diff --git a/xla/backends/gpu/autotuner/BUILD b/xla/backends/gpu/autotuner/BUILD index 160cc595935ce..503ee50d6892c 100644 --- a/xla/backends/gpu/autotuner/BUILD +++ b/xla/backends/gpu/autotuner/BUILD @@ -683,10 +683,15 @@ cc_library( ":cublaslt", ":cudnn", ":factory", + ":fission_backend", ":triton", "//xla/backends/autotuner:codegen_backend", "//xla/hlo/analysis:symbolic_expr", + "//xla/hlo/pass:hlo_pass_pipeline", "//xla/service:compiler", + "//xla/service/gpu/transforms:dot_algorithm_rewriter", + "//xla/service/gpu/transforms:gemm_rewriter", + "//xla/stream_executor:device_description", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/platform:platform_object_registry", diff --git a/xla/backends/gpu/autotuner/factory.h b/xla/backends/gpu/autotuner/factory.h index 4d08e5b95182b..9e8fe7c461747 100644 --- a/xla/backends/gpu/autotuner/factory.h +++ b/xla/backends/gpu/autotuner/factory.h @@ -36,6 +36,13 @@ struct GetCodegenBackends { SymbolicExprContext* symbolic_expr_context)>; }; +struct GetFissionBackends { + using Type = std::function>( + stream_executor::StreamExecutor*, const DebugOptions*, Compiler*, + const Compiler::TargetConfig*, + SymbolicExprContext* symbolic_expr_context)>; +}; + } // namespace gpu } // namespace xla diff --git a/xla/backends/gpu/autotuner/factory_cuda.cc b/xla/backends/gpu/autotuner/factory_cuda.cc index e01fdbb516bfb..199d6467866dd 100644 --- a/xla/backends/gpu/autotuner/factory_cuda.cc +++ b/xla/backends/gpu/autotuner/factory_cuda.cc @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_BACKENDS_GPU_AUTOTUNER_CUDA_FACTORY_H_ #include +#include #include #include "xla/backends/autotuner/codegen_backend.h" @@ -24,16 +25,40 @@ limitations under the License. #include "xla/backends/gpu/autotuner/cublaslt.h" #include "xla/backends/gpu/autotuner/cudnn.h" #include "xla/backends/gpu/autotuner/factory.h" +#include "xla/backends/gpu/autotuner/fission_backend.h" #include "xla/backends/gpu/autotuner/triton.h" #include "xla/hlo/analysis/symbolic_expr.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/service/compiler.h" +#include "xla/service/gpu/transforms/dot_algorithm_rewriter.h" +#include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/platform/platform_object_registry.h" #include "xla/stream_executor/stream_executor.h" namespace xla { namespace gpu { +namespace { + +std::unique_ptr GetCublasRewriterPipeline( + const se::DeviceDescription& device_description) { + auto pipeline = std::make_unique("cublas_rewriter_pipeline"); + pipeline->AddPass(std::make_unique()); + for (GemmRewriterOptions::DType dtype : + {GemmRewriterOptions::DType::kFp8Only, + GemmRewriterOptions::DType::kNonFp8Only}) { + auto gemm_rewriter = std::make_unique( + device_description.gpu_compute_capability(), + device_description.runtime_version(), GemmRewriterOptions{dtype}); + pipeline->AddPass(std::move(gemm_rewriter)); + } + return pipeline; +} + +} // namespace + std::vector> GetCodegenBackendsForCuda( stream_executor::StreamExecutor* stream_executor, const DebugOptions* debug_options, Compiler* compiler, @@ -51,10 +76,29 @@ std::vector> GetCodegenBackendsForCuda( return backends; } +std::vector> GetFissionBackendsForCuda( + stream_executor::StreamExecutor* stream_executor, + const DebugOptions* debug_options, Compiler* compiler, + const Compiler::TargetConfig* target_config, + SymbolicExprContext* symbolic_expr_context) { + std::vector> backends; + backends.push_back(std::make_unique( + debug_options, compiler, target_config, + std::make_unique(stream_executor, debug_options, compiler, + target_config), + GetCublasRewriterPipeline(target_config->device_description), + symbolic_expr_context)); + return backends; +} + STREAM_EXECUTOR_REGISTER_OBJECT_STATICALLY(GetCodegenBackendsCudaRegistration, GetCodegenBackends, se::cuda::kCudaPlatformId, GetCodegenBackendsForCuda); +STREAM_EXECUTOR_REGISTER_OBJECT_STATICALLY(GetFissionBackendsCudaRegistration, + GetFissionBackends, + se::cuda::kCudaPlatformId, + GetFissionBackendsForCuda); } // namespace gpu } // namespace xla diff --git a/xla/backends/gpu/autotuner/factory_rocm.cc b/xla/backends/gpu/autotuner/factory_rocm.cc index 4cceefae4902b..6f0a549799576 100644 --- a/xla/backends/gpu/autotuner/factory_rocm.cc +++ b/xla/backends/gpu/autotuner/factory_rocm.cc @@ -45,10 +45,22 @@ std::vector> GetCodegenBackendsForROCm( return backends; } +std::vector> GetFissionBackendsForROCm( + stream_executor::StreamExecutor* stream_executor, + const DebugOptions* debug_options, Compiler* compiler, + const Compiler::TargetConfig* target_config, + SymbolicExprContext* symbolic_expr_context) { + return {}; +} + STREAM_EXECUTOR_REGISTER_OBJECT_STATICALLY(GetCodegenBackendsROCmRegistration, GetCodegenBackends, se::rocm::kROCmPlatformId, GetCodegenBackendsForROCm); +STREAM_EXECUTOR_REGISTER_OBJECT_STATICALLY(GetFissionBackendsROCmRegistration, + GetFissionBackends, + se::rocm::kROCmPlatformId, + GetFissionBackendsForROCm); } // namespace gpu } // namespace xla