Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation timeout on Flux transformer #19539

Closed
sogartar opened this issue Dec 20, 2024 · 6 comments
Closed

Compilation timeout on Flux transformer #19539

sogartar opened this issue Dec 20, 2024 · 6 comments
Assignees
Labels
bug 🐞 Something isn't working regression Marks regression of feature, compatibility or performance

Comments

@sogartar
Copy link
Contributor

sogartar commented Dec 20, 2024

What happened?

I got a Flux transformer with a single layer exported from Sharktank that did not finish compiling in 10 hours.
The compile flags were copied from the SDXL Unet's regression test.

It also did not compile on another attempt for 1-2 hours.

Steps to reproduce your issue

  1. Download flux-transformer-compilation-timeout.zip and extract and run compile.sh.

What component(s) does this issue relate to?

Compiler

Version information

83af679

Additional context

My IREE build configuration is configure-release.zip

@sogartar sogartar added the bug 🐞 Something isn't working label Dec 20, 2024
@sogartar sogartar self-assigned this Dec 20, 2024
@ScottTodd
Copy link
Member

@pdhirajkumarprasad
Copy link

Here is small IR to reproduce the issue

module @module {
  util.global private @__auto.single_blocks.0.linear2.weight = #stream.parameter.named<"model"::"single_blocks.0.linear2.weight"> : tensor<3072x15360xbf16>
  util.global private @__auto.single_blocks.0.linear2.bias = #stream.parameter.named<"model"::"single_blocks.0.linear2.bias"> : tensor<3072xbf16>
  func.func @forward_bs1(%arg7: !torch.vtensor<[1,1536,3072],bf16>, %arg8: !torch.vtensor<[1,1,3072],bf16>, %arg9: !torch.vtensor<[1,1536,12288],bf16> , %arg10: !torch.vtensor<[1,24,1536,128],f32>, %arg11: !torch.vtensor<[1,24,1536,64,2],f32>  , %arg12:!torch.vtensor<[1,1536,21504],bf16>    ) -> !torch.vtensor<[1,1024,3072],f32> attributes {torch.assume_strict_symbolic_shapes} {
    %__auto.single_blocks.0.linear2.weight = util.global.load @__auto.single_blocks.0.linear2.weight : tensor<3072x15360xbf16>
    %46 = torch_c.from_builtin_tensor %__auto.single_blocks.0.linear2.weight : tensor<3072x15360xbf16> -> !torch.vtensor<[3072,15360],bf16>
    %__auto.single_blocks.0.linear2.bias = util.global.load @__auto.single_blocks.0.linear2.bias : tensor<3072xbf16>
    %47 = torch_c.from_builtin_tensor %__auto.single_blocks.0.linear2.bias : tensor<3072xbf16> -> !torch.vtensor<[3072],bf16>
    %int-1_594 = torch.constant.int -1
    %int0_595 = torch.constant.int 0
    %int9216_596 = torch.constant.int 9216
    %int1_597 = torch.constant.int 1
    %484 = torch.aten.slice.Tensor %arg12, %int-1_594, %int0_595, %int9216_596, %int1_597 : !torch.vtensor<[1,1536,21504],bf16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1536,9216],bf16>
    %int1_602 = torch.constant.int 1
    %int1536_603 = torch.constant.int 1536
    %int3_604 = torch.constant.int 3
    %int24_605 = torch.constant.int 24
    %int-1_606 = torch.constant.int -1
    %486 = torch.prim.ListConstruct %int1_602, %int1536_603, %int3_604, %int24_605, %int-1_606 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %487 = torch.aten.view %484, %486 : !torch.vtensor<[1,1536,9216],bf16>, !torch.list<int> -> !torch.vtensor<[1,1536,3,24,128],bf16>
    %int2_607 = torch.constant.int 2
    %int0_608 = torch.constant.int 0
    %int3_609 = torch.constant.int 3
    %int1_610 = torch.constant.int 1
    %int4_611 = torch.constant.int 4
    %488 = torch.prim.ListConstruct %int2_607, %int0_608, %int3_609, %int1_610, %int4_611 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %489 = torch.aten.permute %487, %488 : !torch.vtensor<[1,1536,3,24,128],bf16>, !torch.list<int> -> !torch.vtensor<[3,1,24,1536,128],bf16>
    %int0_620 = torch.constant.int 0
    %int2_621 = torch.constant.int 2
    %int3_622 = torch.constant.int 3
    %int1_623 = torch.constant.int 1
    %492 = torch.aten.slice.Tensor %489, %int0_620, %int2_621, %int3_622, %int1_623 : !torch.vtensor<[3,1,24,1536,128],bf16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,24,1536,128],bf16>
    %int0_626 = torch.constant.int 0
    %495 = torch.aten.squeeze.dim %492, %int0_626 : !torch.vtensor<[1,1,24,1536,128],bf16>, !torch.int -> !torch.vtensor<[1,24,1536,128],bf16>
    %int1_675 = torch.constant.int 1
    %int24_676 = torch.constant.int 24
    %int1536_677 = torch.constant.int 1536
    %int128_678 = torch.constant.int 128
    %534 = torch.prim.ListConstruct %int1_675, %int24_676, %int1536_677, %int128_678 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %535 = torch.aten.view %arg11, %534 : !torch.vtensor<[1,24,1536,64,2],f32>, !torch.list<int> -> !torch.vtensor<[1,24,1536,128],f32>
    %int15_679 = torch.constant.int 15
    %536 = torch.prims.convert_element_type %535, %int15_679 : !torch.vtensor<[1,24,1536,128],f32>, !torch.int -> !torch.vtensor<[1,24,1536,128],bf16>
    %int15_684 = torch.constant.int 15
    %539 = torch.prims.convert_element_type %arg10, %int15_684 : !torch.vtensor<[1,24,1536,128],f32>, !torch.int -> !torch.vtensor<[1,24,1536,128],bf16>
    %float0.000000e00_685 = torch.constant.float 0.000000e+00
    %true_686 = torch.constant.bool true
    %none_687 = torch.constant.none
    %none_688 = torch.constant.none
    %540:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%536, %539, %495, %float0.000000e00_685, %true_686, %none_687, %none_688) : (!torch.vtensor<[1,24,1536,128],bf16>, !torch.vtensor<[1,24,1536,128],bf16>, !torch.vtensor<[1,24,1536,128],bf16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,24,1536,128],bf16>, !torch.vtensor<[1,24,1536],f32>) 
    %int0_689 = torch.constant.int 0
    %int2_690 = torch.constant.int 2
    %int1_691 = torch.constant.int 1
    %int3_692 = torch.constant.int 3
    %541 = torch.prim.ListConstruct %int0_689, %int2_690, %int1_691, %int3_692 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %542 = torch.aten.permute %540#0, %541 : !torch.vtensor<[1,24,1536,128],bf16>, !torch.list<int> -> !torch.vtensor<[1,1536,24,128],bf16>
    %int1_693 = torch.constant.int 1
    %int1536_694 = torch.constant.int 1536
    %int-1_695 = torch.constant.int -1
    %543 = torch.prim.ListConstruct %int1_693, %int1536_694, %int-1_695 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %544 = torch.aten.view %542, %543 : !torch.vtensor<[1,1536,24,128],bf16>, !torch.list<int> -> !torch.vtensor<[1,1536,3072],bf16>
    %str_696 = torch.constant.str "none"
    %545 = torch.aten.gelu %arg9, %str_696 : !torch.vtensor<[1,1536,12288],bf16>, !torch.str -> !torch.vtensor<[1,1536,12288],bf16>
    %546 = torch.prim.ListConstruct %544, %545 : (!torch.vtensor<[1,1536,3072],bf16>, !torch.vtensor<[1,1536,12288],bf16>) -> !torch.list<vtensor>
    %int2_697 = torch.constant.int 2
    %547 = torch.aten.cat %546, %int2_697 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,1536,15360],bf16>
    %int-2_698 = torch.constant.int -2
    %int-1_699 = torch.constant.int -1
    %548 = torch.aten.transpose.int %46, %int-2_698, %int-1_699 : !torch.vtensor<[3072,15360],bf16>, !torch.int, !torch.int -> !torch.vtensor<[15360,3072],bf16>
    %int1536_700 = torch.constant.int 1536
    %int15360_701 = torch.constant.int 15360
    %549 = torch.prim.ListConstruct %int1536_700, %int15360_701 : (!torch.int, !torch.int) -> !torch.list<int>
    %550 = torch.aten.view %547, %549 : !torch.vtensor<[1,1536,15360],bf16>, !torch.list<int> -> !torch.vtensor<[1536,15360],bf16>
    %551 = torch.aten.mm %550, %548 : !torch.vtensor<[1536,15360],bf16>, !torch.vtensor<[15360,3072],bf16> -> !torch.vtensor<[1536,3072],bf16>
    %int1_702 = torch.constant.int 1
    %int1536_703 = torch.constant.int 1536
    %int3072_704 = torch.constant.int 3072
    %552 = torch.prim.ListConstruct %int1_702, %int1536_703, %int3072_704 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %553 = torch.aten.view %551, %552 : !torch.vtensor<[1536,3072],bf16>, !torch.list<int> -> !torch.vtensor<[1,1536,3072],bf16>
    %int1_705 = torch.constant.int 1
    %554 = torch.aten.add.Tensor %553, %47, %int1_705 : !torch.vtensor<[1,1536,3072],bf16>, !torch.vtensor<[3072],bf16>, !torch.int -> !torch.vtensor<[1,1536,3072],bf16>
    %555 = torch.aten.mul.Tensor %arg8, %554 : !torch.vtensor<[1,1,3072],bf16>, !torch.vtensor<[1,1536,3072],bf16> -> !torch.vtensor<[1,1536,3072],bf16>
    %int1_706 = torch.constant.int 1
    %556 = torch.aten.add.Tensor %arg7, %555, %int1_706 : !torch.vtensor<[1,1536,3072],bf16>, !torch.vtensor<[1,1536,3072],bf16>, !torch.int -> !torch.vtensor<[1,1536,3072],bf16>
    %int0_707 = torch.constant.int 0
    %int0_708 = torch.constant.int 0
    %int9223372036854775807_709 = torch.constant.int 9223372036854775807
    %int1_710 = torch.constant.int 1
    %557 = torch.aten.slice.Tensor %556, %int0_707, %int0_708, %int9223372036854775807_709, %int1_710 : !torch.vtensor<[1,1536,3072],bf16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1536,3072],bf16>
    %int1_711 = torch.constant.int 1
    %int512_712 = torch.constant.int 512
    %int9223372036854775807_713 = torch.constant.int 9223372036854775807
    %int1_714 = torch.constant.int 1
    %558 = torch.aten.slice.Tensor %557, %int1_711, %int512_712, %int9223372036854775807_713, %int1_714 : !torch.vtensor<[1,1536,3072],bf16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1024,3072],bf16>
    %int6_751 = torch.constant.int 6
    %573 = torch.prims.convert_element_type %558, %int6_751 : !torch.vtensor<[1,1024,3072],bf16>, !torch.int -> !torch.vtensor<[1,1024,3072],f32>
    return %573 : !torch.vtensor<[1,1024,3072],f32>
  }
}

till 1211 build, it was giving error as

error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none", waves_per_eu = 2 : i64}>

while from 1212, the compilation is hanging.

@pdhirajkumarprasad pdhirajkumarprasad added the regression Marks regression of feature, compatibility or performance label Dec 20, 2024
@sogartar
Copy link
Contributor Author

sogartar commented Dec 20, 2024

@pdhirajkumarprasad, thank you narrowing it down. Based on your input I figured out that iree-codegen-optimize-tensor-insert-extract-slices is the culprit pass.

This reproducer iree-codegen-optimize-tensor-insert-extract-slices-reproducer.zip uses iree-opt just to run that pass and it hangs again.
More precisely, it hangs in hoistLoopInvariantSubsetAtIterArg.

I may be wrong, but I think I caught it hanging in iree-rocdl-annotate-kernel-for-translation a couple of times as well but I can't reproduce this again.

@MaheshRavishankar
Copy link
Contributor

@Groverkss this is failing somewhere in the deep in the guts of the ValueBoundInterface. Could you PTAL.

@Groverkss
Copy link
Contributor

I know about this issue. The fix is here #19460

@sogartar
Copy link
Contributor Author

sogartar commented Jan 2, 2025

Hi @Groverkss, thank you for the fix. I can confirm that #19460 fixes this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working regression Marks regression of feature, compatibility or performance
Projects
None yet
Development

No branches or pull requests

5 participants