Skip to content

Commit 5feebde

Browse files
[Autotuner] Introduce FissionBackends factory.
- We are not yet using backend factories in the pass, we will switch once all the autotuner passes are ported to new infra. PiperOrigin-RevId: 826034090
1 parent d2b22a6 commit 5feebde

File tree

4 files changed

+68
-0
lines changed

4 files changed

+68
-0
lines changed

xla/backends/gpu/autotuner/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,10 +683,15 @@ cc_library(
683683
":cublaslt",
684684
":cudnn",
685685
":factory",
686+
":fission_backend",
686687
":triton",
687688
"//xla/backends/autotuner:codegen_backend",
688689
"//xla/hlo/analysis:symbolic_expr",
690+
"//xla/hlo/pass:hlo_pass_pipeline",
689691
"//xla/service:compiler",
692+
"//xla/service/gpu/transforms:dot_algorithm_rewriter",
693+
"//xla/service/gpu/transforms:gemm_rewriter",
694+
"//xla/stream_executor:device_description",
690695
"//xla/stream_executor:stream_executor_h",
691696
"//xla/stream_executor/cuda:cuda_platform_id",
692697
"//xla/stream_executor/platform:platform_object_registry",

xla/backends/gpu/autotuner/factory.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ struct GetCodegenBackends {
3636
SymbolicExprContext* symbolic_expr_context)>;
3737
};
3838

39+
struct GetFissionBackends {
40+
using Type = std::function<std::vector<std::unique_ptr<CodegenBackend>>(
41+
stream_executor::StreamExecutor*, const DebugOptions*, Compiler*,
42+
const Compiler::TargetConfig*,
43+
SymbolicExprContext* symbolic_expr_context)>;
44+
};
45+
3946
} // namespace gpu
4047
} // namespace xla
4148

xla/backends/gpu/autotuner/factory_cuda.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,48 @@ limitations under the License.
1717
#define TENSORFLOW_COMPILER_XLA_BACKENDS_GPU_AUTOTUNER_CUDA_FACTORY_H_
1818

1919
#include <memory>
20+
#include <utility>
2021
#include <vector>
2122

2223
#include "xla/backends/autotuner/codegen_backend.h"
2324
#include "xla/backends/gpu/autotuner/cublas.h"
2425
#include "xla/backends/gpu/autotuner/cublaslt.h"
2526
#include "xla/backends/gpu/autotuner/cudnn.h"
2627
#include "xla/backends/gpu/autotuner/factory.h"
28+
#include "xla/backends/gpu/autotuner/fission_backend.h"
2729
#include "xla/backends/gpu/autotuner/triton.h"
2830
#include "xla/hlo/analysis/symbolic_expr.h"
31+
#include "xla/hlo/pass/hlo_pass_pipeline.h"
2932
#include "xla/service/compiler.h"
33+
#include "xla/service/gpu/transforms/dot_algorithm_rewriter.h"
34+
#include "xla/service/gpu/transforms/gemm_rewriter.h"
3035
#include "xla/stream_executor/cuda/cuda_platform_id.h"
36+
#include "xla/stream_executor/device_description.h"
3137
#include "xla/stream_executor/platform/platform_object_registry.h"
3238
#include "xla/stream_executor/stream_executor.h"
3339

3440
namespace xla {
3541
namespace gpu {
3642

43+
namespace {
44+
45+
std::unique_ptr<HloPassPipeline> GetCublasRewriterPipeline(
46+
const se::DeviceDescription& device_description) {
47+
auto pipeline = std::make_unique<HloPassPipeline>("cublas_rewriter_pipeline");
48+
pipeline->AddPass(std::make_unique<DotAlgorithmRewriter>());
49+
for (GemmRewriterOptions::DType dtype :
50+
{GemmRewriterOptions::DType::kFp8Only,
51+
GemmRewriterOptions::DType::kNonFp8Only}) {
52+
auto gemm_rewriter = std::make_unique<GemmRewriter>(
53+
device_description.gpu_compute_capability(),
54+
device_description.runtime_version(), GemmRewriterOptions{dtype});
55+
pipeline->AddPass(std::move(gemm_rewriter));
56+
}
57+
return pipeline;
58+
}
59+
60+
} // namespace
61+
3762
std::vector<std::unique_ptr<CodegenBackend>> GetCodegenBackendsForCuda(
3863
stream_executor::StreamExecutor* stream_executor,
3964
const DebugOptions* debug_options, Compiler* compiler,
@@ -51,10 +76,29 @@ std::vector<std::unique_ptr<CodegenBackend>> GetCodegenBackendsForCuda(
5176
return backends;
5277
}
5378

79+
std::vector<std::unique_ptr<CodegenBackend>> GetFissionBackendsForCuda(
80+
stream_executor::StreamExecutor* stream_executor,
81+
const DebugOptions* debug_options, Compiler* compiler,
82+
const Compiler::TargetConfig* target_config,
83+
SymbolicExprContext* symbolic_expr_context) {
84+
std::vector<std::unique_ptr<CodegenBackend>> backends;
85+
backends.push_back(std::make_unique<FissionBackend>(
86+
debug_options, compiler, target_config,
87+
std::make_unique<CublasBackend>(stream_executor, debug_options, compiler,
88+
target_config),
89+
GetCublasRewriterPipeline(target_config->device_description),
90+
symbolic_expr_context));
91+
return backends;
92+
}
93+
5494
STREAM_EXECUTOR_REGISTER_OBJECT_STATICALLY(GetCodegenBackendsCudaRegistration,
5595
GetCodegenBackends,
5696
se::cuda::kCudaPlatformId,
5797
GetCodegenBackendsForCuda);
98+
STREAM_EXECUTOR_REGISTER_OBJECT_STATICALLY(GetFissionBackendsCudaRegistration,
99+
GetFissionBackends,
100+
se::cuda::kCudaPlatformId,
101+
GetFissionBackendsForCuda);
58102

59103
} // namespace gpu
60104
} // namespace xla

xla/backends/gpu/autotuner/factory_rocm.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,22 @@ std::vector<std::unique_ptr<CodegenBackend>> GetCodegenBackendsForROCm(
4545
return backends;
4646
}
4747

48+
std::vector<std::unique_ptr<CodegenBackend>> GetFissionBackendsForROCm(
49+
stream_executor::StreamExecutor* stream_executor,
50+
const DebugOptions* debug_options, Compiler* compiler,
51+
const Compiler::TargetConfig* target_config,
52+
SymbolicExprContext* symbolic_expr_context) {
53+
return {};
54+
}
55+
4856
STREAM_EXECUTOR_REGISTER_OBJECT_STATICALLY(GetCodegenBackendsROCmRegistration,
4957
GetCodegenBackends,
5058
se::rocm::kROCmPlatformId,
5159
GetCodegenBackendsForROCm);
60+
STREAM_EXECUTOR_REGISTER_OBJECT_STATICALLY(GetFissionBackendsROCmRegistration,
61+
GetFissionBackends,
62+
se::rocm::kROCmPlatformId,
63+
GetFissionBackendsForROCm);
5264

5365
} // namespace gpu
5466
} // namespace xla

0 commit comments

Comments
 (0)