@@ -17,23 +17,48 @@ limitations under the License.
1717#define TENSORFLOW_COMPILER_XLA_BACKENDS_GPU_AUTOTUNER_CUDA_FACTORY_H_
1818
1919#include < memory>
20+ #include < utility>
2021#include < vector>
2122
2223#include " xla/backends/autotuner/codegen_backend.h"
2324#include " xla/backends/gpu/autotuner/cublas.h"
2425#include " xla/backends/gpu/autotuner/cublaslt.h"
2526#include " xla/backends/gpu/autotuner/cudnn.h"
2627#include " xla/backends/gpu/autotuner/factory.h"
28+ #include " xla/backends/gpu/autotuner/fission_backend.h"
2729#include " xla/backends/gpu/autotuner/triton.h"
2830#include " xla/hlo/analysis/symbolic_expr.h"
31+ #include " xla/hlo/pass/hlo_pass_pipeline.h"
2932#include " xla/service/compiler.h"
33+ #include " xla/service/gpu/transforms/dot_algorithm_rewriter.h"
34+ #include " xla/service/gpu/transforms/gemm_rewriter.h"
3035#include " xla/stream_executor/cuda/cuda_platform_id.h"
36+ #include " xla/stream_executor/device_description.h"
3137#include " xla/stream_executor/platform/platform_object_registry.h"
3238#include " xla/stream_executor/stream_executor.h"
3339
3440namespace xla {
3541namespace gpu {
3642
43+ namespace {
44+
45+ std::unique_ptr<HloPassPipeline> GetCublasRewriterPipeline (
46+ const se::DeviceDescription& device_description) {
47+ auto pipeline = std::make_unique<HloPassPipeline>(" cublas_rewriter_pipeline" );
48+ pipeline->AddPass (std::make_unique<DotAlgorithmRewriter>());
49+ for (GemmRewriterOptions::DType dtype :
50+ {GemmRewriterOptions::DType::kFp8Only ,
51+ GemmRewriterOptions::DType::kNonFp8Only }) {
52+ auto gemm_rewriter = std::make_unique<GemmRewriter>(
53+ device_description.gpu_compute_capability (),
54+ device_description.runtime_version (), GemmRewriterOptions{dtype});
55+ pipeline->AddPass (std::move (gemm_rewriter));
56+ }
57+ return pipeline;
58+ }
59+
60+ } // namespace
61+
3762std::vector<std::unique_ptr<CodegenBackend>> GetCodegenBackendsForCuda (
3863 stream_executor::StreamExecutor* stream_executor,
3964 const DebugOptions* debug_options, Compiler* compiler,
@@ -51,10 +76,29 @@ std::vector<std::unique_ptr<CodegenBackend>> GetCodegenBackendsForCuda(
5176 return backends;
5277}
5378
79+ std::vector<std::unique_ptr<CodegenBackend>> GetFissionBackendsForCuda (
80+ stream_executor::StreamExecutor* stream_executor,
81+ const DebugOptions* debug_options, Compiler* compiler,
82+ const Compiler::TargetConfig* target_config,
83+ SymbolicExprContext* symbolic_expr_context) {
84+ std::vector<std::unique_ptr<CodegenBackend>> backends;
85+ backends.push_back (std::make_unique<FissionBackend>(
86+ debug_options, compiler, target_config,
87+ std::make_unique<CublasBackend>(stream_executor, debug_options, compiler,
88+ target_config),
89+ GetCublasRewriterPipeline (target_config->device_description ),
90+ symbolic_expr_context));
91+ return backends;
92+ }
93+
5494STREAM_EXECUTOR_REGISTER_OBJECT_STATICALLY (GetCodegenBackendsCudaRegistration,
5595 GetCodegenBackends,
5696 se::cuda::kCudaPlatformId ,
5797 GetCodegenBackendsForCuda);
98+ STREAM_EXECUTOR_REGISTER_OBJECT_STATICALLY (GetFissionBackendsCudaRegistration,
99+ GetFissionBackends,
100+ se::cuda::kCudaPlatformId ,
101+ GetFissionBackendsForCuda);
58102
59103} // namespace gpu
60104} // namespace xla
0 commit comments