Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -725,9 +725,11 @@ cc_library(
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
Expand All @@ -744,19 +746,27 @@ xla_test(
":shaped_slice",
":thunk",
"//xla:executable_run_options",
"//xla:shape_util",
"//xla/ffi",
"//xla/ffi:attribute_map",
"//xla/ffi:ffi_api",
"//xla/hlo/ir:hlo",
"//xla/service:buffer_assignment",
"//xla/service:custom_call_status_public_headers",
"//xla/service:custom_call_target_registry",
"//xla/service:executable",
"//xla/service:hlo_module_config",
"//xla/service:platform_util",
"//xla/service/gpu:buffer_allocations",
"//xla/service/gpu:resource_requests",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_memory_allocator",
"//xla/tsl/platform:statusor",
"//xla/tsl/util/proto:parse_text_proto",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down Expand Up @@ -2494,6 +2504,7 @@ tf_proto_library(
":shaped_slice_proto",
"//xla:xla_data_proto",
"//xla/core/host_offloading:host_offloading_executable_proto",
"//xla/ffi:attribute_map_proto",
"//xla/service:buffer_assignment_proto",
"//xla/service:hlo_proto",
"//xla/service/gpu:backend_configs",
Expand Down Expand Up @@ -2534,6 +2545,7 @@ cc_library(
":convolution_thunk",
":copy_thunk",
":cudnn_thunk",
":custom_call_thunk",
":dynamic_slice_thunk",
":fft_thunk",
":gemm_thunk",
Expand All @@ -2549,8 +2561,10 @@ cc_library(
":triangular_solve_thunk",
":wait_for_streams_thunk",
":while_thunk",
"//xla/hlo/ir:hlo",
"//xla/service:buffer_assignment",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand All @@ -2572,7 +2586,12 @@ xla_cc_test(
":thunk_proto_cc",
":thunk_proto_deserialization",
":while_thunk",
"//xla:shape_util",
"//xla/ffi",
"//xla/ffi:ffi_api",
"//xla/hlo/ir:hlo",
"//xla/service:buffer_assignment",
"//xla/service:hlo_module_config",
"//xla/tsl/platform:statusor",
"//xla/tsl/util/proto:parse_text_proto",
"//xla/tsl/util/proto:proto_matchers",
Expand Down
81 changes: 81 additions & 0 deletions xla/backends/gpu/runtime/custom_call_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/base/nullability.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -540,5 +542,84 @@ absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) {
return ExecuteCustomCall(params);
}

absl::StatusOr<ThunkProto> CustomCallThunk::ToProto() const {
if (!api_version_.has_value()) {
return absl::FailedPreconditionError(
"CustomCallThunk was created from a non-registered target and cannot "
"be serialized to a proto");
}

ThunkProto proto;
*proto.mutable_thunk_info() = thunk_info().ToProto();
proto.mutable_custom_call_thunk()->set_target_name(target_name_);
proto.mutable_custom_call_thunk()->set_opaque(opaque_);
proto.mutable_custom_call_thunk()->set_api_version(api_version_.value());
if (called_computation_ != nullptr) {
proto.mutable_custom_call_thunk()->set_called_computation(
called_computation_->name());
}

for (const NullableShapedSlice& operand : operands_) {
TF_ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_operands(),
operand.ToProto());
}

for (const NullableShapedSlice& result : results_) {
TF_ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_results(),
result.ToProto());
}

if (attributes_.has_value()) {
*proto.mutable_custom_call_thunk()->mutable_attributes() =
attributes_->ToProto();
}
return proto;
}

absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::FromProto(
ThunkInfo thunk_info, const CustomCallThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations,
const HloModule* absl_nullable hlo_module,
absl::string_view platform_name) {
if (hlo_module == nullptr && proto.has_called_computation()) {
return absl::InvalidArgumentError(
"HloModule is required to deserialize a CustomCallThunk with a "
"called computation");
}

std::vector<NullableShapedSlice> operands, results;
for (const auto& operand_proto : proto.operands()) {
TF_ASSIGN_OR_RETURN(
NullableShapedSlice operand,
NullableShapedSlice::FromProto(operand_proto, buffer_allocations));
operands.push_back(std::move(operand));
}
for (const auto& result_proto : proto.results()) {
TF_ASSIGN_OR_RETURN(
NullableShapedSlice result,
NullableShapedSlice::FromProto(result_proto, buffer_allocations));
results.push_back(std::move(result));
}
TF_ASSIGN_OR_RETURN(ffi::AttributesMap attributes,
ffi::AttributesMap::FromProto(proto.attributes()));

HloComputation* called_computation = nullptr;
if (proto.has_called_computation()) {
CHECK(hlo_module != nullptr); // This check is needed for static analysis.
called_computation =
hlo_module->GetComputationWithName(proto.called_computation());
if (called_computation == nullptr) {
return absl::InvalidArgumentError(absl::StrCat(
"HloComputation '", proto.called_computation(),
"' not found in the HloModule with name '", hlo_module->name(), "'"));
}
}

return CustomCallThunk::Create(std::move(thunk_info), proto.target_name(),
std::move(operands), std::move(results),
std::move(attributes), called_computation,
platform_name);
}

} // namespace gpu
} // namespace xla
10 changes: 10 additions & 0 deletions xla/backends/gpu/runtime/custom_call_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/backends/gpu/runtime/shaped_slice.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/executable_run_options.h"
Expand All @@ -40,6 +41,7 @@ limitations under the License.
#include "xla/ffi/ffi_api.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/runtime/object_pool.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/custom_call_status.h"
#include "xla/service/gpu/buffer_allocations.h"
#include "xla/stream_executor/device_memory_allocator.h"
Expand Down Expand Up @@ -146,6 +148,14 @@ class CustomCallThunk : public Thunk {

absl::string_view opaque() const { return opaque_; }

absl::StatusOr<ThunkProto> ToProto() const override;

static absl::StatusOr<std::unique_ptr<CustomCallThunk>> FromProto(
ThunkInfo thunk_info, const CustomCallThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations,
const HloModule* absl_nullable hlo_module,
absl::string_view platform_name);

private:
CustomCallThunk(ThunkInfo thunk_info, std::string target_name,
std::vector<NullableShapedSlice> operands,
Expand Down
124 changes: 124 additions & 0 deletions xla/backends/gpu/runtime/custom_call_thunk_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,47 @@ limitations under the License.

#include "xla/backends/gpu/runtime/custom_call_thunk.h"

#include <array>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/backends/gpu/runtime/shaped_slice.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/executable_run_options.h"
#include "xla/ffi/attribute_map.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/custom_call_status.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/gpu/buffer_allocations.h"
#include "xla/service/gpu/resource_requests.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/platform_util.h"
#include "xla/service/service_executable_run_options.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor_memory_allocator.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/util/proto/parse_text_proto.h"

namespace xla::gpu {
namespace {
Expand Down Expand Up @@ -321,5 +334,116 @@ TEST(CustomCallThunkTest, CustomCallWithOwnedHandlersWithoutExecute) {
StatusIs(absl::StatusCode::kInvalidArgument));
}

// A simple callback function that expects specific arguments.
absl::Status VerifyCallbackArguments(int my_attribute,
ffi::AnyBuffer my_operand,
ffi::Result<ffi::AnyBuffer> my_result,
const HloComputation* called_computation) {
EXPECT_EQ(my_attribute, 42);
EXPECT_EQ(my_operand.element_type(), xla::PrimitiveType::U8);
EXPECT_EQ(my_operand.device_memory().opaque(),
absl::bit_cast<void*>(static_cast<intptr_t>(0xDEADBEEF)));
EXPECT_EQ(my_result->element_type(), xla::PrimitiveType::U16);
EXPECT_EQ(my_result->device_memory().opaque(),
absl::bit_cast<void*>(static_cast<intptr_t>(0xABCDEF)));
EXPECT_EQ(called_computation->name(), "test_computation");
return absl::OkStatus();
}

XLA_FFI_DEFINE_HANDLER(kVerifyCallbackArguments, VerifyCallbackArguments,
ffi::Ffi::Bind()
.Attr<int>("my_attribute")
.Arg<ffi::AnyBuffer>()
.Ret<ffi::AnyBuffer>()
.Ctx<ffi::CalledComputation>(),
{ffi::Traits::kCmdBufferCompatible});

constexpr absl::string_view kVerifyCallbackArgumentsCustomCallName =
"__xla_test$$verify_callback_arguments";
constexpr absl::string_view kTestPlatformName = "TEST_PLATFORM";

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
kVerifyCallbackArgumentsCustomCallName,
kTestPlatformName, kVerifyCallbackArguments);

TEST(CustomCallThunkTest, ProtoConversion) {
TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
executor->CreateStream());

HloModuleConfig config;
HloModule hlo_module("test_module", config);
HloComputation::Builder builder("test_computation");
// This instruction is pretty arbitrary, we just need a non-empty computation.
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(U32, {42}), "parameter"));
hlo_module.AddEntryComputation(builder.Build());

BufferAllocation alloc0{0, 1024, 0};
BufferAllocation alloc1{1, 1024, 0};
ShapedSlice operand_slice{BufferAllocation::Slice{&alloc0, 0, 1024},
ShapeUtil::MakeShape(U8, {1024})};
ShapedSlice result_slice{BufferAllocation::Slice{&alloc1, 0, 1024},
ShapeUtil::MakeShape(U16, {512})};

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<CustomCallThunk> original_thunk,
CustomCallThunk::Create(
Thunk::ThunkInfo(),
/*target_name=*/std::string(kVerifyCallbackArgumentsCustomCallName),
/*operands=*/{operand_slice},
/*results=*/{result_slice}, /*attributes=*/{{"my_attribute", 42}},
hlo_module.entry_computation(),
/*platform_name=*/kTestPlatformName));
TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto());
ASSERT_TRUE(proto.has_custom_call_thunk());
original_thunk.reset();

std::array allocations = {alloc0, alloc1};

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<CustomCallThunk> new_thunk,
CustomCallThunk::FromProto(Thunk::ThunkInfo(), proto.custom_call_thunk(),
allocations, &hlo_module, kTestPlatformName));

se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations device_allocations(
{stream_executor::DeviceMemoryBase(
absl::bit_cast<void*>(static_cast<intptr_t>(0xDEADBEEF)), 1024),
stream_executor::DeviceMemoryBase(
absl::bit_cast<void*>(static_cast<intptr_t>(0xABCDEF)), 1024)},
0, &allocator);
Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(
ServiceExecutableRunOptions(), device_allocations,
/*stream=*/stream.get(),
/*command_buffer_trace_stream=*/stream.get(),
/*collective_params=*/nullptr,
/*collective_cliques=*/nullptr);
EXPECT_THAT(new_thunk->ExecuteOnStream(params), IsOk());
}

TEST(CustomCallThunkTest, DeserializationFailsWithMissingHloModule) {
CustomCallThunkProto proto =
tsl::proto_testing::ParseTextProtoOrDie<CustomCallThunkProto>(
R"pb(
target_name: "__xla_test$$verify_callback_arguments"
api_version: API_VERSION_TYPED_FFI
called_computation: "called_computation"
)pb");

HloModuleConfig config;
HloModule hlo_module("test_module", config);
HloComputation::Builder builder("not_called_computation");
// This instruction is pretty arbitrary, we just need a non-empty computation.
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(U32, {42}), "parameter"));
hlo_module.AddEntryComputation(builder.Build());

EXPECT_THAT(CustomCallThunk::FromProto(Thunk::ThunkInfo(), proto,
/*buffer_allocations=*/{}, &hlo_module,
/*platform_name=*/kTestPlatformName),
StatusIs(absl::StatusCode::kInvalidArgument));
}

} // namespace
} // namespace xla::gpu
4 changes: 3 additions & 1 deletion xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ void CheckProtoRoundTrip(const DynamicSliceThunk& thunk,
[](const ThunkProto& thunk_proto,
absl::Span<const BufferAllocation> fake_allocations_span)
-> absl::StatusOr<std::unique_ptr<Thunk>> {
return DeserializeThunkProto(thunk_proto, fake_allocations_span);
return DeserializeThunkProto(thunk_proto, fake_allocations_span,
/*hlo_module*/ nullptr,
/*platform_name=*/"TEST_PLATFORM");
};

TF_ASSERT_OK_AND_ASSIGN(
Expand Down
Loading
Loading