From b3db957ae86ce402ccba6fccf7a3dde4f655ea77 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 4 Nov 2025 07:36:27 -0800 Subject: [PATCH] Add proto serialization for the CustomCallThunk PiperOrigin-RevId: 827967159 --- xla/backends/gpu/runtime/BUILD | 19 +++ xla/backends/gpu/runtime/custom_call_thunk.cc | 81 +++++++++ xla/backends/gpu/runtime/custom_call_thunk.h | 10 ++ .../gpu/runtime/custom_call_thunk_test.cc | 124 ++++++++++++++ .../gpu/runtime/dynamic_slice_thunk_test.cc | 4 +- xla/backends/gpu/runtime/thunk.proto | 14 ++ .../runtime/thunk_proto_deserialization.cc | 44 +++-- .../gpu/runtime/thunk_proto_deserialization.h | 6 +- .../thunk_proto_deserialization_test.cc | 159 +++++++++++++++--- xla/service/gpu/gpu_executable.cc | 6 +- xla/service/gpu/gpu_executable.h | 3 +- xla/service/gpu/gpu_executable_test.cc | 2 +- 12 files changed, 428 insertions(+), 44 deletions(-) diff --git a/xla/backends/gpu/runtime/BUILD b/xla/backends/gpu/runtime/BUILD index 1d0c7ee31dd39..320c7c8b99aee 100644 --- a/xla/backends/gpu/runtime/BUILD +++ b/xla/backends/gpu/runtime/BUILD @@ -725,9 +725,11 @@ cc_library( "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", @@ -744,19 +746,27 @@ xla_test( ":shaped_slice", ":thunk", "//xla:executable_run_options", + "//xla:shape_util", "//xla/ffi", + "//xla/ffi:attribute_map", "//xla/ffi:ffi_api", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", "//xla/service:custom_call_status_public_headers", "//xla/service:custom_call_target_registry", "//xla/service:executable", + "//xla/service:hlo_module_config", "//xla/service:platform_util", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:resource_requests", + "//xla/stream_executor:device_memory", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tsl/platform:statusor", + "//xla/tsl/util/proto:parse_text_proto", + "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -2494,6 +2504,7 @@ tf_proto_library( ":shaped_slice_proto", "//xla:xla_data_proto", "//xla/core/host_offloading:host_offloading_executable_proto", + "//xla/ffi:attribute_map_proto", "//xla/service:buffer_assignment_proto", "//xla/service:hlo_proto", "//xla/service/gpu:backend_configs", @@ -2534,6 +2545,7 @@ cc_library( ":convolution_thunk", ":copy_thunk", ":cudnn_thunk", + ":custom_call_thunk", ":dynamic_slice_thunk", ":fft_thunk", ":gemm_thunk", @@ -2549,8 +2561,10 @@ cc_library( ":triangular_solve_thunk", ":wait_for_streams_thunk", ":while_thunk", + "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -2572,7 +2586,12 @@ xla_cc_test( ":thunk_proto_cc", ":thunk_proto_deserialization", ":while_thunk", + "//xla:shape_util", + "//xla/ffi", + "//xla/ffi:ffi_api", + "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", + "//xla/service:hlo_module_config", "//xla/tsl/platform:statusor", "//xla/tsl/util/proto:parse_text_proto", "//xla/tsl/util/proto:proto_matchers", diff --git a/xla/backends/gpu/runtime/custom_call_thunk.cc b/xla/backends/gpu/runtime/custom_call_thunk.cc index ebf0426e32c00..4d06e2418559c 100644 --- a/xla/backends/gpu/runtime/custom_call_thunk.cc +++ b/xla/backends/gpu/runtime/custom_call_thunk.cc @@ -28,10 +28,12 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/nullability.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -540,5 +542,84 @@ absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { return ExecuteCustomCall(params); } +absl::StatusOr CustomCallThunk::ToProto() const { + if (!api_version_.has_value()) { + return absl::FailedPreconditionError( + "CustomCallThunk was created from a non-registered target and cannot " + "be serialized to a proto"); + } + + ThunkProto proto; + *proto.mutable_thunk_info() = thunk_info().ToProto(); + proto.mutable_custom_call_thunk()->set_target_name(target_name_); + proto.mutable_custom_call_thunk()->set_opaque(opaque_); + proto.mutable_custom_call_thunk()->set_api_version(api_version_.value()); + if (called_computation_ != nullptr) { + proto.mutable_custom_call_thunk()->set_called_computation( + called_computation_->name()); + } + + for (const NullableShapedSlice& operand : operands_) { + TF_ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_operands(), + operand.ToProto()); + } + + for (const NullableShapedSlice& result : results_) { + TF_ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_results(), + result.ToProto()); + } + + if (attributes_.has_value()) { + *proto.mutable_custom_call_thunk()->mutable_attributes() = + attributes_->ToProto(); + } + return proto; +} + +absl::StatusOr> CustomCallThunk::FromProto( + ThunkInfo thunk_info, const CustomCallThunkProto& proto, + absl::Span buffer_allocations, + const HloModule* absl_nullable hlo_module, + absl::string_view platform_name) { + if (hlo_module == nullptr && proto.has_called_computation()) { + return absl::InvalidArgumentError( + "HloModule is required to deserialize a CustomCallThunk with a " + "called computation"); + } + + std::vector operands, results; + for (const auto& operand_proto : proto.operands()) { + TF_ASSIGN_OR_RETURN( + NullableShapedSlice operand, + NullableShapedSlice::FromProto(operand_proto, buffer_allocations)); + operands.push_back(std::move(operand)); + } + for (const auto& result_proto : proto.results()) { + TF_ASSIGN_OR_RETURN( + NullableShapedSlice result, + NullableShapedSlice::FromProto(result_proto, buffer_allocations)); + results.push_back(std::move(result)); + } + TF_ASSIGN_OR_RETURN(ffi::AttributesMap attributes, + ffi::AttributesMap::FromProto(proto.attributes())); + + HloComputation* called_computation = nullptr; + if (proto.has_called_computation()) { + CHECK(hlo_module != nullptr); // This check is needed for static analysis. + called_computation = + hlo_module->GetComputationWithName(proto.called_computation()); + if (called_computation == nullptr) { + return absl::InvalidArgumentError(absl::StrCat( + "HloComputation '", proto.called_computation(), + "' not found in the HloModule with name '", hlo_module->name(), "'")); + } + } + + return CustomCallThunk::Create(std::move(thunk_info), proto.target_name(), + std::move(operands), std::move(results), + std::move(attributes), called_computation, + platform_name); +} + } // namespace gpu } // namespace xla diff --git a/xla/backends/gpu/runtime/custom_call_thunk.h b/xla/backends/gpu/runtime/custom_call_thunk.h index b5c15f946665f..7324bd1666738 100644 --- a/xla/backends/gpu/runtime/custom_call_thunk.h +++ b/xla/backends/gpu/runtime/custom_call_thunk.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/backends/gpu/runtime/shaped_slice.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/executable_run_options.h" @@ -40,6 +41,7 @@ limitations under the License. #include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/runtime/object_pool.h" +#include "xla/service/buffer_assignment.h" #include "xla/service/custom_call_status.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/stream_executor/device_memory_allocator.h" @@ -146,6 +148,14 @@ class CustomCallThunk : public Thunk { absl::string_view opaque() const { return opaque_; } + absl::StatusOr ToProto() const override; + + static absl::StatusOr> FromProto( + ThunkInfo thunk_info, const CustomCallThunkProto& proto, + absl::Span buffer_allocations, + const HloModule* absl_nullable hlo_module, + absl::string_view platform_name); + private: CustomCallThunk(ThunkInfo thunk_info, std::string target_name, std::vector operands, diff --git a/xla/backends/gpu/runtime/custom_call_thunk_test.cc b/xla/backends/gpu/runtime/custom_call_thunk_test.cc index 00a8e15a53211..29199825bb96a 100644 --- a/xla/backends/gpu/runtime/custom_call_thunk_test.cc +++ b/xla/backends/gpu/runtime/custom_call_thunk_test.cc @@ -15,34 +15,47 @@ limitations under the License. #include "xla/backends/gpu/runtime/custom_call_thunk.h" +#include #include +#include #include #include #include #include #include +#include "absl/base/casts.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/backends/gpu/runtime/shaped_slice.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/executable_run_options.h" +#include "xla/ffi/attribute_map.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/buffer_assignment.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/resource_requests.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/platform_util.h" #include "xla/service/service_executable_run_options.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tsl/platform/statusor.h" +#include "xla/tsl/util/proto/parse_text_proto.h" namespace xla::gpu { namespace { @@ -321,5 +334,116 @@ TEST(CustomCallThunkTest, CustomCallWithOwnedHandlersWithoutExecute) { StatusIs(absl::StatusCode::kInvalidArgument)); } +// A simple callback function that expects specific arguments. +absl::Status VerifyCallbackArguments(int my_attribute, + ffi::AnyBuffer my_operand, + ffi::Result my_result, + const HloComputation* called_computation) { + EXPECT_EQ(my_attribute, 42); + EXPECT_EQ(my_operand.element_type(), xla::PrimitiveType::U8); + EXPECT_EQ(my_operand.device_memory().opaque(), + absl::bit_cast(static_cast(0xDEADBEEF))); + EXPECT_EQ(my_result->element_type(), xla::PrimitiveType::U16); + EXPECT_EQ(my_result->device_memory().opaque(), + absl::bit_cast(static_cast(0xABCDEF))); + EXPECT_EQ(called_computation->name(), "test_computation"); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kVerifyCallbackArguments, VerifyCallbackArguments, + ffi::Ffi::Bind() + .Attr("my_attribute") + .Arg() + .Ret() + .Ctx(), + {ffi::Traits::kCmdBufferCompatible}); + +constexpr absl::string_view kVerifyCallbackArgumentsCustomCallName = + "__xla_test$$verify_callback_arguments"; +constexpr absl::string_view kTestPlatformName = "TEST_PLATFORM"; + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), + kVerifyCallbackArgumentsCustomCallName, + kTestPlatformName, kVerifyCallbackArguments); + +TEST(CustomCallThunkTest, ProtoConversion) { + TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + + HloModuleConfig config; + HloModule hlo_module("test_module", config); + HloComputation::Builder builder("test_computation"); + // This instruction is pretty arbitrary, we just need a non-empty computation. + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(U32, {42}), "parameter")); + hlo_module.AddEntryComputation(builder.Build()); + + BufferAllocation alloc0{0, 1024, 0}; + BufferAllocation alloc1{1, 1024, 0}; + ShapedSlice operand_slice{BufferAllocation::Slice{&alloc0, 0, 1024}, + ShapeUtil::MakeShape(U8, {1024})}; + ShapedSlice result_slice{BufferAllocation::Slice{&alloc1, 0, 1024}, + ShapeUtil::MakeShape(U16, {512})}; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr original_thunk, + CustomCallThunk::Create( + Thunk::ThunkInfo(), + /*target_name=*/std::string(kVerifyCallbackArgumentsCustomCallName), + /*operands=*/{operand_slice}, + /*results=*/{result_slice}, /*attributes=*/{{"my_attribute", 42}}, + hlo_module.entry_computation(), + /*platform_name=*/kTestPlatformName)); + TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); + ASSERT_TRUE(proto.has_custom_call_thunk()); + original_thunk.reset(); + + std::array allocations = {alloc0, alloc1}; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr new_thunk, + CustomCallThunk::FromProto(Thunk::ThunkInfo(), proto.custom_call_thunk(), + allocations, &hlo_module, kTestPlatformName)); + + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations device_allocations( + {stream_executor::DeviceMemoryBase( + absl::bit_cast(static_cast(0xDEADBEEF)), 1024), + stream_executor::DeviceMemoryBase( + absl::bit_cast(static_cast(0xABCDEF)), 1024)}, + 0, &allocator); + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + ServiceExecutableRunOptions(), device_allocations, + /*stream=*/stream.get(), + /*command_buffer_trace_stream=*/stream.get(), + /*collective_params=*/nullptr, + /*collective_cliques=*/nullptr); + EXPECT_THAT(new_thunk->ExecuteOnStream(params), IsOk()); +} + +TEST(CustomCallThunkTest, DeserializationFailsWithMissingHloModule) { + CustomCallThunkProto proto = + tsl::proto_testing::ParseTextProtoOrDie( + R"pb( + target_name: "__xla_test$$verify_callback_arguments" + api_version: API_VERSION_TYPED_FFI + called_computation: "called_computation" + )pb"); + + HloModuleConfig config; + HloModule hlo_module("test_module", config); + HloComputation::Builder builder("not_called_computation"); + // This instruction is pretty arbitrary, we just need a non-empty computation. + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(U32, {42}), "parameter")); + hlo_module.AddEntryComputation(builder.Build()); + + EXPECT_THAT(CustomCallThunk::FromProto(Thunk::ThunkInfo(), proto, + /*buffer_allocations=*/{}, &hlo_module, + /*platform_name=*/kTestPlatformName), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + } // namespace } // namespace xla::gpu diff --git a/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc b/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc index 255a32302b0d9..9ee4f097b72ac 100644 --- a/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc +++ b/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc @@ -115,7 +115,9 @@ void CheckProtoRoundTrip(const DynamicSliceThunk& thunk, [](const ThunkProto& thunk_proto, absl::Span fake_allocations_span) -> absl::StatusOr> { - return DeserializeThunkProto(thunk_proto, fake_allocations_span); + return DeserializeThunkProto(thunk_proto, fake_allocations_span, + /*hlo_module*/ nullptr, + /*platform_name=*/"TEST_PLATFORM"); }; TF_ASSERT_OK_AND_ASSIGN( diff --git a/xla/backends/gpu/runtime/thunk.proto b/xla/backends/gpu/runtime/thunk.proto index 689f3292a80af..d78f66fd11ca1 100644 --- a/xla/backends/gpu/runtime/thunk.proto +++ b/xla/backends/gpu/runtime/thunk.proto @@ -21,6 +21,7 @@ import "xla/backends/gpu/runtime/convolution_filter_thunk.proto"; import "xla/backends/gpu/runtime/dynamic_slice_thunk.proto"; import "xla/backends/gpu/runtime/shaped_slice.proto"; import "xla/core/host_offloading/host_offloading_executable.proto"; +import "xla/ffi/attribute_map.proto"; import "xla/service/buffer_assignment.proto"; import "xla/service/gpu/gpu_conv_runner.proto"; import "xla/service/gpu/gpu_norm_runner.proto"; @@ -248,6 +249,18 @@ message FftThunkProto { xla.ShapeProto output_shape = 6; } +message CustomCallThunkProto { + string target_name = 1; + repeated NullableShapedSliceProto operands = 2; + repeated NullableShapedSliceProto results = 3; + string opaque = 4; + CustomCallApiVersion api_version = 5; + xla.ffi.AttributesMapProto attributes = 6; + // The name of the called computation. It needs to match the HloCompuation in + // the HloModule that is used to deserialize the thunk. + optional string called_computation = 7; +} + message ThunkProto { ThunkInfoProto thunk_info = 1; @@ -280,6 +293,7 @@ message ThunkProto { FftThunkProto fft_thunk = 27; CholeskyThunkProto cholesky_thunk = 28; Memset32BitValueThunkProto memset32bit_value_thunk = 29; + CustomCallThunkProto custom_call_thunk = 30; } } diff --git a/xla/backends/gpu/runtime/thunk_proto_deserialization.cc b/xla/backends/gpu/runtime/thunk_proto_deserialization.cc index b47370eb95523..61425b15e1621 100644 --- a/xla/backends/gpu/runtime/thunk_proto_deserialization.cc +++ b/xla/backends/gpu/runtime/thunk_proto_deserialization.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/base/nullability.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -32,6 +33,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/convolution_thunk.h" #include "xla/backends/gpu/runtime/copy_thunk.h" #include "xla/backends/gpu/runtime/cudnn_thunk.h" +#include "xla/backends/gpu/runtime/custom_call_thunk.h" #include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" #include "xla/backends/gpu/runtime/fft_thunk.h" #include "xla/backends/gpu/runtime/gemm_thunk.h" @@ -47,6 +49,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/triangular_solve_thunk.h" #include "xla/backends/gpu/runtime/wait_for_streams_thunk.h" #include "xla/backends/gpu/runtime/while_thunk.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" #include "xla/tsl/platform/statusor.h" @@ -72,15 +75,19 @@ static std::optional GetStoredThunkTypeName( absl::StatusOr> DeserializeThunkProto( const ThunkProto& thunk_proto, - absl::Span buffer_allocations) { + absl::Span buffer_allocations, + const HloModule* absl_nullable hlo_module, + absl::string_view platform_name) { TF_ASSIGN_OR_RETURN(Thunk::ThunkInfo thunk_info, Thunk::ThunkInfo::FromProto(thunk_proto.thunk_info())); + auto deserializer = [&buffer_allocations, &hlo_module, + &platform_name](const ThunkProto& thunk_proto) { + return DeserializeThunkProto(thunk_proto, buffer_allocations, hlo_module, + platform_name); + }; switch (thunk_proto.impl_case()) { case ThunkProto::kSequentialThunk: { - auto deserializer = [&buffer_allocations](const ThunkProto& thunk_proto) { - return DeserializeThunkProto(thunk_proto, buffer_allocations); - }; return SequentialThunk::FromProto( std::move(thunk_info), thunk_proto.sequential_thunk(), deserializer); } @@ -100,18 +107,13 @@ absl::StatusOr> DeserializeThunkProto( std::move(thunk_info), thunk_proto.device_to_device_copy_thunk(), buffer_allocations); case ThunkProto::kWhileThunk: - return WhileThunk::FromProto( - std::move(thunk_info), thunk_proto.while_thunk(), buffer_allocations, - [&buffer_allocations](const ThunkProto& thunk_proto) { - return DeserializeThunkProto(thunk_proto, buffer_allocations); - }); + return WhileThunk::FromProto(std::move(thunk_info), + thunk_proto.while_thunk(), + buffer_allocations, deserializer); case ThunkProto::kConditionalThunk: - return ConditionalThunk::FromProto( - std::move(thunk_info), thunk_proto.conditional_thunk(), - buffer_allocations, - [&buffer_allocations](const ThunkProto& thunk_proto) { - return DeserializeThunkProto(thunk_proto, buffer_allocations); - }); + return ConditionalThunk::FromProto(std::move(thunk_info), + thunk_proto.conditional_thunk(), + buffer_allocations, deserializer); case ThunkProto::kGemmThunk: return GemmThunk::FromProto(std::move(thunk_info), thunk_proto.gemm_thunk(), buffer_allocations); @@ -170,14 +172,20 @@ absl::StatusOr> DeserializeThunkProto( buffer_allocations); case ThunkProto::kDynamicSliceThunk: { auto deserializer = - [](const ThunkProto& thunk_proto, - absl::Span custom_allocations) { - return DeserializeThunkProto(thunk_proto, custom_allocations); + [hlo_module, platform_name]( + const ThunkProto& thunk_proto, + absl::Span custom_allocations) { + return DeserializeThunkProto(thunk_proto, custom_allocations, + hlo_module, platform_name); }; return DynamicSliceThunk::FromProto(std::move(thunk_info), thunk_proto.dynamic_slice_thunk(), buffer_allocations, deserializer); } + case ThunkProto::kCustomCallThunk: + return CustomCallThunk::FromProto( + std::move(thunk_info), thunk_proto.custom_call_thunk(), + buffer_allocations, hlo_module, platform_name); default: std::optional unsupported_thunk_type = GetStoredThunkTypeName(thunk_proto); diff --git a/xla/backends/gpu/runtime/thunk_proto_deserialization.h b/xla/backends/gpu/runtime/thunk_proto_deserialization.h index a68b6ed9fa60f..729c09847cd1e 100644 --- a/xla/backends/gpu/runtime/thunk_proto_deserialization.h +++ b/xla/backends/gpu/runtime/thunk_proto_deserialization.h @@ -18,10 +18,13 @@ limitations under the License. #include +#include "absl/base/nullability.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" namespace xla::gpu { @@ -29,7 +32,8 @@ namespace xla::gpu { // Deserializes the given `thunk_proto` into a Thunk. absl::StatusOr> DeserializeThunkProto( const ThunkProto& thunk_proto, - absl::Span buffer_allocations); + absl::Span buffer_allocations, + const HloModule* absl_nullable hlo_module, absl::string_view platform_name); } // namespace xla::gpu diff --git a/xla/backends/gpu/runtime/thunk_proto_deserialization_test.cc b/xla/backends/gpu/runtime/thunk_proto_deserialization_test.cc index 62bc3255dcd48..bbb19ad3af9bf 100644 --- a/xla/backends/gpu/runtime/thunk_proto_deserialization_test.cc +++ b/xla/backends/gpu/runtime/thunk_proto_deserialization_test.cc @@ -30,7 +30,14 @@ limitations under the License. #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/backends/gpu/runtime/while_thunk.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape_util.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/util/proto/parse_text_proto.h" #include "xla/tsl/util/proto/proto_matchers.h" @@ -46,6 +53,8 @@ using ::tsl::proto_testing::EqualsProto; using ::tsl::proto_testing::ParseTextProtoOrDie; using Kind = Thunk::Kind; +constexpr absl::string_view kTestPlatformName = "TEST_PLATFORM"; + TEST(ThunkProtoDeserializationTest, SequentialThunkChain) { constexpr ExecutionStreamId kExecutionStreamId{123}; constexpr absl::string_view kProfileAnnotation = "profile_annotation"; @@ -63,8 +72,10 @@ TEST(ThunkProtoDeserializationTest, SequentialThunkChain) { SequentialThunk outer_thunk(thunk_info, std::move(thunk_sequence)); TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, outer_thunk.ToProto()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr new_thunk, - DeserializeThunkProto(proto, {})); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr new_thunk, + DeserializeThunkProto(proto, /*buffer_allocations=*/{}, + /*hlo_module=*/nullptr, kTestPlatformName)); EXPECT_THAT(new_thunk.get(), WhenDynamicCastTo(Property( @@ -91,8 +102,10 @@ TEST(ThunkProtoDeserializationTest, CopyThunk) { BufferAllocation(/*index=*/0, /*size=*/1024, /*color=*/0), BufferAllocation(/*index=*/1, /*size=*/1024, /*color=*/0)}; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr thunk, - DeserializeThunkProto(proto, buffer_allocations)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr thunk, + DeserializeThunkProto(proto, buffer_allocations, /*hlo_module=*/nullptr, + kTestPlatformName)); auto* copy_thunk = dynamic_cast(thunk.get()); ASSERT_NE(copy_thunk, nullptr); // Check the cast succeeded TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, copy_thunk->ToProto()); @@ -123,8 +136,10 @@ TEST(ThunkProtoDeserializationTest, DeviceToHostCopyThunk) { BufferAllocation(/*index=*/0, /*size=*/1024, /*color=*/0), BufferAllocation(/*index=*/1, /*size=*/1024, /*color=*/0)}; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr thunk, - DeserializeThunkProto(proto, buffer_allocations)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr thunk, + DeserializeThunkProto(proto, buffer_allocations, /*hlo_module=*/nullptr, + kTestPlatformName)); auto* copy_thunk = dynamic_cast(thunk.get()); ASSERT_NE(copy_thunk, nullptr); // Check the cast succeeded TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, copy_thunk->ToProto()); @@ -155,8 +170,10 @@ TEST(ThunkProtoDeserializationTest, HostToDeviceCopyThunk) { BufferAllocation(/*index=*/0, /*size=*/1024, /*color=*/0), BufferAllocation(/*index=*/1, /*size=*/1024, /*color=*/0)}; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr thunk, - DeserializeThunkProto(proto, buffer_allocations)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr thunk, + DeserializeThunkProto(proto, buffer_allocations, /*hlo_module=*/nullptr, + kTestPlatformName)); auto* copy_thunk = dynamic_cast(thunk.get()); ASSERT_NE(copy_thunk, nullptr); // Check the cast succeeded TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, copy_thunk->ToProto()); @@ -187,8 +204,10 @@ TEST(ThunkProtoDeserializationTest, DeviceToDeviceCopyThunk) { BufferAllocation(/*index=*/0, /*size=*/1024, /*color=*/0), BufferAllocation(/*index=*/1, /*size=*/1024, /*color=*/0)}; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr thunk, - DeserializeThunkProto(proto, buffer_allocations)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr thunk, + DeserializeThunkProto(proto, buffer_allocations, /*hlo_module=*/nullptr, + kTestPlatformName)); auto* copy_thunk = dynamic_cast(thunk.get()); ASSERT_NE(copy_thunk, nullptr); // Check the cast succeeded TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, copy_thunk->ToProto()); @@ -264,8 +283,10 @@ TEST(ThunkProtoDeserializationTest, WhileThunk) { BufferAllocation(/*index=*/4, /*size=*/1024, /*color=*/0), BufferAllocation(/*index=*/5, /*size=*/1024, /*color=*/0)}; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr athunk, - DeserializeThunkProto(proto, buffer_allocations)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr athunk, + DeserializeThunkProto(proto, buffer_allocations, /*hlo_module=*/nullptr, + kTestPlatformName)); auto* thunk = dynamic_cast(athunk.get()); ASSERT_NE(thunk, nullptr); TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, thunk->ToProto()); @@ -353,8 +374,10 @@ TEST(ThunkProtoDeserializationTest, ConditionalThunk) { BufferAllocation(/*index=*/4, /*size=*/1024, /*color=*/0), BufferAllocation(/*index=*/5, /*size=*/1024, /*color=*/0)}; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr athunk, - DeserializeThunkProto(proto, buffer_allocations)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr athunk, + DeserializeThunkProto(proto, buffer_allocations, /*hlo_module=*/nullptr, + kTestPlatformName)); auto* thunk = dynamic_cast(athunk.get()); ASSERT_NE(thunk, nullptr); TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, thunk->ToProto()); @@ -370,7 +393,8 @@ TEST(ThunkProtoDeserializationTest, WaitForStreamsThunk) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, - DeserializeThunkProto(proto, /*buffer_allocations=*/{})); + DeserializeThunkProto(proto, /*buffer_allocations=*/{}, + /*hlo_module=*/nullptr, kTestPlatformName)); TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, thunk->ToProto()); EXPECT_THAT(round_trip_proto, EqualsProto(proto)); @@ -391,8 +415,10 @@ TEST(ThunkProtoDeserializationTest, CudnnThunk) { BufferAllocation(/*index=*/1, /*size=*/1024, /*color=*/0), }; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr thunk, - DeserializeThunkProto(proto, buffer_allocations)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr thunk, + DeserializeThunkProto(proto, buffer_allocations, /*hlo_module=*/nullptr, + kTestPlatformName)); TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, thunk->ToProto()); EXPECT_THAT(round_trip_proto, EqualsProto(proto)); @@ -457,8 +483,100 @@ TEST(ThunkProtoDeserializationTest, CublasLtMatmulThunk) { BufferAllocation(/*index=*/5, /*size=*/161600, /*color=*/0), }; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr thunk, - DeserializeThunkProto(proto, allocations)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr thunk, + DeserializeThunkProto(proto, allocations, /*hlo_module=*/nullptr, + kTestPlatformName)); + + TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, thunk->ToProto()); + EXPECT_THAT(round_trip_proto, EqualsProto(proto)); +} + +XLA_FFI_DEFINE_HANDLER(kSimpleCustomCall, []() { return absl::OkStatus(); }, + ffi::Ffi::Bind(), {ffi::Traits::kCmdBufferCompatible}); + +constexpr absl::string_view kSimpleCustomCallName = + "__xla_test$$simple_custom_call"; + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), kSimpleCustomCallName, + "TEST_PLATFORM", kSimpleCustomCall); + +TEST(ThunkProtoDeserializationTest, CustomCallThunk) { + ThunkProto proto = ParseTextProtoOrDie( + R"pb( + thunk_info { execution_stream_id: 7 } + custom_call_thunk { + target_name: "__xla_test$$simple_custom_call" + operands { + shaped_slice { + slice { buffer_allocation_index: 0 } + shape { + dimensions: 42 + element_type: S32 + is_dynamic_dimension: false + } + } + } + operands { + shaped_slice { + slice { buffer_allocation_index: 1 } + shape { + dimensions: 42 + element_type: S32 + is_dynamic_dimension: false + } + } + } + results { + shaped_slice { + slice { buffer_allocation_index: 2 } + shape { + dimensions: 42 + element_type: S32 + is_dynamic_dimension: false + } + } + } + results { + shaped_slice { + slice { buffer_allocation_index: 3 } + shape { + dimensions: 42 + element_type: S32 + is_dynamic_dimension: false + } + } + } + api_version: API_VERSION_TYPED_FFI + attributes { + attrs { + key: "my_attribute" + value { scalar { i32: 42 } } + } + } + called_computation: "called_computation" + } + )pb"); + std::vector buffer_allocations = { + BufferAllocation(/*index=*/0, /*size=*/1024, /*color=*/0), + BufferAllocation(/*index=*/1, /*size=*/1024, /*color=*/0), + BufferAllocation(/*index=*/2, /*size=*/1024, /*color=*/0), + BufferAllocation(/*index=*/3, /*size=*/1024, /*color=*/0), + }; + + HloModuleConfig config; + HloModule hlo_module("test_module", config); + HloComputation::Builder builder("called_computation"); + // This instruction is pretty arbitrary, we just need a non-empty computation. + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(U32, {42}), "parameter")); + hlo_module.AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr thunk, + DeserializeThunkProto(proto, buffer_allocations, &hlo_module, + kTestPlatformName)); + TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, thunk->ToProto()); EXPECT_THAT(round_trip_proto, EqualsProto(proto)); } @@ -469,7 +587,8 @@ TEST(ThunkProtoDeserializationTest, EmptyThunkImplReturnsAnError) { thunk_info { execution_stream_id: 7 } )pb"); - EXPECT_THAT(DeserializeThunkProto(proto, /*buffer_allocations=*/{}), + EXPECT_THAT(DeserializeThunkProto(proto, /*buffer_allocations=*/{}, + /*hlo_module=*/nullptr, kTestPlatformName), absl_testing::StatusIs(absl::StatusCode::kInvalidArgument)); } diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc index d747602057959..6af5bd4d648ce 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc @@ -1221,7 +1221,8 @@ absl::StatusOr GpuExecutable::ToProto() const { absl::StatusOr> GpuExecutable::FromProto( const GpuExecutableProto& proto, - const se::DeviceDescription& device_description) { + const se::DeviceDescription& device_description, + absl::string_view platform_name) { Params params; params.enable_debug_info_manager = false; params.asm_text = proto.asm_text(); @@ -1263,7 +1264,8 @@ absl::StatusOr> GpuExecutable::FromProto( TF_ASSIGN_OR_RETURN( std::unique_ptr thunk, - DeserializeThunkProto(proto.thunk(), params.mlir_allocations.value())); + DeserializeThunkProto(proto.thunk(), params.mlir_allocations.value(), + params.debug_module.get(), platform_name)); if (dynamic_cast(thunk.get()) == nullptr) { return absl::InvalidArgumentError( diff --git a/xla/service/gpu/gpu_executable.h b/xla/service/gpu/gpu_executable.h index 60333aa1811a5..52fa7b468bedf 100644 --- a/xla/service/gpu/gpu_executable.h +++ b/xla/service/gpu/gpu_executable.h @@ -218,7 +218,8 @@ class GpuExecutable : public Executable { static absl::StatusOr> FromProto( const GpuExecutableProto&, - const se::DeviceDescription& device_description); + const se::DeviceDescription& device_description, + absl::string_view platform); absl::StatusOr ToProto() const; diff --git a/xla/service/gpu/gpu_executable_test.cc b/xla/service/gpu/gpu_executable_test.cc index 881065fe91985..9f5c12b29e2fe 100644 --- a/xla/service/gpu/gpu_executable_test.cc +++ b/xla/service/gpu/gpu_executable_test.cc @@ -474,7 +474,7 @@ TEST(GpuExecutableTest, ProtoConversion) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr reconstructed_executable, - GpuExecutable::FromProto(proto, device_description)); + GpuExecutable::FromProto(proto, device_description, "TEST_PLATFORM")); EXPECT_THAT(reconstructed_executable->text(), "test_asm_text"); EXPECT_THAT(reconstructed_executable->binary(), ElementsAre(1, 2, 3)); EXPECT_THAT(