Skip to content

Commit 450be22

Browse files
hugomanoGoogle-ML-Automation
authored andcommitted
PR #33671: fix(Triton/ROCm): Add missing createTritonGPUAllocateWarpGroups pass to pipeline
Imported from GitHub PR #33671 This PR fixes the Triton compilation pipeline for ROCm by adding the `createTritonGPUAllocateWarpGroups` pass, which was missing. This pass is necessary for the `ExtractThreadDims` function to work correctly during code generation. It adds the `ttg.total-num-warps` attribute to the MLIR module, which is later consumed in `emitter_helpers.cc`. Without this pass, the compilation fails when trying to extract thread dimensions. c/ @khasanovaa @chsigg @AleksaArsic Copybara import of the project: -- 3f9c437 by Hugo Mano <[email protected]>: fix(Triton/ROCm): Add missing createTritonGPUAllocateWarpGroups pass to pipeline -- 4ec8907 by Hugo Mano <[email protected]>: format Merging this change closes #33671 COPYBARA_INTEGRATE_REVIEW=#33671 from hugomano:hugomano/fix-rocm-triton-compilation-pipeline 4ec8907 PiperOrigin-RevId: 829473885
1 parent e07b779 commit 450be22

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ static void MakeLLIR(mlir::OpPassManager* pm,
123123
const stream_executor::RocmComputeCapability& rocm_cc,
124124
int num_stages) {
125125
const int custom_lds_size = 0;
126+
// The `createTritonGPUAllocateWarpGroups` pass is not implemented in the
127+
// upstream Triton, but is necessary for `ExtractThreadDims` in emitter
128+
// helpers. It adds the `ttg.total-num-warps` attribute.
129+
pm->addPass(mt::gpu::createTritonGPUAllocateWarpGroups());
126130
pm->addPass(mlir::triton::AMD::createOptimizeLDSUsagePass(
127131
rocm_cc.gfx_version(), custom_lds_size));
128132
pm->addPass(mlir::createSCFToControlFlowPass());

0 commit comments

Comments
 (0)