-
Notifications
You must be signed in to change notification settings - Fork 672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compilation timeout on Flux transformer #19539
Comments
Debugging tips: https://iree.dev/developers/debugging/compile-time-regressions/ |
Here is small IR to reproduce the issue module @module {
util.global private @__auto.single_blocks.0.linear2.weight = #stream.parameter.named<"model"::"single_blocks.0.linear2.weight"> : tensor<3072x15360xbf16>
util.global private @__auto.single_blocks.0.linear2.bias = #stream.parameter.named<"model"::"single_blocks.0.linear2.bias"> : tensor<3072xbf16>
func.func @forward_bs1(%arg7: !torch.vtensor<[1,1536,3072],bf16>, %arg8: !torch.vtensor<[1,1,3072],bf16>, %arg9: !torch.vtensor<[1,1536,12288],bf16> , %arg10: !torch.vtensor<[1,24,1536,128],f32>, %arg11: !torch.vtensor<[1,24,1536,64,2],f32> , %arg12:!torch.vtensor<[1,1536,21504],bf16> ) -> !torch.vtensor<[1,1024,3072],f32> attributes {torch.assume_strict_symbolic_shapes} {
%__auto.single_blocks.0.linear2.weight = util.global.load @__auto.single_blocks.0.linear2.weight : tensor<3072x15360xbf16>
%46 = torch_c.from_builtin_tensor %__auto.single_blocks.0.linear2.weight : tensor<3072x15360xbf16> -> !torch.vtensor<[3072,15360],bf16>
%__auto.single_blocks.0.linear2.bias = util.global.load @__auto.single_blocks.0.linear2.bias : tensor<3072xbf16>
%47 = torch_c.from_builtin_tensor %__auto.single_blocks.0.linear2.bias : tensor<3072xbf16> -> !torch.vtensor<[3072],bf16>
%int-1_594 = torch.constant.int -1
%int0_595 = torch.constant.int 0
%int9216_596 = torch.constant.int 9216
%int1_597 = torch.constant.int 1
%484 = torch.aten.slice.Tensor %arg12, %int-1_594, %int0_595, %int9216_596, %int1_597 : !torch.vtensor<[1,1536,21504],bf16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1536,9216],bf16>
%int1_602 = torch.constant.int 1
%int1536_603 = torch.constant.int 1536
%int3_604 = torch.constant.int 3
%int24_605 = torch.constant.int 24
%int-1_606 = torch.constant.int -1
%486 = torch.prim.ListConstruct %int1_602, %int1536_603, %int3_604, %int24_605, %int-1_606 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%487 = torch.aten.view %484, %486 : !torch.vtensor<[1,1536,9216],bf16>, !torch.list<int> -> !torch.vtensor<[1,1536,3,24,128],bf16>
%int2_607 = torch.constant.int 2
%int0_608 = torch.constant.int 0
%int3_609 = torch.constant.int 3
%int1_610 = torch.constant.int 1
%int4_611 = torch.constant.int 4
%488 = torch.prim.ListConstruct %int2_607, %int0_608, %int3_609, %int1_610, %int4_611 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%489 = torch.aten.permute %487, %488 : !torch.vtensor<[1,1536,3,24,128],bf16>, !torch.list<int> -> !torch.vtensor<[3,1,24,1536,128],bf16>
%int0_620 = torch.constant.int 0
%int2_621 = torch.constant.int 2
%int3_622 = torch.constant.int 3
%int1_623 = torch.constant.int 1
%492 = torch.aten.slice.Tensor %489, %int0_620, %int2_621, %int3_622, %int1_623 : !torch.vtensor<[3,1,24,1536,128],bf16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,24,1536,128],bf16>
%int0_626 = torch.constant.int 0
%495 = torch.aten.squeeze.dim %492, %int0_626 : !torch.vtensor<[1,1,24,1536,128],bf16>, !torch.int -> !torch.vtensor<[1,24,1536,128],bf16>
%int1_675 = torch.constant.int 1
%int24_676 = torch.constant.int 24
%int1536_677 = torch.constant.int 1536
%int128_678 = torch.constant.int 128
%534 = torch.prim.ListConstruct %int1_675, %int24_676, %int1536_677, %int128_678 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%535 = torch.aten.view %arg11, %534 : !torch.vtensor<[1,24,1536,64,2],f32>, !torch.list<int> -> !torch.vtensor<[1,24,1536,128],f32>
%int15_679 = torch.constant.int 15
%536 = torch.prims.convert_element_type %535, %int15_679 : !torch.vtensor<[1,24,1536,128],f32>, !torch.int -> !torch.vtensor<[1,24,1536,128],bf16>
%int15_684 = torch.constant.int 15
%539 = torch.prims.convert_element_type %arg10, %int15_684 : !torch.vtensor<[1,24,1536,128],f32>, !torch.int -> !torch.vtensor<[1,24,1536,128],bf16>
%float0.000000e00_685 = torch.constant.float 0.000000e+00
%true_686 = torch.constant.bool true
%none_687 = torch.constant.none
%none_688 = torch.constant.none
%540:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%536, %539, %495, %float0.000000e00_685, %true_686, %none_687, %none_688) : (!torch.vtensor<[1,24,1536,128],bf16>, !torch.vtensor<[1,24,1536,128],bf16>, !torch.vtensor<[1,24,1536,128],bf16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,24,1536,128],bf16>, !torch.vtensor<[1,24,1536],f32>)
%int0_689 = torch.constant.int 0
%int2_690 = torch.constant.int 2
%int1_691 = torch.constant.int 1
%int3_692 = torch.constant.int 3
%541 = torch.prim.ListConstruct %int0_689, %int2_690, %int1_691, %int3_692 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%542 = torch.aten.permute %540#0, %541 : !torch.vtensor<[1,24,1536,128],bf16>, !torch.list<int> -> !torch.vtensor<[1,1536,24,128],bf16>
%int1_693 = torch.constant.int 1
%int1536_694 = torch.constant.int 1536
%int-1_695 = torch.constant.int -1
%543 = torch.prim.ListConstruct %int1_693, %int1536_694, %int-1_695 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%544 = torch.aten.view %542, %543 : !torch.vtensor<[1,1536,24,128],bf16>, !torch.list<int> -> !torch.vtensor<[1,1536,3072],bf16>
%str_696 = torch.constant.str "none"
%545 = torch.aten.gelu %arg9, %str_696 : !torch.vtensor<[1,1536,12288],bf16>, !torch.str -> !torch.vtensor<[1,1536,12288],bf16>
%546 = torch.prim.ListConstruct %544, %545 : (!torch.vtensor<[1,1536,3072],bf16>, !torch.vtensor<[1,1536,12288],bf16>) -> !torch.list<vtensor>
%int2_697 = torch.constant.int 2
%547 = torch.aten.cat %546, %int2_697 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,1536,15360],bf16>
%int-2_698 = torch.constant.int -2
%int-1_699 = torch.constant.int -1
%548 = torch.aten.transpose.int %46, %int-2_698, %int-1_699 : !torch.vtensor<[3072,15360],bf16>, !torch.int, !torch.int -> !torch.vtensor<[15360,3072],bf16>
%int1536_700 = torch.constant.int 1536
%int15360_701 = torch.constant.int 15360
%549 = torch.prim.ListConstruct %int1536_700, %int15360_701 : (!torch.int, !torch.int) -> !torch.list<int>
%550 = torch.aten.view %547, %549 : !torch.vtensor<[1,1536,15360],bf16>, !torch.list<int> -> !torch.vtensor<[1536,15360],bf16>
%551 = torch.aten.mm %550, %548 : !torch.vtensor<[1536,15360],bf16>, !torch.vtensor<[15360,3072],bf16> -> !torch.vtensor<[1536,3072],bf16>
%int1_702 = torch.constant.int 1
%int1536_703 = torch.constant.int 1536
%int3072_704 = torch.constant.int 3072
%552 = torch.prim.ListConstruct %int1_702, %int1536_703, %int3072_704 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%553 = torch.aten.view %551, %552 : !torch.vtensor<[1536,3072],bf16>, !torch.list<int> -> !torch.vtensor<[1,1536,3072],bf16>
%int1_705 = torch.constant.int 1
%554 = torch.aten.add.Tensor %553, %47, %int1_705 : !torch.vtensor<[1,1536,3072],bf16>, !torch.vtensor<[3072],bf16>, !torch.int -> !torch.vtensor<[1,1536,3072],bf16>
%555 = torch.aten.mul.Tensor %arg8, %554 : !torch.vtensor<[1,1,3072],bf16>, !torch.vtensor<[1,1536,3072],bf16> -> !torch.vtensor<[1,1536,3072],bf16>
%int1_706 = torch.constant.int 1
%556 = torch.aten.add.Tensor %arg7, %555, %int1_706 : !torch.vtensor<[1,1536,3072],bf16>, !torch.vtensor<[1,1536,3072],bf16>, !torch.int -> !torch.vtensor<[1,1536,3072],bf16>
%int0_707 = torch.constant.int 0
%int0_708 = torch.constant.int 0
%int9223372036854775807_709 = torch.constant.int 9223372036854775807
%int1_710 = torch.constant.int 1
%557 = torch.aten.slice.Tensor %556, %int0_707, %int0_708, %int9223372036854775807_709, %int1_710 : !torch.vtensor<[1,1536,3072],bf16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1536,3072],bf16>
%int1_711 = torch.constant.int 1
%int512_712 = torch.constant.int 512
%int9223372036854775807_713 = torch.constant.int 9223372036854775807
%int1_714 = torch.constant.int 1
%558 = torch.aten.slice.Tensor %557, %int1_711, %int512_712, %int9223372036854775807_713, %int1_714 : !torch.vtensor<[1,1536,3072],bf16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1024,3072],bf16>
%int6_751 = torch.constant.int 6
%573 = torch.prims.convert_element_type %558, %int6_751 : !torch.vtensor<[1,1024,3072],bf16>, !torch.int -> !torch.vtensor<[1,1024,3072],f32>
return %573 : !torch.vtensor<[1,1024,3072],f32>
}
} till 1211 build, it was giving error as
while from 1212, the compilation is hanging. |
@pdhirajkumarprasad, thank you narrowing it down. Based on your input I figured out that This reproducer iree-codegen-optimize-tensor-insert-extract-slices-reproducer.zip uses I may be wrong, but I think I caught it hanging in |
@Groverkss this is failing somewhere in the deep in the guts of the |
I know about this issue. The fix is here #19460 |
Hi @Groverkss, thank you for the fix. I can confirm that #19460 fixes this. |
What happened?
I got a Flux transformer with a single layer exported from Sharktank that did not finish compiling in 10 hours.
The compile flags were copied from the SDXL Unet's regression test.
It also did not compile on another attempt for 1-2 hours.
Steps to reproduce your issue
compile.sh
.What component(s) does this issue relate to?
Compiler
Version information
83af679
Additional context
My IREE build configuration is configure-release.zip
The text was updated successfully, but these errors were encountered: