Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hf_llama model error #3702

Open
kevinstephano opened this issue Jan 14, 2025 · 3 comments
Open

hf_llama model error #3702

kevinstephano opened this issue Jan 14, 2025 · 3 comments
Assignees
Labels

Comments

@kevinstephano
Copy link
Collaborator

kevinstephano commented Jan 14, 2025

This impacts taking more ops in the forward pass to reduce CPU overhead.

Error:

Error from segmentation group 14:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/bfs.h":241, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. BFS traversal could no
t visit some nodes:  idg{108 503 510} (from:  idg{600 602 607 609 614 616 621 623 628 630 641 643 647 649 653 655 659 661 677 679 683 685 689 691 695 697 701 703 707 709 713 715 719 721 725 727 731 733 753 755 757
759} idg{603 610 617 624 631 644 650 656 662 680 686 692 698 704 710 716 722 728 734 758} idg{601 608 615 622 629 642 648 654 660 678 684 690 696 702 708 714 720 726 732 754 756 760}), visited: ( idg{665 667 671 67
3 737 739 743 745 749 751} idg{389 393} idg{668 674 740 746 752} idg{663 664 669 670 735 736 741 742 747 748} idg{1 14 390 394 398 506 535 536} idg{666 672 738 744 750} idg{2 15 28 31 34 37 40 43 106 112 137 145 15
7 161 165 169 213 221 241 297 305 388 392 396 502 507 509 513 517 521 525 529 533} idg{603 610 617 624 631 644 650 656 662 680 686 692 698 704 710 716 722 728 734 758} idg{601 608 615 622 629 642 648 654 660 678 68
4 690 696 702 708 714 720 726 732 754 756 760} idg{397} idg{541 548 555 562 569 576 583 590 597 598 599 604 605 606 611 612 613 618 619 620 625 626 627 632 639 640 645 646 651 652 657 658 675 676 681 682 687 688 69
3 694 699 700 705 706 711 712 717 718 723 724 729 730} idg{600 602 607 609 614 616 621 623 628 630 641 643 647 649 653 655 659 661 677 679 683 685 689 691 695 697 701 703 707 709 713 715 719 721 725 727 731 733 753
 755 757 759} idg{29 32 35 38 41 44 110 114 138 146 158 162 166 170 214 222 242 298 306 399 514 518 522 526 530 534 538 540} idg{0 13 27 30 33 36 39 42 105 111 135 136 143 155 156 159 163 164 167 211 219 239 240 29
5 303 387 391 395 501 505 508 511 515 516 519 523 527 531})

Repro:

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id3(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 32, 6], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[1, 6, 2048], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[2048], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[2048, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T4 = fd.define_tensor(shape=[512, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T5 = fd.define_tensor(shape=[512, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T6 = fd.ops.permute(T0, dims=[0, 2, 1])
    T7 = fd.ops.cat([T6, T6], dim=-1, manual_padding=0)
    T8 = fd.ops.cos(T7)
    T9 = fd.ops.sin(T7)
    S10 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T11 = fd.ops.mul(T8, S10)
    S12 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T13 = fd.ops.mul(T9, S12)
    T14 = fd.ops.cast(T11, dtype=DataType.BFloat16)
    T15 = fd.ops.cast(T13, dtype=DataType.BFloat16)
    T16 = fd.ops.cast(T1, dtype=DataType.Float)
    S17 = fd.define_scalar(2.00000, dtype=DataType.Double)
    T18 = fd.ops.pow(T16, S17)
    T19 = fd.ops.sum(T18, dims=[2], keepdim=False, dtype=DataType.Null)
    T24 = fd.ops.broadcast_in_dim(T19, shape=[1, 6, 1], broadcast_dims=[0, 1])
    S25 = fd.define_scalar(2048.00, dtype=DataType.Double)
    S26 = fd.ops.reciprocal(S25)
    T27 = fd.ops.mul(T24, S26)
    S28 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T29 = fd.ops.add(T27, S28)
    T30 = fd.ops.rsqrt(T29)
    T35 = fd.ops.broadcast_in_dim(T30, shape=[1, 6, 2048], broadcast_dims=[0, 1, 2])
    T36 = fd.ops.mul(T16, T35)
    T41 = fd.ops.broadcast_in_dim(T2, shape=[1, 6, 2048], broadcast_dims=[2])
    T42 = fd.ops.cast(T41, dtype=DataType.Float)
    T43 = fd.ops.mul(T42, T36)
    T44 = fd.ops.cast(T43, dtype=DataType.BFloat16)
    T45 = fd.ops.linear(T44, T3)
    T46 = fd.ops.linear(T44, T4)
    T47 = fd.ops.linear(T44, T5)
    T53 = fd.ops.reshape(T45, new_shape=[1, 6, 32, 64])
    T54 = fd.ops.permute(T53, dims=[0, 2, 1, 3])
    T60 = fd.ops.reshape(T46, new_shape=[1, 6, 8, 64])
    T61 = fd.ops.permute(T60, dims=[0, 2, 1, 3])
    T67 = fd.ops.reshape(T47, new_shape=[1, 6, 8, 64])
    T68 = fd.ops.permute(T67, dims=[0, 2, 1, 3])
    T74 = fd.ops.broadcast_in_dim(T14, shape=[1, 1, 6, 64], broadcast_dims=[0, 2, 3])
    T80 = fd.ops.broadcast_in_dim(T15, shape=[1, 1, 6, 64], broadcast_dims=[0, 2, 3])
    T86 = fd.ops.broadcast_in_dim(T74, shape=[1, 32, 6, 64], broadcast_dims=[0, 1, 2, 3])
    T87 = fd.ops.cast(T54, dtype=DataType.Float)
    T88 = fd.ops.cast(T86, dtype=DataType.Float)
    T89 = fd.ops.mul(T87, T88)
    T105 = fd.ops.slice(T54, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 6, 32], strides=[1, 1, 1, 1], manual_normalization=0)
    T121 = fd.ops.slice(T54, start_indices=[0, 0, 0, 32], end_indices=[1, 32, 6, 64], strides=[1, 1, 1, 1], manual_normalization=0)
    T122 = fd.ops.cast(T121, dtype=DataType.Float)
    T123 = fd.ops.neg(T122)
    T124 = fd.ops.cast(T123, dtype=DataType.BFloat16)
    T125 = fd.ops.cat([T124, T105], dim=-1, manual_padding=0)
    T131 = fd.ops.broadcast_in_dim(T80, shape=[1, 32, 6, 64], broadcast_dims=[0, 1, 2, 3])
    T132 = fd.ops.cast(T125, dtype=DataType.Float)
    T133 = fd.ops.cast(T131, dtype=DataType.Float)
    T134 = fd.ops.mul(T132, T133)
    T135 = fd.ops.add(T89, T134)
    T136 = fd.ops.cast(T135, dtype=DataType.BFloat16)
    T142 = fd.ops.broadcast_in_dim(T74, shape=[1, 8, 6, 64], broadcast_dims=[0, 1, 2, 3])
    T143 = fd.ops.cast(T61, dtype=DataType.Float)
    T144 = fd.ops.cast(T142, dtype=DataType.Float)
    T145 = fd.ops.mul(T143, T144)
    T161 = fd.ops.slice(T61, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 6, 32], strides=[1, 1, 1, 1], manual_normalization=0)
    T177 = fd.ops.slice(T61, start_indices=[0, 0, 0, 32], end_indices=[1, 8, 6, 64], strides=[1, 1, 1, 1], manual_normalization=0)
    T178 = fd.ops.cast(T177, dtype=DataType.Float)
    T179 = fd.ops.neg(T178)
    T180 = fd.ops.cast(T179, dtype=DataType.BFloat16)
    T181 = fd.ops.cat([T180, T161], dim=-1, manual_padding=0)
    T187 = fd.ops.broadcast_in_dim(T80, shape=[1, 8, 6, 64], broadcast_dims=[0, 1, 2, 3])
    T188 = fd.ops.cast(T181, dtype=DataType.Float)
    T189 = fd.ops.cast(T187, dtype=DataType.Float)
    T190 = fd.ops.mul(T188, T189)
    T191 = fd.ops.add(T145, T190)
    T192 = fd.ops.cast(T191, dtype=DataType.BFloat16)
    T199 = fd.ops.broadcast_in_dim(T192, shape=[1, 8, 1, 6, 64], broadcast_dims=[0, 1, 3, 4])
    T206 = fd.ops.broadcast_in_dim(T199, shape=[1, 8, 4, 6, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T212 = fd.ops.reshape(T206, new_shape=[1, 32, 6, 64])
    T219 = fd.ops.broadcast_in_dim(T68, shape=[1, 8, 1, 6, 64], broadcast_dims=[0, 1, 3, 4])
    T226 = fd.ops.broadcast_in_dim(T219, shape=[1, 8, 4, 6, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T232 = fd.ops.reshape(T226, new_shape=[1, 32, 6, 64])
    T233 = fd.ops.stride_order(T136, stride_order=[3, 2, 1, 0])
    T234 = fd.ops.stride_order(T212, stride_order=[3, 2, 1, 0])
    T235 = fd.ops.stride_order(T232, stride_order=[3, 2, 1, 0])
    fd.add_output(T68)
    fd.add_output(T192)
    fd.add_output(T233)
    fd.add_output(T234)
    fd.add_output(T235)

with FusionDefinition() as fd:
    nvfuser_fusion_id3(fd)

inputs = [
    torch.testing.make_tensor((1, 32, 6), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((1, 6, 2048), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((2048,), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((2048, 2048), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((512, 2048), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((512, 2048), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs)
@kevinstephano
Copy link
Collaborator Author

A more detailed stacktrace:

#0  nvfuser::BFS<std::shared_ptr<nvfuser::VectorOfUniqueEntries<nvfuser::Expr*, std::hash<nvfuser::Expr*> > >, std::shared_ptr<nvfuser::VectorOfUniqueEntries<nvfuser::Val*, std::hash<nvfuser::Val*> > >, nvfuser::ValGraphDefinitions, nvfuser::ValGraphUses, nvfuser::ValGraphInputs, nvfuser::ValGraphOutputs>::traverse (this=0x7ffd19ff7bf0) at /opt/pytorch/nvfuser/csrc/bfs.h:241
#1  0x00007ffd6b1e523f in nvfuser::IndexingTraversal::getExprsBetween (expr=0x7ffcf0201730, graph=..., from_domains=std::vector of length 3, capacity 3 = {...}, to_domains=std::vector of length 2, capacity 2 = {...})
    at /opt/pytorch/nvfuser/csrc/id_model/indexing_traversal.cpp:297
#2  0x00007ffd6b1cf9e4 in nvfuser::TensorIndexer::computeIndex (this=0x7ffcf0394e40, expr=0x7ffcf0201730, index_ids=std::vector of length 2, capacity 2 = {...}, for_loops=std::vector of length 3, capacity 4 = {...})
    at /opt/pytorch/nvfuser/csrc/id_model/indexing.cpp:956
#3  0x00007ffd6b1d214e in nvfuser::TensorIndexer::getContigIndexFor (this=0x7ffcf0394e40, expr=0x7ffcf0201730, as_consumer=false, alloc_info=..., for_loops=std::vector of length 3, capacity 4 = {...})
    at /opt/pytorch/nvfuser/csrc/id_model/indexing.cpp:1316
#4  0x00007ffd6b1cf451 in nvfuser::TensorIndexer::getLinearIndex (this=0x7ffcf0394e40, tv=0x7ffcf0238350, expr=0x7ffcf0201730, for_loops=std::vector of length 3, capacity 4 = {...}) at /opt/pytorch/nvfuser/csrc/id_model/indexing.cpp:899
#5  0x00007ffd6b225eff in nvfuser::Index::getProducerIndex (producer=0x7ffcf0238350, consumer=0x7ffcf01fc290, loops=std::vector of length 3, capacity 4 = {...}, rotated_loops=std::unordered_set with 0 elements,
    override_index=std::unordered_map with 0 elements, generate_pointer=false, as_type=...) at /opt/pytorch/nvfuser/csrc/index_compute.cpp:2134
#6  0x00007ffd6ae2aad3 in nvfuser::IndexLowering::lowerSrcIndex (this=0x7ffd19ff9000, src=0x7ffcf0238350, dst=0x7ffcf01fc290, override_index=std::unordered_map with 0 elements, generate_pointer=false, as_type=...)
    at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:48
#7  0x00007ffd6ae37e39 in nvfuser::IndexLowering::handle (this=0x7ffd19ff9000, ldst=0x7ffcf0201730) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:2144
#8  0x00007ffd6af3cc3a in nvfuser::Expr::constDispatch<nvfuser::OptOutConstDispatch*> (handler=0x7ffd19ff9000, expr=0x7ffcf0201730) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:157
#9  0x00007ffd6af30539 in nvfuser::OptOutConstDispatch::dispatch (this=0x7ffd19ff9000, e=0x7ffcf0201730) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:295
#10 0x00007ffd6ae2b0ae in nvfuser::IndexLowering::handle (this=0x7ffd19ff9000, ite=0x7ffcf0379020) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:131
#11 0x00007ffd6af3d78c in nvfuser::Expr::constDispatch<nvfuser::OptOutConstDispatch*> (handler=0x7ffd19ff9000, expr=0x7ffcf0379020) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:165
#12 0x00007ffd6af30539 in nvfuser::OptOutConstDispatch::dispatch (this=0x7ffd19ff9000, e=0x7ffcf0379020) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:295
#13 0x00007ffd6ae2b2b3 in nvfuser::IndexLowering::handle (this=0x7ffd19ff9000, for_loop=0x7ffcf03aa2c0) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:157
#14 0x00007ffd6af3d1e3 in nvfuser::Expr::constDispatch<nvfuser::OptOutConstDispatch*> (handler=0x7ffd19ff9000, expr=0x7ffcf03aa2c0) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:157
#15 0x00007ffd6af30539 in nvfuser::OptOutConstDispatch::dispatch (this=0x7ffd19ff9000, e=0x7ffcf03aa2c0) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:295
#16 0x00007ffd6ae2b2b3 in nvfuser::IndexLowering::handle (this=0x7ffd19ff9000, for_loop=0x7ffcf038f2d0) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:157
#17 0x00007ffd6af3d1e3 in nvfuser::Expr::constDispatch<nvfuser::OptOutConstDispatch*> (handler=0x7ffd19ff9000, expr=0x7ffcf038f2d0) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:157
#18 0x00007ffd6af30539 in nvfuser::OptOutConstDispatch::dispatch (this=0x7ffd19ff9000, e=0x7ffcf038f2d0) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:295
#19 0x00007ffd6ae2b0ae in nvfuser::IndexLowering::handle (this=0x7ffd19ff9000, ite=0x7ffcf03cb000) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:131
#20 0x00007ffd6af3d78c in nvfuser::Expr::constDispatch<nvfuser::OptOutConstDispatch*> (handler=0x7ffd19ff9000, expr=0x7ffcf03cb000) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:165
#21 0x00007ffd6af30539 in nvfuser::OptOutConstDispatch::dispatch (this=0x7ffd19ff9000, e=0x7ffcf03cb000) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:295
#22 0x00007ffd6ae2b2b3 in nvfuser::IndexLowering::handle (this=0x7ffd19ff9000, for_loop=0x7ffcf03b38a0) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:157
#23 0x00007ffd6af3d1e3 in nvfuser::Expr::constDispatch<nvfuser::OptOutConstDispatch*> (handler=0x7ffd19ff9000, expr=0x7ffcf03b38a0) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:157
#24 0x00007ffd6af30539 in nvfuser::OptOutConstDispatch::dispatch (this=0x7ffd19ff9000, e=0x7ffcf03b38a0) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:295
#25 0x00007ffd6ae3b0d1 in nvfuser::IndexLowering::generate (this=0x7ffd19ff9000, exprs=std::vector of length 1, capacity 1 = {...}) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:2613
#26 0x00007ffd6ae2a8bf in nvfuser::IndexLowering::getIndexedExprs (incoming_exprs=std::vector of length 1, capacity 1 = {...}) at /opt/pytorch/nvfuser/csrc/device_lower/pass/index.cpp:36
#27 0x00007ffd6adafe49 in std::__invoke_impl<std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >, std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > (*&)(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >), std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&> (__f=@0x7ffd19ff9370: 0x7ffd6ae2a81e <nvfuser::IndexLowering::getIndexedExprs(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >)>)
    at /usr/include/c++/13/bits/invoke.h:61
#28 0x00007ffd6adada73 in std::__invoke_r<std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >, std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > (*&)(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >), std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&> (__fn=@0x7ffd19ff9370: 0x7ffd6ae2a81e <nvfuser::IndexLowering::getIndexedExprs(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >)>)
    at /usr/include/c++/13/bits/invoke.h:116
#29 0x00007ffd6adab43d in std::_Function_handler<std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > (std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&), std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > (*)(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> >)>::_M_invoke(std::_Any_data const&, std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&) (__functor=...,
    __args#0=std::vector of length 1, capacity 1 = {...}) at /usr/include/c++/13/bits/std_function.h:291
#30 0x00007ffd6ada9d60 in std::function<std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > (std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&)>::operator()(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&) const (this=0x7ffd19ff9370, __args#0=std::vector of length 1, capacity 1 = {...}) at /usr/include/c++/13/bits/std_function.h:591
#31 0x00007ffd6ada0711 in nvfuser::GpuLower::run (this=0x7ffcf01b5de0) at /opt/pytorch/nvfuser/csrc/device_lower/lower2device.cpp:323
#32 0x00007ffd6b5c3409 in nvfuser::KernelExecutor::compile (this=0x7ffcf019aec0, fusion=0x7ffcf0003050, args=..., launch_constraints=..., compile_params=..., scheduler_type=nvfuser::SchedulerType::PointWise)
    at /opt/pytorch/nvfuser/csrc/runtime/executor.cpp:429

@naoyam naoyam self-assigned this Jan 14, 2025
@naoyam
Copy link
Collaborator

naoyam commented Jan 15, 2025

Recording what we have found so far.

Here's the failing segment:

Inputs:
  T0_g_float[iS751{2}, iS752{1}, iS750{128}]
  T32_g___bfloat[iS581{96}, iS582{1}, iS580{128}]
Outputs:
  T61_g_float[iblockIdx.x733{3}, iUS734{1}, ithreadIdx.x732{128}, bS512{1 ex 32}] ca_pos( 2 ) produce_pos( 3 )
  T66_g___bfloat[iblockIdx.x630{3 ex ( ceilDiv(( 8 * 384 ), 128) )}, iUS631{1}, ithreadIdx.x629{128}] ca_pos( 2 ) produce_pos( 3 )
  T81_g_float[iblockIdx.x715{3}, iUS716{1}, ithreadIdx.x714{128}, bS520{1 ex 8}] ca_pos( 2 ) produce_pos( 3 )
  T129_g___bfloat[iblockIdx.x546{96}, iUS547{1}, ithreadIdx.x545{128}] ca_pos( 2 ) produce_pos( 3 )
  T49_g_float[iblockIdx.x637{96}, iUS638{1}, ithreadIdx.x636{128}] ca_pos( 2 ) produce_pos( 3 )

%kernel {
T124_l_float[iblockIdx.x745{2}, iUS746{1}, ithreadIdx.x744{128}] ca_pos( 2 )
   = Set( T0_g_float[iS751{2}, iS752{1}, iS750{128}], cache_op=Streaming )
T6_l_float[iblockIdx.x739{2}, iUS740{1}, ithreadIdx.x738{128}] ca_pos( 3 ) produce_pos( 2 )
   = Set.Permute( T124_l_float[iblockIdx.x745{2}, iUS746{1}, ithreadIdx.x744{128}] ca_pos( 2 ), cache_op=Streaming )
T98_l_float[iblockIdx.x673{2}, iUS674{1}, ithreadIdx.x672{128}] ca_pos( 3 ) produce_pos( 3 )
   = broadcast( T6_l_float[iblockIdx.x739{2}, iUS740{1}, ithreadIdx.x738{128}] ca_pos( 3 ) produce_pos( 2 ), flags = {false, false, true, false} )
T99_l_float[iblockIdx.x667{2 ex ( ceilDiv(( 6 * ( 2 * 32 ) ), 128) )}, iUS668{1}, ithreadIdx.x666{128}] ca_pos( 3 ) produce_pos( 3 ) = expand( T98_l_float[iblockIdx.x673{2}, iUS674{1}, ithreadIdx.x672{128}] ca_pos( 3 ) produce_pos( 3 ), {1, 6, 2, 32} )
T100_l_float[iblockIdx.x661{( ceilDiv(( ( 2 * 32 ) * 6 ), 128) )}, iUS662{1}, ithreadIdx.x660{128}] ca_pos( 3 ) produce_pos( 3 ) = view( T99_l_float[iblockIdx.x667{2 ex ( ceilDiv(( 6 * ( 2 * 32 ) ), 128) )}, iUS668{1}, ithreadIdx.x666{128}] ca_pos( 3 ) produce_pos( 3 ) )
T11_l_float[iblockIdx.x679{3}, iUS680{1}, ithreadIdx.x678{128}] ca_pos( 3 ) produce_pos( 3 )
   = sinf(T100_l_float[iblockIdx.x661{( ceilDiv(( ( 2 * 32 ) * 6 ), 128) )}, iUS662{1}, ithreadIdx.x660{128}] ca_pos( 3 ) produce_pos( 3 ));
T13_l_float[iblockIdx.x685{3}, iUS686{1}, ithreadIdx.x684{128}] ca_pos( 3 ) produce_pos( 3 )
   = T11_l_float[iblockIdx.x679{3}, iUS680{1}, ithreadIdx.x678{128}] ca_pos( 3 ) produce_pos( 3 )
   * double(1);
T15_l___bfloat[iblockIdx.x691{3}, iUS692{1}, ithreadIdx.x690{128}] ca_pos( 3 ) produce_pos( 3 )
   = __float2bfloat(T13_l_float[iblockIdx.x685{3}, iUS686{1}, ithreadIdx.x684{128}] ca_pos( 3 ) produce_pos( 3 ));
T43_l___bfloat[iblockIdx.x697{3}, iUS698{1}, ithreadIdx.x696{128}, bS144{1}] ca_pos( 3 ) produce_pos( 3 )
   = broadcast( T15_l___bfloat[iblockIdx.x691{3}, iUS692{1}, ithreadIdx.x690{128}] ca_pos( 3 ) produce_pos( 3 ), flags = {false, true, false, false} )
T59_l___bfloat[iblockIdx.x721{3}, iUS722{1}, ithreadIdx.x720{128}, bS212{1 ex 32}] ca_pos( 3 ) produce_pos( 3 ) = expand( T43_l___bfloat[iblockIdx.x697{3}, iUS698{1}, ithreadIdx.x696{128}, bS144{1}] ca_pos( 3 ) produce_pos( 3 ), {1, 32, 6, 64} )
T126_l_float[iblockIdx.x727{3}, iUS728{1}, ithreadIdx.x726{128}, bS220{1 ex 32}] ca_pos( 3 ) produce_pos( 3 )
   = __bfloat2float(T59_l___bfloat[iblockIdx.x721{3}, iUS722{1}, ithreadIdx.x720{128}, bS212{1 ex 32}] ca_pos( 3 ) produce_pos( 3 ));
T61_g_float[iblockIdx.x733{3}, iUS734{1}, ithreadIdx.x732{128}, bS512{1 ex 32}] ca_pos( 2 ) produce_pos( 3 )
   = Set( T126_l_float[iblockIdx.x727{3}, iUS728{1}, ithreadIdx.x726{128}, bS220{1 ex 32}] ca_pos( 3 ) produce_pos( 3 ), cache_op=Streaming )
T10_l_float[iblockIdx.x655{3}, iUS656{1}, ithreadIdx.x654{128}] ca_pos( 3 ) produce_pos( 3 )
   = cosf(T100_l_float[iblockIdx.x661{( ceilDiv(( ( 2 * 32 ) * 6 ), 128) )}, iUS662{1}, ithreadIdx.x660{128}] ca_pos( 3 ) produce_pos( 3 ));
T12_l_float[iblockIdx.x649{3}, iUS650{1}, ithreadIdx.x648{128}] ca_pos( 3 ) produce_pos( 3 )
   = T10_l_float[iblockIdx.x655{3}, iUS656{1}, ithreadIdx.x654{128}] ca_pos( 3 ) produce_pos( 3 )
   * double(1);
T14_l___bfloat[iblockIdx.x643{3}, iUS644{1}, ithreadIdx.x642{128}] ca_pos( 3 ) produce_pos( 3 )
   = __float2bfloat(T12_l_float[iblockIdx.x649{3}, iUS650{1}, ithreadIdx.x648{128}] ca_pos( 3 ) produce_pos( 3 ));
T41_l___bfloat[iblockIdx.x616{3}, iUS617{1}, ithreadIdx.x615{128}] ca_pos( 3 ) produce_pos( 3 )
   = broadcast( T14_l___bfloat[iblockIdx.x643{3}, iUS644{1}, ithreadIdx.x642{128}] ca_pos( 3 ) produce_pos( 3 ), flags = {false, true, false, false} )
T127_l___bfloat[iblockIdx.x623{3 ex ( ceilDiv(( 8 * 384 ), 128) )}, iUS624{1}, ithreadIdx.x622{128}] ca_pos( 3 ) produce_pos( 3 ) = expand( T41_l___bfloat[iblockIdx.x616{3}, iUS617{1}, ithreadIdx.x615{128}] ca_pos( 3 ) produce_pos( 3 ), {1, 8, 6, 64} )
T66_g___bfloat[iblockIdx.x630{3 ex ( ceilDiv(( 8 * 384 ), 128) )}, iUS631{1}, ithreadIdx.x629{128}] ca_pos( 2 ) produce_pos( 3 )
   = Set( T127_l___bfloat[iblockIdx.x623{3 ex ( ceilDiv(( 8 * 384 ), 128) )}, iUS624{1}, ithreadIdx.x622{128}] ca_pos( 3 ) produce_pos( 3 ), cache_op=Streaming )
T79_l___bfloat[iblockIdx.x703{3}, iUS704{1}, ithreadIdx.x702{128}, bS296{1 ex 8}] ca_pos( 3 ) produce_pos( 3 ) = expand( T43_l___bfloat[iblockIdx.x697{3}, iUS698{1}, ithreadIdx.x696{128}, bS144{1}] ca_pos( 3 ) produce_pos( 3 ), {1, 8, 6, 64} )
T128_l_float[iblockIdx.x709{3}, iUS710{1}, ithreadIdx.x708{128}, bS304{1 ex 8}] ca_pos( 3 ) produce_pos( 3 )
   = __bfloat2float(T79_l___bfloat[iblockIdx.x703{3}, iUS704{1}, ithreadIdx.x702{128}, bS296{1 ex 8}] ca_pos( 3 ) produce_pos( 3 ));
T81_g_float[iblockIdx.x715{3}, iUS716{1}, ithreadIdx.x714{128}, bS520{1 ex 8}] ca_pos( 2 ) produce_pos( 3 )
   = Set( T128_l_float[iblockIdx.x709{3}, iUS710{1}, ithreadIdx.x708{128}, bS304{1 ex 8}] ca_pos( 3 ) produce_pos( 3 ), cache_op=Streaming )
T125_l___bfloat[iblockIdx.x574{96}, iUS575{1}, ithreadIdx.x573{128}] ca_pos( 2 )
   = Set( T32_g___bfloat[iS581{96}, iS582{1}, iS580{128}], cache_op=Streaming )
T35_l___bfloat[iblockIdx.x567{96}, iUS568{1}, ithreadIdx.x566{128}] ca_pos( 3 ) produce_pos( 2 ) = view( T125_l___bfloat[iblockIdx.x574{96}, iUS575{1}, ithreadIdx.x573{128}] ca_pos( 2 ) )
T36_l___bfloat[iblockIdx.x560{96}, iUS561{1}, ithreadIdx.x559{128}] ca_pos( 3 ) produce_pos( 3 )
   = Set.Permute( T35_l___bfloat[iblockIdx.x567{96}, iUS568{1}, ithreadIdx.x566{128}] ca_pos( 3 ) produce_pos( 2 ), cache_op=Streaming )
T130_l___bfloat[iblockIdx.x553{96}, iUS554{1}, ithreadIdx.x552{128}] ca_pos( 3 ) produce_pos( 3 )
   = Set( T36_l___bfloat[iblockIdx.x560{96}, iUS561{1}, ithreadIdx.x559{128}] ca_pos( 3 ) produce_pos( 3 ), cache_op=Streaming )
T129_g___bfloat[iblockIdx.x546{96}, iUS547{1}, ithreadIdx.x545{128}] ca_pos( 2 ) produce_pos( 3 )
   = Set( T130_l___bfloat[iblockIdx.x553{96}, iUS554{1}, ithreadIdx.x552{128}] ca_pos( 3 ) produce_pos( 3 ), cache_op=Streaming )
T47_l_float[iblockIdx.x588{96}, iUS589{1}, ithreadIdx.x587{128}] ca_pos( 3 ) produce_pos( 3 )
   = __bfloat2float(T36_l___bfloat[iblockIdx.x560{96}, iUS561{1}, ithreadIdx.x559{128}] ca_pos( 3 ) produce_pos( 3 ));
T46_l___bfloat[iblockIdx.x609{3 ex ( ceilDiv(( 32 * 384 ), 128) )}, iUS610{1}, ithreadIdx.x608{128}] ca_pos( 3 ) produce_pos( 3 ) = expand( T41_l___bfloat[iblockIdx.x616{3}, iUS617{1}, ithreadIdx.x615{128}] ca_pos( 3 ) produce_pos( 3 ), {1, 32, 6, 64} )
T48_l_float[iblockIdx.x602{3 ex ( ceilDiv(( 32 * 384 ), 128) )}, iUS603{1}, ithreadIdx.x601{128}] ca_pos( 3 ) produce_pos( 3 )
   = __bfloat2float(T46_l___bfloat[iblockIdx.x609{3 ex ( ceilDiv(( 32 * 384 ), 128) )}, iUS610{1}, ithreadIdx.x608{128}] ca_pos( 3 ) produce_pos( 3 ));
T131_l_float[iblockIdx.x595{96}, iUS596{1}, ithreadIdx.x594{128}] ca_pos( 3 ) produce_pos( 3 )
   = T47_l_float[iblockIdx.x588{96}, iUS589{1}, ithreadIdx.x587{128}] ca_pos( 3 ) produce_pos( 3 )
   * T48_l_float[iblockIdx.x602{3 ex ( ceilDiv(( 32 * 384 ), 128) )}, iUS603{1}, ithreadIdx.x601{128}] ca_pos( 3 ) produce_pos( 3 );
T49_g_float[iblockIdx.x637{96}, iUS638{1}, ithreadIdx.x636{128}] ca_pos( 2 ) produce_pos( 3 )
   = Set( T131_l_float[iblockIdx.x595{96}, iUS596{1}, ithreadIdx.x594{128}] ca_pos( 3 ) produce_pos( 3 ), cache_op=Streaming )
TransformPrinter :
T0_g_float[iS751{2}, iS752{1}, iS750{128}]
 logical domain : (bS0{1}, iS1{32}, iS2{6})
 contiguity: n t t
  Merge: iS2{6} and iS1{32} -> iS747{192}
  Merge: bS0{1} and iS747{192} -> iS748{192}
  Split: iS748{192} by factor 128 -> iS749{2}, iS750{128}
  Split: iS749{2} by factor 1 -> iS751{2}, iS752{1}
 loop domain : (iS751{2}, iS752{1}, iS750{128})
T124_l_float[iblockIdx.x745{2}, iUS746{1}, ithreadIdx.x744{128}] ca_pos( 2 )
 logical domain : (bS505{1}, iS506{32}, iS507{6})
 contiguity: n t t
  Merge: iS507{6} and iS506{32} -> iS741{192}
  Merge: bS505{1} and iS741{192} -> iS742{192}
  Split: iS742{192} by factor 128 -> iS743{2}, ithreadIdx.x744{128}
  Split: iS743{2} by factor 1 -> iblockIdx.x745{2}, iUS746{1}
 loop domain : (iblockIdx.x745{2}, iUS746{1}, ithreadIdx.x744{128})
T6_l_float[iblockIdx.x739{2}, iUS740{1}, ithreadIdx.x738{128}] ca_pos( 3 ) produce_pos( 2 )
 root domain : (bS13{1}, iS14{32}, iS15{6})
 logical domain : (bS13{1}, iS15{6}, iS14{32})
 allocation domain : (bS13{1}, iS14{32}, iS15{6})
 contiguity: n t t
  Merge: iS15{6} and iS14{32} -> iS735{192}
  Merge: bS13{1} and iS735{192} -> iS736{192}
  Split: iS736{192} by factor 128 -> iS737{2}, ithreadIdx.x738{128}
  Split: iS737{2} by factor 1 -> iblockIdx.x739{2}, iUS740{1}
 loop domain : (iblockIdx.x739{2}, iUS740{1}, ithreadIdx.x738{128})
T98_l_float[iblockIdx.x673{2}, iUS674{1}, ithreadIdx.x672{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS387{1}, iS388{6}, bS389{1}, iS390{32})
 allocation domain : (bS387{1}, iS390{32}, iS388{6}, bS389{1})
 contiguity: n t t n
  Merge: bS389{1} and iS390{32} -> iS536{32}
  Merge: iS388{6} and iS536{32} -> iS669{192}
  Merge: bS387{1} and iS669{192} -> iS670{192}
  Split: iS670{192} by factor 128 -> iS671{2}, ithreadIdx.x672{128}
  Split: iS671{2} by factor 1 -> iblockIdx.x673{2}, iUS674{1}
 loop domain : (iblockIdx.x673{2}, iUS674{1}, ithreadIdx.x672{128})
T99_l_float[iblockIdx.x667{2 ex ( ceilDiv(( 6 * ( 2 * 32 ) ), 128) )}, iUS668{1}, ithreadIdx.x666{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS391{1}, iS392{6}, bS393{1 ex 2}, iS394{32})
 allocation domain : (bS391{1}, iS394{32}, iS392{6}, bS393{1 ex 2})
 contiguity: n t t n
  Merge: bS393{1 ex 2} and iS394{32} -> iS535{32 ex ( 2 * 32 )}
  Merge: iS392{6} and iS535{32 ex ( 2 * 32 )} -> iS663{192 ex ( 6 * ( 2 * 32 ) )}
  Merge: bS391{1} and iS663{192 ex ( 6 * ( 2 * 32 ) )} -> iS664{192 ex ( 6 * ( 2 * 32 ) )}
  Split: iS664{192 ex ( 6 * ( 2 * 32 ) )} by factor 128 -> iS665{2 ex ( ceilDiv(( 6 * ( 2 * 32 ) ), 128) )}, ithreadIdx.x666{128}
  Split: iS665{2 ex ( ceilDiv(( 6 * ( 2 * 32 ) ), 128) )} by factor 1 -> iblockIdx.x667{2 ex ( ceilDiv(( 6 * ( 2 * 32 ) ), 128) )}, iUS668{1}
 loop domain : (iblockIdx.x667{2 ex ( ceilDiv(( 6 * ( 2 * 32 ) ), 128) )}, iUS668{1}, ithreadIdx.x666{128})
T100_l_float[iblockIdx.x661{( ceilDiv(( ( 2 * 32 ) * 6 ), 128) )}, iUS662{1}, ithreadIdx.x660{128}] ca_pos( 3 ) produce_pos( 3 )
 root domain : (bS395{1}, iS396{6}, iS397{2 ex 2}rf, iS398{32}rf)
  Merge: iS397{2 ex 2}rf and iS398{32}rf -> iS399{( 2 * 32 )}rf
 logical domain : (bS395{1}, iS396{6}, iS399{( 2 * 32 )}rf)
 contiguity: n t t
  Merge: iS396{6} and iS399{( 2 * 32 )}rf -> iS657{( ( 2 * 32 ) * 6 )}
  Merge: bS395{1} and iS657{( ( 2 * 32 ) * 6 )} -> iS658{( ( 2 * 32 ) * 6 )}
  Split: iS658{( ( 2 * 32 ) * 6 )} by factor 128 -> iS659{( ceilDiv(( ( 2 * 32 ) * 6 ), 128) )}, ithreadIdx.x660{128}
  Split: iS659{( ceilDiv(( ( 2 * 32 ) * 6 ), 128) )} by factor 1 -> iblockIdx.x661{( ceilDiv(( ( 2 * 32 ) * 6 ), 128) )}, iUS662{1}
 loop domain : (iblockIdx.x661{( ceilDiv(( ( 2 * 32 ) * 6 ), 128) )}, iUS662{1}, ithreadIdx.x660{128})
T11_l_float[iblockIdx.x679{3}, iUS680{1}, ithreadIdx.x678{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS30{1}, iS31{6}, iS32{64})
 contiguity: n t t
  Merge: iS31{6} and iS32{64} -> iS675{384}
  Merge: bS30{1} and iS675{384} -> iS676{384}
  Split: iS676{384} by factor 128 -> iS677{3}, ithreadIdx.x678{128}
  Split: iS677{3} by factor 1 -> iblockIdx.x679{3}, iUS680{1}
 loop domain : (iblockIdx.x679{3}, iUS680{1}, ithreadIdx.x678{128})
T13_l_float[iblockIdx.x685{3}, iUS686{1}, ithreadIdx.x684{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS36{1}, iS37{6}, iS38{64})
 contiguity: n t t
  Merge: iS37{6} and iS38{64} -> iS681{384}
  Merge: bS36{1} and iS681{384} -> iS682{384}
  Split: iS682{384} by factor 128 -> iS683{3}, ithreadIdx.x684{128}
  Split: iS683{3} by factor 1 -> iblockIdx.x685{3}, iUS686{1}
 loop domain : (iblockIdx.x685{3}, iUS686{1}, ithreadIdx.x684{128})
T15_l___bfloat[iblockIdx.x691{3}, iUS692{1}, ithreadIdx.x690{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS42{1}, iS43{6}, iS44{64})
 contiguity: n t t
  Merge: iS43{6} and iS44{64} -> iS687{384}
  Merge: bS42{1} and iS687{384} -> iS688{384}
  Split: iS688{384} by factor 128 -> iS689{3}, ithreadIdx.x690{128}
  Split: iS689{3} by factor 1 -> iblockIdx.x691{3}, iUS692{1}
 loop domain : (iblockIdx.x691{3}, iUS692{1}, ithreadIdx.x690{128})
T43_l___bfloat[iblockIdx.x697{3}, iUS698{1}, ithreadIdx.x696{128}, bS144{1}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS143{1}, bS144{1}, iS145{6}, iS146{64})
 allocation domain : (bS143{1}, iS145{6}, iS146{64}, bS144{1})
 contiguity: n t t n
  Merge: iS145{6} and iS146{64} -> iS693{384}
  Merge: bS143{1} and iS693{384} -> iS694{384}
  Split: iS694{384} by factor 128 -> iS695{3}, ithreadIdx.x696{128}
  Split: iS695{3} by factor 1 -> iblockIdx.x697{3}, iUS698{1}
 loop domain : (iblockIdx.x697{3}, iUS698{1}, ithreadIdx.x696{128}, bS144{1})
T59_l___bfloat[iblockIdx.x721{3}, iUS722{1}, ithreadIdx.x720{128}, bS212{1 ex 32}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS211{1}, bS212{1 ex 32}, iS213{6}, iS214{64})
 allocation domain : (bS211{1}, iS213{6}, iS214{64}, bS212{1 ex 32})
 contiguity: n t t n
  Merge: iS213{6} and iS214{64} -> iS717{384}
  Merge: bS211{1} and iS717{384} -> iS718{384}
  Split: iS718{384} by factor 128 -> iS719{3}, ithreadIdx.x720{128}
  Split: iS719{3} by factor 1 -> iblockIdx.x721{3}, iUS722{1}
 loop domain : (iblockIdx.x721{3}, iUS722{1}, ithreadIdx.x720{128}, bS212{1 ex 32})
T126_l_float[iblockIdx.x727{3}, iUS728{1}, ithreadIdx.x726{128}, bS220{1 ex 32}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS219{1}, bS220{1 ex 32}, iS221{6}, iS222{64})
 contiguity: n n t t
  Merge: iS221{6} and iS222{64} -> iS723{384}
  Merge: bS219{1} and iS723{384} -> iS724{384}
  Split: iS724{384} by factor 128 -> iS725{3}, ithreadIdx.x726{128}
  Split: iS725{3} by factor 1 -> iblockIdx.x727{3}, iUS728{1}
 loop domain : (iblockIdx.x727{3}, iUS728{1}, ithreadIdx.x726{128}, bS220{1 ex 32})
T61_g_float[iblockIdx.x733{3}, iUS734{1}, ithreadIdx.x732{128}, bS512{1 ex 32}] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (bS511{1}, bS512{1 ex 32}, iS513{6}, iS514{64})
 contiguity: n n t t
  Merge: iS513{6} and iS514{64} -> iS729{384}
  Merge: bS511{1} and iS729{384} -> iS730{384}
  Split: iS730{384} by factor 128 -> iS731{3}, ithreadIdx.x732{128}
  Split: iS731{3} by factor 1 -> iblockIdx.x733{3}, iUS734{1}
 loop domain : (iblockIdx.x733{3}, iUS734{1}, ithreadIdx.x732{128}, bS512{1 ex 32})
T10_l_float[iblockIdx.x655{3}, iUS656{1}, ithreadIdx.x654{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS27{1}, iS28{6}, iS29{64})
 contiguity: n t t
  Merge: iS28{6} and iS29{64} -> iS651{384}
  Merge: bS27{1} and iS651{384} -> iS652{384}
  Split: iS652{384} by factor 128 -> iS653{3}, ithreadIdx.x654{128}
  Split: iS653{3} by factor 1 -> iblockIdx.x655{3}, iUS656{1}
 loop domain : (iblockIdx.x655{3}, iUS656{1}, ithreadIdx.x654{128})
T12_l_float[iblockIdx.x649{3}, iUS650{1}, ithreadIdx.x648{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS33{1}, iS34{6}, iS35{64})
 contiguity: n t t
  Merge: iS34{6} and iS35{64} -> iS645{384}
  Merge: bS33{1} and iS645{384} -> iS646{384}
  Split: iS646{384} by factor 128 -> iS647{3}, ithreadIdx.x648{128}
  Split: iS647{3} by factor 1 -> iblockIdx.x649{3}, iUS650{1}
 loop domain : (iblockIdx.x649{3}, iUS650{1}, ithreadIdx.x648{128})
T14_l___bfloat[iblockIdx.x643{3}, iUS644{1}, ithreadIdx.x642{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS39{1}, iS40{6}, iS41{64})
 contiguity: n t t
  Merge: iS40{6} and iS41{64} -> iS639{384}
  Merge: bS39{1} and iS639{384} -> iS640{384}
  Split: iS640{384} by factor 128 -> iS641{3}, ithreadIdx.x642{128}
  Split: iS641{3} by factor 1 -> iblockIdx.x643{3}, iUS644{1}
 loop domain : (iblockIdx.x643{3}, iUS644{1}, ithreadIdx.x642{128})
T41_l___bfloat[iblockIdx.x616{3}, iUS617{1}, ithreadIdx.x615{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS135{1}, bS136{1}, iS137{6}, iS138{64})
 allocation domain : (bS135{1}, iS137{6}, iS138{64}, bS136{1})
 contiguity: n t t n
  Merge: iS137{6} and iS138{64} -> iS611{384}
  Merge: bS136{1} and iS611{384} -> iS612{384}
  Merge: bS135{1} and iS612{384} -> iS613{384}
  Split: iS613{384} by factor 128 -> iS614{3}, ithreadIdx.x615{128}
  Split: iS614{3} by factor 1 -> iblockIdx.x616{3}, iUS617{1}
 loop domain : (iblockIdx.x616{3}, iUS617{1}, ithreadIdx.x615{128})
T127_l___bfloat[iblockIdx.x623{3 ex ( ceilDiv(( 8 * 384 ), 128) )}, iUS624{1}, ithreadIdx.x622{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS239{1}, bS240{1 ex 8}, iS241{6}, iS242{64})
 allocation domain : (bS239{1}, iS241{6}, iS242{64}, bS240{1 ex 8})
 contiguity: n t t n
  Merge: iS241{6} and iS242{64} -> iS618{384}
  Merge: bS240{1 ex 8} and iS618{384} -> iS619{384 ex ( 8 * 384 )}
  Merge: bS239{1} and iS619{384 ex ( 8 * 384 )} -> iS620{384 ex ( 8 * 384 )}
  Split: iS620{384 ex ( 8 * 384 )} by factor 128 -> iS621{3 ex ( ceilDiv(( 8 * 384 ), 128) )}, ithreadIdx.x622{128}
  Split: iS621{3 ex ( ceilDiv(( 8 * 384 ), 128) )} by factor 1 -> iblockIdx.x623{3 ex ( ceilDiv(( 8 * 384 ), 128) )}, iUS624{1}
 loop domain : (iblockIdx.x623{3 ex ( ceilDiv(( 8 * 384 ), 128) )}, iUS624{1}, ithreadIdx.x622{128})
T66_g___bfloat[iblockIdx.x630{3 ex ( ceilDiv(( 8 * 384 ), 128) )}, iUS631{1}, ithreadIdx.x629{128}] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (bS515{1}, bS516{1 ex 8}, iS517{6}, iS518{64})
 allocation domain : (bS515{1}, iS517{6}, iS518{64}, bS516{1 ex 8})
 contiguity: n t t n
  Merge: iS517{6} and iS518{64} -> iS625{384}
  Merge: bS516{1 ex 8} and iS625{384} -> iS626{384 ex ( 8 * 384 )}
  Merge: bS515{1} and iS626{384 ex ( 8 * 384 )} -> iS627{384 ex ( 8 * 384 )}
  Split: iS627{384 ex ( 8 * 384 )} by factor 128 -> iS628{3 ex ( ceilDiv(( 8 * 384 ), 128) )}, ithreadIdx.x629{128}
  Split: iS628{3 ex ( ceilDiv(( 8 * 384 ), 128) )} by factor 1 -> iblockIdx.x630{3 ex ( ceilDiv(( 8 * 384 ), 128) )}, iUS631{1}
 loop domain : (iblockIdx.x630{3 ex ( ceilDiv(( 8 * 384 ), 128) )}, iUS631{1}, ithreadIdx.x629{128})
T79_l___bfloat[iblockIdx.x703{3}, iUS704{1}, ithreadIdx.x702{128}, bS296{1 ex 8}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS295{1}, bS296{1 ex 8}, iS297{6}, iS298{64})
 allocation domain : (bS295{1}, iS297{6}, iS298{64}, bS296{1 ex 8})
 contiguity: n t t n
  Merge: iS297{6} and iS298{64} -> iS699{384}
  Merge: bS295{1} and iS699{384} -> iS700{384}
  Split: iS700{384} by factor 128 -> iS701{3}, ithreadIdx.x702{128}
  Split: iS701{3} by factor 1 -> iblockIdx.x703{3}, iUS704{1}
 loop domain : (iblockIdx.x703{3}, iUS704{1}, ithreadIdx.x702{128}, bS296{1 ex 8})
T128_l_float[iblockIdx.x709{3}, iUS710{1}, ithreadIdx.x708{128}, bS304{1 ex 8}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS303{1}, bS304{1 ex 8}, iS305{6}, iS306{64})
 contiguity: n n t t
  Merge: iS305{6} and iS306{64} -> iS705{384}
  Merge: bS303{1} and iS705{384} -> iS706{384}
  Split: iS706{384} by factor 128 -> iS707{3}, ithreadIdx.x708{128}
  Split: iS707{3} by factor 1 -> iblockIdx.x709{3}, iUS710{1}
 loop domain : (iblockIdx.x709{3}, iUS710{1}, ithreadIdx.x708{128}, bS304{1 ex 8})
T81_g_float[iblockIdx.x715{3}, iUS716{1}, ithreadIdx.x714{128}, bS520{1 ex 8}] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (bS519{1}, bS520{1 ex 8}, iS521{6}, iS522{64})
 contiguity: n n t t
  Merge: iS521{6} and iS522{64} -> iS711{384}
  Merge: bS519{1} and iS711{384} -> iS712{384}
  Split: iS712{384} by factor 128 -> iS713{3}, ithreadIdx.x714{128}
  Split: iS713{3} by factor 1 -> iblockIdx.x715{3}, iUS716{1}
 loop domain : (iblockIdx.x715{3}, iUS716{1}, ithreadIdx.x714{128}, bS520{1 ex 8})
T32_g___bfloat[iS581{96}, iS582{1}, iS580{128}]
 logical domain : (bS501{1}, iS502{6}, iS503{2048})
 contiguity: n t t
  Outer split: iS503{2048} by factor 32 -> iS539{32}, iS540{64}
  Merge: iS502{6} and iS540{64} -> iS576{384}
  Merge: iS539{32} and iS576{384} -> iS577{12288}
  Merge: bS501{1} and iS577{12288} -> iS578{12288}
  Split: iS578{12288} by factor 128 -> iS579{96}, iS580{128}
  Split: iS579{96} by factor 1 -> iS581{96}, iS582{1}
 loop domain : (iS581{96}, iS582{1}, iS580{128})
T125_l___bfloat[iblockIdx.x574{96}, iUS575{1}, ithreadIdx.x573{128}] ca_pos( 2 )
 logical domain : (bS508{1}, iS509{6}, iS510{2048})
 contiguity: n t t
  Outer split: iS510{2048} by factor 32 -> iS537{32}, iS538{64}
  Merge: iS509{6} and iS538{64} -> iS569{384}
  Merge: iS537{32} and iS569{384} -> iS570{12288}
  Merge: bS508{1} and iS570{12288} -> iS571{12288}
  Split: iS571{12288} by factor 128 -> iS572{96}, ithreadIdx.x573{128}
  Split: iS572{96} by factor 1 -> iblockIdx.x574{96}, iUS575{1}
 loop domain : (iblockIdx.x574{96}, iUS575{1}, ithreadIdx.x573{128})
T35_l___bfloat[iblockIdx.x567{96}, iUS568{1}, ithreadIdx.x566{128}] ca_pos( 3 ) produce_pos( 2 )
 root domain : (bS105{1}, iS106{6}, iS108{2048}rf)
  Outer split: iS108{2048}rf by factor 32 -> iS109{32}rf, iS110{64}rf
 logical domain : (bS105{1}, iS106{6}, iS109{32}rf, iS110{64}rf)
 allocation domain : (bS105{1}, iS106{6}, iS109{32}rf, iS110{64}rf)
 contiguity: n t t t
  Merge: iS106{6} and iS110{64}rf -> iS562{384}
  Merge: iS109{32}rf and iS562{384} -> iS563{12288}
  Merge: bS105{1} and iS563{12288} -> iS564{12288}
  Split: iS564{12288} by factor 128 -> iS565{96}, ithreadIdx.x566{128}
  Split: iS565{96} by factor 1 -> iblockIdx.x567{96}, iUS568{1}
 loop domain : (iblockIdx.x567{96}, iUS568{1}, ithreadIdx.x566{128})
T36_l___bfloat[iblockIdx.x560{96}, iUS561{1}, ithreadIdx.x559{128}] ca_pos( 3 ) produce_pos( 3 )
 root domain : (bS111{1}, iS112{6}, iS113{32}, iS114{64})
 logical domain : (bS111{1}, iS113{32}, iS112{6}, iS114{64})
 allocation domain : (bS111{1}, iS112{6}, iS113{32}, iS114{64})
 contiguity: n t t t
  Merge: iS112{6} and iS114{64} -> iS555{384}
  Merge: iS113{32} and iS555{384} -> iS556{12288}
  Merge: bS111{1} and iS556{12288} -> iS557{12288}
  Split: iS557{12288} by factor 128 -> iS558{96}, ithreadIdx.x559{128}
  Split: iS558{96} by factor 1 -> iblockIdx.x560{96}, iUS561{1}
 loop domain : (iblockIdx.x560{96}, iUS561{1}, ithreadIdx.x559{128})
T130_l___bfloat[iblockIdx.x553{96}, iUS554{1}, ithreadIdx.x552{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS523{1}, iS524{32}, iS525{6}, iS526{64})
 allocation domain : (bS523{1}, iS525{6}, iS524{32}, iS526{64})
 contiguity: n t t t
  Merge: iS525{6} and iS526{64} -> iS548{384}
  Merge: iS524{32} and iS548{384} -> iS549{12288}
  Merge: bS523{1} and iS549{12288} -> iS550{12288}
  Split: iS550{12288} by factor 128 -> iS551{96}, ithreadIdx.x552{128}
  Split: iS551{96} by factor 1 -> iblockIdx.x553{96}, iUS554{1}
 loop domain : (iblockIdx.x553{96}, iUS554{1}, ithreadIdx.x552{128})
T129_g___bfloat[iblockIdx.x546{96}, iUS547{1}, ithreadIdx.x545{128}] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (bS527{1}, iS528{32}, iS529{6}, iS530{64})
 allocation domain : (bS527{1}, iS529{6}, iS528{32}, iS530{64})
 contiguity: n t t t
  Merge: iS529{6} and iS530{64} -> iS541{384}
  Merge: iS528{32} and iS541{384} -> iS542{12288}
  Merge: bS527{1} and iS542{12288} -> iS543{12288}
  Split: iS543{12288} by factor 128 -> iS544{96}, ithreadIdx.x545{128}
  Split: iS544{96} by factor 1 -> iblockIdx.x546{96}, iUS547{1}
 loop domain : (iblockIdx.x546{96}, iUS547{1}, ithreadIdx.x545{128})
T47_l_float[iblockIdx.x588{96}, iUS589{1}, ithreadIdx.x587{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS159{1}, iS160{32}, iS161{6}, iS162{64})
 contiguity: n t t t
  Merge: iS161{6} and iS162{64} -> iS583{384}
  Merge: iS160{32} and iS583{384} -> iS584{12288}
  Merge: bS159{1} and iS584{12288} -> iS585{12288}
  Split: iS585{12288} by factor 128 -> iS586{96}, ithreadIdx.x587{128}
  Split: iS586{96} by factor 1 -> iblockIdx.x588{96}, iUS589{1}
 loop domain : (iblockIdx.x588{96}, iUS589{1}, ithreadIdx.x587{128})
T46_l___bfloat[iblockIdx.x609{3 ex ( ceilDiv(( 32 * 384 ), 128) )}, iUS610{1}, ithreadIdx.x608{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS155{1}, bS156{1 ex 32}, iS157{6}, iS158{64})
 allocation domain : (bS155{1}, iS157{6}, iS158{64}, bS156{1 ex 32})
 contiguity: n t t n
  Merge: iS157{6} and iS158{64} -> iS604{384}
  Merge: bS156{1 ex 32} and iS604{384} -> iS605{384 ex ( 32 * 384 )}
  Merge: bS155{1} and iS605{384 ex ( 32 * 384 )} -> iS606{384 ex ( 32 * 384 )}
  Split: iS606{384 ex ( 32 * 384 )} by factor 128 -> iS607{3 ex ( ceilDiv(( 32 * 384 ), 128) )}, ithreadIdx.x608{128}
  Split: iS607{3 ex ( ceilDiv(( 32 * 384 ), 128) )} by factor 1 -> iblockIdx.x609{3 ex ( ceilDiv(( 32 * 384 ), 128) )}, iUS610{1}
 loop domain : (iblockIdx.x609{3 ex ( ceilDiv(( 32 * 384 ), 128) )}, iUS610{1}, ithreadIdx.x608{128})
T48_l_float[iblockIdx.x602{3 ex ( ceilDiv(( 32 * 384 ), 128) )}, iUS603{1}, ithreadIdx.x601{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS163{1}, bS164{1 ex 32}, iS165{6}, iS166{64})
 contiguity: n n t t
  Merge: iS165{6} and iS166{64} -> iS597{384}
  Merge: bS164{1 ex 32} and iS597{384} -> iS598{384 ex ( 32 * 384 )}
  Merge: bS163{1} and iS598{384 ex ( 32 * 384 )} -> iS599{384 ex ( 32 * 384 )}
  Split: iS599{384 ex ( 32 * 384 )} by factor 128 -> iS600{3 ex ( ceilDiv(( 32 * 384 ), 128) )}, ithreadIdx.x601{128}
  Split: iS600{3 ex ( ceilDiv(( 32 * 384 ), 128) )} by factor 1 -> iblockIdx.x602{3 ex ( ceilDiv(( 32 * 384 ), 128) )}, iUS603{1}
 loop domain : (iblockIdx.x602{3 ex ( ceilDiv(( 32 * 384 ), 128) )}, iUS603{1}, ithreadIdx.x601{128})
T131_l_float[iblockIdx.x595{96}, iUS596{1}, ithreadIdx.x594{128}] ca_pos( 3 ) produce_pos( 3 )
 logical domain : (bS167{1}, iS168{32}, iS169{6}, iS170{64})
 contiguity: n t t t
  Merge: iS169{6} and iS170{64} -> iS590{384}
  Merge: iS168{32} and iS590{384} -> iS591{12288}
  Merge: bS167{1} and iS591{12288} -> iS592{12288}
  Split: iS592{12288} by factor 128 -> iS593{96}, ithreadIdx.x594{128}
  Split: iS593{96} by factor 1 -> iblockIdx.x595{96}, iUS596{1}
 loop domain : (iblockIdx.x595{96}, iUS596{1}, ithreadIdx.x594{128})
T49_g_float[iblockIdx.x637{96}, iUS638{1}, ithreadIdx.x636{128}] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (bS531{1}, iS532{32}, iS533{6}, iS534{64})
 contiguity: n t t t
  Merge: iS533{6} and iS534{64} -> iS632{384}
  Merge: iS532{32} and iS632{384} -> iS633{12288}
  Merge: bS531{1} and iS633{12288} -> iS634{12288}
  Split: iS634{12288} by factor 128 -> iS635{96}, ithreadIdx.x636{128}
  Split: iS635{96} by factor 1 -> iblockIdx.x637{96}, iUS638{1}
 loop domain : (iblockIdx.x637{96}, iUS638{1}, ithreadIdx.x636{128})
} // %kernel

Here's the Exact graph.

Disjoint Expression groups:
  (exprgs){
    (idgs){idg{108 503 510}} --exprg{35 225 226}--> (idgs){idg{29 32 35 38 41 44 110 114 138 146 158 162 166 170 214 222 242 298 306 399 514 518 522 526 530 534 538 540}, idg{109 113 160 168 524 528 532 537 539}}
    (idgs){idg{1 14 390 394 398 506}, idg{397}} --exprg{120}--> (idgs){idg{29 32 35 38 41 44 110 114 138 146 158 162 166 170 214 222 242 298 306 399 514 518 522 526 530 534 538 540}}
    (idgs){idg{1 14 390 394 398 506}, idg{389 393}} --exprg{223 224}--> (idgs){idg{535 536}}
    (idgs){idg{2 15 28 31 34 37 40 43 106 112 137 145 157 161 165 169 213 221 241 297 305 388 392 396 502 507 509 513 517 521 525 529 533}, idg{29 32 35 38 41 44 110 114 138 146 158 162 166 170 214 222 242 298 306 399 514 518 522 526 530 534 538 540}} --exprg{227 232 237 242 247 252 257 262 267 274 281 286 293 300 305 309 313 318 333 337 341 345 349 353 357 361 365 369}--> (idgs){idg{541 548 555 562 569 576 583 590 597 604 611 618 625 632 639 645 651 657 675 681 687 693 699 705 711 717 723 729}}
    (idgs){idg{109 113 160 168 524 528 532 537 539}, idg{541 548 555 562 569 576 583 590 597 604 611 618 625 632 639 645 651 657 675 681 687 693 699 705 711 717 723 729}} --exprg{228 233 238 243 248 253 258 263 301}--> (idgs){idg{542 549 556 563 570 577 584 591 633}}
    (idgs){idg{0 13 27 30 33 36 39 42 105 111 135 143 155 159 163 167 211 219 239 295 303 387 391 395 501 505 508 511 515 519 523 527 531}, idg{542 549 556 563 570 577 584 591 633}} --exprg{229 234 239 244 249 254 259 264 302}--> (idgs){idg{543 550 557 564 571 578 585 592 634}}
    (idgs){idg{543 550 557 564 571 578 585 592 634}} --exprg{230 235 240 245 250 255 260 265 303}--> (idgs){idg{544 551 558 565 572 579 586 593 635}, idg{545 552 559 566 573 580 587 594 636}}
    (idgs){idg{544 551 558 565 572 579 586 593 635}} --exprg{231 236 241 246 251 256 261 266 304}--> (idgs){idg{546 553 560 567 574 581 588 595 637}, idg{547 554 561 568 575 582 589 596 638}}
    (idgs){idg{136 156 164 240 516}, idg{541 548 555 562 569 576 583 590 597 604 611 618 625 632 639 645 651 657 675 681 687 693 699 705 711 717 723 729}} --exprg{269 276 282 288 295}--> (idgs){idg{598 605 612 619 626}}
    (idgs){idg{0 13 27 30 33 36 39 42 105 111 135 143 155 159 163 167 211 219 239 295 303 387 391 395 501 505 508 511 515 519 523 527 531}, idg{598 605 612 619 626}} --exprg{270 277 283 289 296}--> (idgs){idg{599 606 613 620 627}}
    (idgs){idg{599 606 613 620 627}} --exprg{272 279 284 291 298}--> (idgs){idg{600 607 614 621 628}, idg{601 608 615 622 629}}
    (idgs){idg{600 607 614 621 628}} --exprg{273 280 285 292 299}--> (idgs){idg{602 609 616 623 630}, idg{603 610 617 624 631}}
    (idgs){idg{0 13 27 30 33 36 39 42 105 111 135 143 155 159 163 167 211 219 239 295 303 387 391 395 501 505 508 511 515 519 523 527 531}, idg{541 548 555 562 569 576 583 590 597 604 611 618 625 632 639 645 651 657 675 681 687 693 699 705 711 717 723 729}} --exprg{306 310 314 319 334 338 342 346 350 354 358 362 366 370}--> (idgs){idg{640 646 652 658 676 682 688 694 700 706 712 718 724 730}}
    (idgs){idg{640 646 652 658 676 682 688 694 700 706 712 718 724 730}} --exprg{307 311 315 321 335 339 343 347 351 355 359 363 367 371}--> (idgs){idg{641 647 653 659 677 683 689 695 701 707 713 719 725 731}, idg{642 648 654 660 678 684 690 696 702 708 714 720 726 732}}
    (idgs){idg{641 647 653 659 677 683 689 695 701 707 713 719 725 731}} --exprg{308 312 316 322 336 340 344 348 352 356 360 364 368 372}--> (idgs){idg{643 649 655 661 679 685 691 697 703 709 715 721 727 733}, idg{644 650 656 662 680 686 692 698 704 710 716 722 728 734}}
    (idgs){idg{2 15 28 31 34 37 40 43 106 112 137 145 157 161 165 169 213 221 241 297 305 388 392 396 502 507 509 513 517 521 525 529 533}, idg{535 536}} --exprg{324 329}--> (idgs){idg{663 669}}
    (idgs){idg{0 13 27 30 33 36 39 42 105 111 135 143 155 159 163 167 211 219 239 295 303 387 391 395 501 505 508 511 515 519 523 527 531}, idg{663 669}} --exprg{325 330}--> (idgs){idg{664 670}}
    (idgs){idg{664 670}} --exprg{327 331}--> (idgs){idg{665 671}, idg{666 672}}
    (idgs){idg{665 671}} --exprg{328 332}--> (idgs){idg{667 673}, idg{668 674}}
    (idgs){idg{1 14 390 394 398 506}, idg{2 15 28 31 34 37 40 43 106 112 137 145 157 161 165 169 213 221 241 297 305 388 392 396 502 507 509 513 517 521 525 529 533}} --exprg{373 377 381}--> (idgs){idg{735 741 747}}
    (idgs){idg{0 13 27 30 33 36 39 42 105 111 135 143 155 159 163 167 211 219 239 295 303 387 391 395 501 505 508 511 515 519 523 527 531}, idg{735 741 747}} --exprg{374 378 382}--> (idgs){idg{736 742 748}}
    (idgs){idg{736 742 748}} --exprg{375 379 383}--> (idgs){idg{737 743 749}, idg{738 744 750}}
    (idgs){idg{737 743 749}} --exprg{376 380 384}--> (idgs){idg{739 745 751}, idg{740 746 752}}
  }
 } IdGraph


IdGraph {
Disjoint Ids:
  (idgs){
    idg{0 13 27 30 33 36 39 42 105 111 135 143 155 159 163 167 211 219 239 295 303 387 391 395 501 505 508 511 515 519 523 527 531}
    idg{1 14 390 394 398 506}
    idg{2 15 28 31 34 37 40 43 106 112 137 145 157 161 165 169 213 221 241 297 305 388 392 396 502 507 509 513 517 521 525 529 533}
    idg{29 32 35 38 41 44 110 114 138 146 158 162 166 170 214 222 242 298 306 399 514 518 522 526 530 534 538 540}
    idg{108 503 510}
    idg{109 113 160 168 524 528 532 537 539}
    idg{136 156 164 240 516}
    idg{144 212 220 296 304 512 520}
    idg{389 393}
    idg{397}
    idg{535 536}
    idg{541 548 555 562 569 576 583 590 597 604 611 618 625 632 639 645 651 657 675 681 687 693 699 705 711 717 723 729}
    idg{542 549 556 563 570 577 584 591 633}
    idg{543 550 557 564 571 578 585 592 634}
    idg{544 551 558 565 572 579 586 593 635}
    idg{545 552 559 566 573 580 587 594 636}
    idg{546 553 560 567 574 581 588 595 637}
    idg{547 554 561 568 575 582 589 596 638}
    idg{598 605 612 619 626}
    idg{599 606 613 620 627}
    idg{600 607 614 621 628}
    idg{601 608 615 622 629}
    idg{602 609 616 623 630}
    idg{603 610 617 624 631}
    idg{640 646 652 658 676 682 688 694 700 706 712 718 724 730}
    idg{641 647 653 659 677 683 689 695 701 707 713 719 725 731}
    idg{642 648 654 660 678 684 690 696 702 708 714 720 726 732}
    idg{643 649 655 661 679 685 691 697 703 709 715 721 727 733}
    idg{644 650 656 662 680 686 692 698 704 710 716 722 728 734}
    idg{663 669}
    idg{664 670}
    idg{665 671}
    idg{666 672}
    idg{667 673}
    idg{668 674}
    idg{735 741 747}
    idg{736 742 748}
    idg{737 743 749}
    idg{738 744 750}
    idg{739 745 751}
    idg{740 746 752}
}

I mentioned it's strange to have this group:

idg{108 503 510}

But it's actually this seems to make sense. These three IDs come from:

T32_g___bfloat[iS581{96}, iS582{1}, iS580{128}]
 logical domain : (bS501{1}, iS502{6}, iS503{2048})
 contiguity: n t t
  Outer split: iS503{2048} by factor 32 -> iS539{32}, iS540{64}
  Merge: iS502{6} and iS540{64} -> iS576{384}
  Merge: iS539{32} and iS576{384} -> iS577{12288}
  Merge: bS501{1} and iS577{12288} -> iS578{12288}
  Split: iS578{12288} by factor 128 -> iS579{96}, iS580{128}
  Split: iS579{96} by factor 1 -> iS581{96}, iS582{1}
 loop domain : (iS581{96}, iS582{1}, iS580{128})
T125_l___bfloat[iblockIdx.x574{96}, iUS575{1}, ithreadIdx.x573{128}] ca_pos( 2 )
 logical domain : (bS508{1}, iS509{6}, iS510{2048})
 contiguity: n t t
  Outer split: iS510{2048} by factor 32 -> iS537{32}, iS538{64}
  Merge: iS509{6} and iS538{64} -> iS569{384}
  Merge: iS537{32} and iS569{384} -> iS570{12288}
  Merge: bS508{1} and iS570{12288} -> iS571{12288}
  Split: iS571{12288} by factor 128 -> iS572{96}, ithreadIdx.x573{128}
  Split: iS572{96} by factor 1 -> iblockIdx.x574{96}, iUS575{1}
 loop domain : (iblockIdx.x574{96}, iUS575{1}, ithreadIdx.x573{128})
T35_l___bfloat[iblockIdx.x567{96}, iUS568{1}, ithreadIdx.x566{128}] ca_pos( 3 ) produce_pos( 2 )
 root domain : (bS105{1}, iS106{6}, iS108{2048}rf)
  Outer split: iS108{2048}rf by factor 32 -> iS109{32}rf, iS110{64}rf
 logical domain : (bS105{1}, iS106{6}, iS109{32}rf, iS110{64}rf)
 allocation domain : (bS105{1}, iS106{6}, iS109{32}rf, iS110{64}rf)
 contiguity: n t t t
  Merge: iS106{6} and iS110{64}rf -> iS562{384}
  Merge: iS109{32}rf and iS562{384} -> iS563{12288}
  Merge: bS105{1} and iS563{12288} -> iS564{12288}
  Split: iS564{12288} by factor 128 -> iS565{96}, ithreadIdx.x566{128}
  Split: iS565{96} by factor 1 -> iblockIdx.x567{96}, iUS568{1}
 loop domain : (iblockIdx.x567{96}, iUS568{1}, ithreadIdx.x566{128})

I must have been looking at different IDs.

@tfogal
Copy link
Collaborator

tfogal commented Jan 16, 2025

Reduced test case:

import torch
from nvfuser import FusionDefinition, DataType
def a(FusionDefinition) :
    b = fd.define_tensor([1, 32, True] )
    # Curiously these empty/useless define_tensors are critical to the bug!
    fd.define_tensor([1, 6, 2048], [None, True, True], DataType.BFloat16 )
    d = fd.define_tensor([], [], DataType.BFloat16, stride_order=[0])
    fd.define_tensor([], [], DataType.BFloat16, stride_order=[1, 0])
    fd.define_tensor([], [], DataType.BFloat16, stride_order=[1, 0])
    e = fd.define_tensor([], [], DataType.BFloat16, stride_order=[1, 0])
    f = fd.ops.permute(b, [0, 2, 1])
    g = fd.ops.cat([f, f], -1 )
    ab = fd.ops.cast(g, DataType.BFloat16)
    l = fd.ops.broadcast_in_dim(d, [1, 6, 2048], [2])
    n = fd.ops.mul(l, l)
    o = fd.ops.cast(n, DataType.BFloat16)
    p = fd.ops.linear(o, e)
    ah = fd.ops.reshape(p, [1, 6, 8, 64])
    ai = fd.ops.permute(ah, [0, 2, 1, 3])
    al = fd.ops.broadcast_in_dim(g, [1, 1, 6, 64], [0, 2, 3])
    am = fd.ops.broadcast_in_dim(ab, [1, 1, 6, 64], [0, 2, 3] )
    bc = fd.ops.broadcast_in_dim(al, [1, 8, 6, 64], [0, 1, 2, 3])
    bf = fd.ops.mul(ai, bc)
    bg = fd.ops.slice(ai, [0, 0, 0, 0], [1, 8, 6, 32] )
    bh = fd.ops.slice(ai, [0, 0, 0, 32], [1, 8, 6, 64] )
    bl = fd.ops.cat([bh, bg], -1 )
    bp = fd.ops.mul(bl, am)
    bq = fd.ops.add(bf, bp)
    br = fd.ops.cast(bq, DataType.BFloat16)
    bs = fd.ops.broadcast_in_dim(br, [1, 8, 1, 6, 64], [0, 1, 3, 4])
    bt = fd.ops.broadcast_in_dim(bs, [1, 8, 4, 6, 64], [0, 1, 2, 3, 4])
    bu = fd.ops.reshape(bt, [1, 32, 6, 64])
    by = fd.ops.stride_order(bu, [3, 2, 1, 0])
    fd.add_output(by)

with FusionDefinition() as fd:
    a(fd)
inputs = [
    torch.testing.make_tensor(1, 32, 6, dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor(1, 6, 2048, dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor(2048, dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor(2048, 2048, dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor(2, 2048, dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor(512, 2048, dtype=torch.bfloat16, device='cuda:0')]
fd.execute(inputs)

The automated reduction didn't do a great job; I actually manually reduced this a bit afterwards, and it's possible more simplifications are possible.

@protonu protonu self-assigned this Jan 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants