Skip to content

Commit e9bd138

Browse files
[Autotuner] Introduce FissionBackends factory.
- We are not yet using backend factories in the pass, we will switch once all the autotuner passes are ported to new infra. PiperOrigin-RevId: 826034090
1 parent 9a77a88 commit e9bd138

20 files changed

+590
-23
lines changed

xla/backends/gpu/autotuner/BUILD

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,10 +683,15 @@ cc_library(
683683
":cublaslt",
684684
":cudnn",
685685
":factory",
686+
":fission_backend",
686687
":triton",
687688
"//xla/backends/autotuner:codegen_backend",
688689
"//xla/hlo/analysis:symbolic_expr",
690+
"//xla/hlo/pass:hlo_pass_pipeline",
689691
"//xla/service:compiler",
692+
"//xla/service/gpu/transforms:dot_algorithm_rewriter",
693+
"//xla/service/gpu/transforms:gemm_rewriter",
694+
"//xla/stream_executor:device_description",
690695
"//xla/stream_executor:stream_executor_h",
691696
"//xla/stream_executor/cuda:cuda_platform_id",
692697
"//xla/stream_executor/platform:platform_object_registry",
@@ -787,6 +792,60 @@ cc_library(
787792
],
788793
)
789794

795+
cc_library(
796+
name = "fission_backend",
797+
srcs = ["fission_backend.cc"],
798+
hdrs = ["fission_backend.h"],
799+
deps = [
800+
":gpu_codegen_backend",
801+
"//xla/backends/autotuner:codegen_backend",
802+
"//xla/hlo/analysis:symbolic_expr",
803+
"//xla/hlo/ir:hlo",
804+
"//xla/hlo/pass:hlo_pass_pipeline",
805+
"//xla/service:compiler",
806+
"//xla/service:hlo_cost_analysis",
807+
"//xla/service/gpu/transforms:priority_fusion",
808+
"//xla/stream_executor:stream_executor_h",
809+
"//xla/tools:hlo_decomposer_lib",
810+
"//xla/tsl/platform:errors",
811+
"//xla/tsl/platform:statusor",
812+
"@com_google_absl//absl/container:flat_hash_map",
813+
"@com_google_absl//absl/log",
814+
"@com_google_absl//absl/status",
815+
"@com_google_absl//absl/status:statusor",
816+
"@com_google_absl//absl/strings",
817+
],
818+
)
819+
820+
xla_test(
821+
name = "fission_backend_test",
822+
srcs = ["fission_backend_test.cc"],
823+
backends = ["h100"],
824+
tags = ["cuda-only"],
825+
deps = [
826+
":cublas",
827+
":fission_backend",
828+
":gpu_codegen_backend",
829+
"//xla/backends/autotuner:codegen_backend",
830+
"//xla/hlo/analysis:symbolic_expr",
831+
"//xla/hlo/ir:hlo",
832+
"//xla/hlo/pass:hlo_pass_pipeline",
833+
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
834+
"//xla/service:compiler",
835+
"//xla/service:executable",
836+
"//xla/service:platform_util",
837+
"//xla/service/gpu:nvptx_compiler_impl",
838+
"//xla/service/gpu/transforms:dot_algorithm_rewriter",
839+
"//xla/service/gpu/transforms:gemm_rewriter",
840+
"//xla/stream_executor:device_description",
841+
"//xla/stream_executor:stream_executor_h",
842+
"//xla/tsl/platform:statusor",
843+
"@com_google_absl//absl/status:statusor",
844+
"@com_google_googletest//:gtest_main",
845+
"@llvm-project//mlir:IR",
846+
],
847+
)
848+
790849
xla_cc_test(
791850
name = "legacy_cache_test",
792851
srcs = ["legacy_cache_test.cc"],

xla/backends/gpu/autotuner/cublas.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,5 +153,9 @@ absl::Status CublasBackend::ApplyConfig(HloInstruction& instr,
153153
return absl::OkStatus();
154154
}
155155

156+
bool CublasBackend::IsSupported(const HloInstruction& instr) {
157+
return IsLegacyCublasMatmul(instr);
158+
}
159+
156160
} // namespace gpu
157161
} // namespace xla

xla/backends/gpu/autotuner/cublas.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class CublasBackend : public GpuCodegenBackend {
6060

6161
absl::Status ApplyConfig(HloInstruction& instr,
6262
const BackendConfig& config) override;
63+
64+
private:
65+
bool IsSupported(const HloInstruction& instr) override;
6366
};
6467

6568
} // namespace gpu

xla/backends/gpu/autotuner/cublaslt.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ absl::StatusOr<BlasLt::Epilogue> AsBlasLtEpilogue(
7474
}
7575
}
7676

77-
bool IsSupported(const HloInstruction& instr) {
77+
} // namespace
78+
79+
bool CublasLtBackend::IsSupported(const HloInstruction& instr) {
7880
return IsCublasLtMatmul(instr) || IsCublasLtMatmulF8(instr);
7981
}
8082

81-
} // namespace
82-
8383
absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>>
8484
CublasLtBackend::GetSupportedConfigs(const HloInstruction& instr) {
8585
if (!IsSupported(instr)) {

xla/backends/gpu/autotuner/cublaslt.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ class CublasLtBackend : public GpuCodegenBackend {
5959

6060
absl::Status ApplyConfig(HloInstruction& instr,
6161
const BackendConfig& config) override;
62+
63+
private:
64+
bool IsSupported(const HloInstruction& instr) override;
6265
};
6366

6467
} // namespace gpu

xla/backends/gpu/autotuner/cudnn.cc

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -155,20 +155,6 @@ bool IsSupportedCudnnFusion(const HloInstruction& instr,
155155
return false;
156156
}
157157

158-
bool IsSupportedByCudnn(const HloInstruction& instr,
159-
se::StreamExecutor* stream_executor,
160-
const DebugOptions& debug_options) {
161-
if (instr.opcode() == HloOpcode::kFusion) {
162-
return IsSupportedCudnnFusion(instr, stream_executor, debug_options);
163-
}
164-
165-
if (instr.opcode() == HloOpcode::kCustomCall) {
166-
return IsCustomCallToDnnConvolution(instr);
167-
}
168-
169-
return false;
170-
}
171-
172158
absl::StatusOr<std::vector<CudnnBackendConfig>> GetAlgorithms(
173159
se::dnn::DnnSupport* dnn, se::dnn::ConvolutionKind conv_kind,
174160
se::dnn::DataType input_type, se::dnn::DataType output_type,
@@ -338,6 +324,18 @@ absl::Status ApplyConfigToCudnnCustomCall(HloInstruction& instr,
338324

339325
} // namespace
340326

327+
bool CudnnBackend::IsSupported(const HloInstruction& instr) {
328+
if (instr.opcode() == HloOpcode::kFusion) {
329+
return IsSupportedCudnnFusion(instr, stream_executor(), debug_options());
330+
}
331+
332+
if (instr.opcode() == HloOpcode::kCustomCall) {
333+
return IsCustomCallToDnnConvolution(instr);
334+
}
335+
336+
return false;
337+
}
338+
341339
absl::StatusOr<std::unique_ptr<BackendConfig>> CudnnBackend::GetDefaultConfig(
342340
const HloInstruction& instr) {
343341
if (IsCustomCallToDnnConvolution(instr)) {
@@ -358,7 +356,7 @@ absl::StatusOr<std::unique_ptr<BackendConfig>> CudnnBackend::GetDefaultConfig(
358356

359357
absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>>
360358
CudnnBackend::GetSupportedConfigs(const HloInstruction& instr) {
361-
if (!IsSupportedByCudnn(instr, stream_executor(), debug_options())) {
359+
if (!IsSupported(instr)) {
362360
return std::vector<std::unique_ptr<BackendConfig>>();
363361
}
364362
if (instr.opcode() == HloOpcode::kFusion) {

xla/backends/gpu/autotuner/cudnn.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ class CudnnBackend : public GpuCodegenBackend {
7272
// apply the configs with non-zero workspace size.
7373
absl::Status ApplyConfig(HloInstruction& instr,
7474
const BackendConfig& config) override;
75+
76+
private:
77+
bool IsSupported(const HloInstruction& instr) override;
7578
};
7679

7780
} // namespace gpu

xla/backends/gpu/autotuner/custom_kernel.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ namespace se = ::stream_executor;
4444

4545
using CustomKernelBackendConfig = AutotuneResult::CustomKernelFusionKey;
4646

47-
namespace {
48-
bool IsSupported(const HloInstruction& instr) {
47+
bool CustomKernelBackend::IsSupported(const HloInstruction& instr) {
4948
if (instr.opcode() != HloOpcode::kFusion) {
5049
LOG(ERROR)
5150
<< "CustomKernelBackend doesn't support non-fusion instructions.";
@@ -61,7 +60,6 @@ bool IsSupported(const HloInstruction& instr) {
6160

6261
return true;
6362
}
64-
} // namespace
6563

6664
absl::StatusOr<std::vector<CustomKernel>> LoadKernels(
6765
const HloInstruction* fusion_instruction,

xla/backends/gpu/autotuner/custom_kernel.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class CustomKernelBackend : public GpuCodegenBackend {
4848

4949
absl::Status ApplyConfig(HloInstruction& instr,
5050
const BackendConfig& config) override;
51+
52+
private:
53+
bool IsSupported(const HloInstruction& instr) override;
5154
};
5255

5356
} // namespace gpu

xla/backends/gpu/autotuner/factory.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ struct GetCodegenBackends {
3636
SymbolicExprContext* symbolic_expr_context)>;
3737
};
3838

39+
struct GetFissionBackends {
40+
using Type = std::function<std::vector<std::unique_ptr<CodegenBackend>>(
41+
stream_executor::StreamExecutor*, const DebugOptions*, Compiler*,
42+
const Compiler::TargetConfig*,
43+
SymbolicExprContext* symbolic_expr_context)>;
44+
};
45+
3946
} // namespace gpu
4047
} // namespace xla
4148

0 commit comments

Comments
 (0)