-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
hf_llama
model error
#3702
Comments
A more detailed stacktrace: #0 nvfuser::BFS<std::shared_ptr<nvfuser::VectorOfUniqueEntries<nvfuser::Expr*, std::hash<nvfuser::Expr*> > >, std::shared_ptr<nvfuser::VectorOfUniqueEntries<nvfuser::Val*, std::hash<nvfuser::Val*> > >, nvfuser::ValGraphDefinitions, nvfuser::ValGraphUses, nvfuser::ValGraphInputs, nvfuser::ValGraphOutputs>::traverse (this=0x7ffd19ff7bf0) at /opt/pytorch/nvfuser/csrc/bfs.h:241
#1 0x00007ffd6b1e523f in nvfuser::IndexingTraversal::getExprsBetween (expr=0x7ffcf0201730, graph=..., from_domains=std::vector of length 3, capacity 3 = {...}, to_domains=std::vector of length 2, capacity 2 = {...})
at /opt/pytorch/nvfuser/csrc/id_model/indexing_traversal.cpp:297
#2 0x00007ffd6b1cf9e4 in nvfuser::TensorIndexer::computeIndex (this=0x7ffcf0394e40, expr=0x7ffcf0201730, index_ids=std::vector of length 2, capacity 2 = {...}, for_loops=std::vector of length 3, capacity 4 = {...})
at /opt/pytorch/nvfuser/csrc/id_model/indexing.cpp:956
#3 0x00007ffd6b1d214e in nvfuser::TensorIndexer::getContigIndexFor (this=0x7ffcf0394e40, expr=0x7ffcf0201730, as_consumer=false, alloc_info=..., for_loops=std::vector of length 3, capacity 4 = {...})
at /opt/pytorch/nvfuser/csrc/id_model/indexing.cpp:1316
#4 0x00007ffd6b1cf451 in nvfuser::TensorIndexer::getLinearIndex (this=0x7ffcf0394e40, tv=0x7ffcf0238350, expr=0x7ffcf0201730, for_loops=std::vector of length 3, capacity 4 = {...}) at /opt/pytorch/nvfuser/csrc/id_model/indexing.cpp:899
#5 0x00007ffd6b225eff in nvfuser::Index::getProducerIndex (producer=0x7ffcf0238350, consumer=0x7ffcf01fc290, loops=std::vector of length 3, capacity 4 = {...}, rotated_loops=std::unordered_set with 0 elements,
override_index=std::unordered_map with 0 elements, generate_pointer=false, as_type=...) at /opt/pytorch/nvfuser/csrc/index_compute.cpp:2134
#6 0x00007ffd6ae2aad3 in nvfuser::IndexLowering::lowerSrcIndex (this=0x7ffd19ff9000, src=0x7ffcf0238350, dst=0x7ffcf01fc290, override_index=std::unordered_map with 0 elements, generate_pointer=false, as_type=...)
at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:48
#7 0x00007ffd6ae37e39 in nvfuser::IndexLowering::handle (this=0x7ffd19ff9000, ldst=0x7ffcf0201730) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:2144
#8 0x00007ffd6af3cc3a in nvfuser::Expr::constDispatch<nvfuser::OptOutConstDispatch*> (handler=0x7ffd19ff9000, expr=0x7ffcf0201730) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:157
#9 0x00007ffd6af30539 in nvfuser::OptOutConstDispatch::dispatch (this=0x7ffd19ff9000, e=0x7ffcf0201730) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:295
#10 0x00007ffd6ae2b0ae in nvfuser::IndexLowering::handle (this=0x7ffd19ff9000, ite=0x7ffcf0379020) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:131
#11 0x00007ffd6af3d78c in nvfuser::Expr::constDispatch<nvfuser::OptOutConstDispatch*> (handler=0x7ffd19ff9000, expr=0x7ffcf0379020) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:165
#12 0x00007ffd6af30539 in nvfuser::OptOutConstDispatch::dispatch (this=0x7ffd19ff9000, e=0x7ffcf0379020) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:295
#13 0x00007ffd6ae2b2b3 in nvfuser::IndexLowering::handle (this=0x7ffd19ff9000, for_loop=0x7ffcf03aa2c0) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:157
#14 0x00007ffd6af3d1e3 in nvfuser::Expr::constDispatch<nvfuser::OptOutConstDispatch*> (handler=0x7ffd19ff9000, expr=0x7ffcf03aa2c0) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:157
#15 0x00007ffd6af30539 in nvfuser::OptOutConstDispatch::dispatch (this=0x7ffd19ff9000, e=0x7ffcf03aa2c0) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:295
#16 0x00007ffd6ae2b2b3 in nvfuser::IndexLowering::handle (this=0x7ffd19ff9000, for_loop=0x7ffcf038f2d0) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:157
#17 0x00007ffd6af3d1e3 in nvfuser::Expr::constDispatch<nvfuser::OptOutConstDispatch*> (handler=0x7ffd19ff9000, expr=0x7ffcf038f2d0) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:157
#18 0x00007ffd6af30539 in nvfuser::OptOutConstDispatch::dispatch (this=0x7ffd19ff9000, e=0x7ffcf038f2d0) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:295
#19 0x00007ffd6ae2b0ae in nvfuser::IndexLowering::handle (this=0x7ffd19ff9000, ite=0x7ffcf03cb000) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:131
#20 0x00007ffd6af3d78c in nvfuser::Expr::constDispatch<nvfuser::OptOutConstDispatch*> (handler=0x7ffd19ff9000, expr=0x7ffcf03cb000) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:165
#21 0x00007ffd6af30539 in nvfuser::OptOutConstDispatch::dispatch (this=0x7ffd19ff9000, e=0x7ffcf03cb000) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:295
#22 0x00007ffd6ae2b2b3 in nvfuser::IndexLowering::handle (this=0x7ffd19ff9000, for_loop=0x7ffcf03b38a0) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:157
#23 0x00007ffd6af3d1e3 in nvfuser::Expr::constDispatch<nvfuser::OptOutConstDispatch*> (handler=0x7ffd19ff9000, expr=0x7ffcf03b38a0) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:157
#24 0x00007ffd6af30539 in nvfuser::OptOutConstDispatch::dispatch (this=0x7ffd19ff9000, e=0x7ffcf03b38a0) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:295
#25 0x00007ffd6ae3b0d1 in nvfuser::IndexLowering::generate (this=0x7ffd19ff9000, exprs=std::vector of length 1, capacity 1 = {...}) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:2613
#26 0x00007ffd6ae2a8bf in nvfuser::IndexLowering::getIndexedExprs (incoming_exprs=std::vector of length 1, capacity 1 = {...}) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:36
#27 0x00007ffd6adafe49 in std::__invoke_impl<std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >, std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > (*&)(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >), std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&> (__f=@0x7ffd19ff9370: 0x7ffd6ae2a81e <nvfuser::IndexLowering::getIndexedExprs(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >)>)
at /usr/include/c++/13/bits/invoke.h:61
#28 0x00007ffd6adada73 in std::__invoke_r<std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >, std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > (*&)(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >), std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&> (__fn=@0x7ffd19ff9370: 0x7ffd6ae2a81e <nvfuser::IndexLowering::getIndexedExprs(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >)>)
at /usr/include/c++/13/bits/invoke.h:116
#29 0x00007ffd6adab43d in std::_Function_handler<std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > (std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&), std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > (*)(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >)>::_M_invoke(std::_Any_data const&, std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&) (__functor=...,
__args#0=std::vector of length 1, capacity 1 = {...}) at /usr/include/c++/13/bits/std_function.h:291
#30 0x00007ffd6ada9d60 in std::function<std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > (std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&)>::operator()(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&) const (this=0x7ffd19ff9370, __args#0=std::vector of length 1, capacity 1 = {...}) at /usr/include/c++/13/bits/std_function.h:591
#31 0x00007ffd6ada0711 in nvfuser::GpuLower::run (this=0x7ffcf01b5de0) at /opt/pytorch/nvfuser/csrc/device_lower/lower2device.cpp:323
#32 0x00007ffd6b5c3409 in nvfuser::KernelExecutor::compile (this=0x7ffcf019aec0, fusion=0x7ffcf0003050, args=..., launch_constraints=..., compile_params=..., scheduler_type=nvfuser::SchedulerType::PointWise)
at /opt/pytorch/nvfuser/csrc/runtime/executor.cpp:429 |
Recording what we have found so far. Here's the failing segment:
Here's the Exact graph.
I mentioned it's strange to have this group:
But it's actually this seems to make sense. These three IDs come from:
I must have been looking at different IDs. |
Reduced test case: import torch
from nvfuser import FusionDefinition, DataType
def a(FusionDefinition) :
b = fd.define_tensor([1, 32, True] )
# Curiously these empty/useless define_tensors are critical to the bug!
fd.define_tensor([1, 6, 2048], [None, True, True], DataType.BFloat16 )
d = fd.define_tensor([], [], DataType.BFloat16, stride_order=[0])
fd.define_tensor([], [], DataType.BFloat16, stride_order=[1, 0])
fd.define_tensor([], [], DataType.BFloat16, stride_order=[1, 0])
e = fd.define_tensor([], [], DataType.BFloat16, stride_order=[1, 0])
f = fd.ops.permute(b, [0, 2, 1])
g = fd.ops.cat([f, f], -1 )
ab = fd.ops.cast(g, DataType.BFloat16)
l = fd.ops.broadcast_in_dim(d, [1, 6, 2048], [2])
n = fd.ops.mul(l, l)
o = fd.ops.cast(n, DataType.BFloat16)
p = fd.ops.linear(o, e)
ah = fd.ops.reshape(p, [1, 6, 8, 64])
ai = fd.ops.permute(ah, [0, 2, 1, 3])
al = fd.ops.broadcast_in_dim(g, [1, 1, 6, 64], [0, 2, 3])
am = fd.ops.broadcast_in_dim(ab, [1, 1, 6, 64], [0, 2, 3] )
bc = fd.ops.broadcast_in_dim(al, [1, 8, 6, 64], [0, 1, 2, 3])
bf = fd.ops.mul(ai, bc)
bg = fd.ops.slice(ai, [0, 0, 0, 0], [1, 8, 6, 32] )
bh = fd.ops.slice(ai, [0, 0, 0, 32], [1, 8, 6, 64] )
bl = fd.ops.cat([bh, bg], -1 )
bp = fd.ops.mul(bl, am)
bq = fd.ops.add(bf, bp)
br = fd.ops.cast(bq, DataType.BFloat16)
bs = fd.ops.broadcast_in_dim(br, [1, 8, 1, 6, 64], [0, 1, 3, 4])
bt = fd.ops.broadcast_in_dim(bs, [1, 8, 4, 6, 64], [0, 1, 2, 3, 4])
bu = fd.ops.reshape(bt, [1, 32, 6, 64])
by = fd.ops.stride_order(bu, [3, 2, 1, 0])
fd.add_output(by)
with FusionDefinition() as fd:
a(fd)
inputs = [
torch.testing.make_tensor(1, 32, 6, dtype=torch.float32, device='cuda:0'),
torch.testing.make_tensor(1, 6, 2048, dtype=torch.bfloat16, device='cuda:0'),
torch.testing.make_tensor(2048, dtype=torch.bfloat16, device='cuda:0'),
torch.testing.make_tensor(2048, 2048, dtype=torch.bfloat16, device='cuda:0'),
torch.testing.make_tensor(2, 2048, dtype=torch.bfloat16, device='cuda:0'),
torch.testing.make_tensor(512, 2048, dtype=torch.bfloat16, device='cuda:0')]
fd.execute(inputs) The automated reduction didn't do a great job; I actually manually reduced this a bit afterwards, and it's possible more simplifications are possible. |
This impacts taking more ops in the forward pass to reduce CPU overhead.
Error:
Error from segmentation group 14: INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/bfs.h":241, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. BFS traversal could no t visit some nodes: idg{108 503 510} (from: idg{600 602 607 609 614 616 621 623 628 630 641 643 647 649 653 655 659 661 677 679 683 685 689 691 695 697 701 703 707 709 713 715 719 721 725 727 731 733 753 755 757 759} idg{603 610 617 624 631 644 650 656 662 680 686 692 698 704 710 716 722 728 734 758} idg{601 608 615 622 629 642 648 654 660 678 684 690 696 702 708 714 720 726 732 754 756 760}), visited: ( idg{665 667 671 67 3 737 739 743 745 749 751} idg{389 393} idg{668 674 740 746 752} idg{663 664 669 670 735 736 741 742 747 748} idg{1 14 390 394 398 506 535 536} idg{666 672 738 744 750} idg{2 15 28 31 34 37 40 43 106 112 137 145 15 7 161 165 169 213 221 241 297 305 388 392 396 502 507 509 513 517 521 525 529 533} idg{603 610 617 624 631 644 650 656 662 680 686 692 698 704 710 716 722 728 734 758} idg{601 608 615 622 629 642 648 654 660 678 68 4 690 696 702 708 714 720 726 732 754 756 760} idg{397} idg{541 548 555 562 569 576 583 590 597 598 599 604 605 606 611 612 613 618 619 620 625 626 627 632 639 640 645 646 651 652 657 658 675 676 681 682 687 688 69 3 694 699 700 705 706 711 712 717 718 723 724 729 730} idg{600 602 607 609 614 616 621 623 628 630 641 643 647 649 653 655 659 661 677 679 683 685 689 691 695 697 701 703 707 709 713 715 719 721 725 727 731 733 753 755 757 759} idg{29 32 35 38 41 44 110 114 138 146 158 162 166 170 214 222 242 298 306 399 514 518 522 526 530 534 538 540} idg{0 13 27 30 33 36 39 42 105 111 135 136 143 155 156 159 163 164 167 211 219 239 240 29 5 303 387 391 395 501 505 508 511 515 516 519 523 527 531})
Repro:
The text was updated successfully, but these errors were encountered: