Skip to content

Commit 9a77a88

Browse files
EusebioDMGoogle-ML-Automation
authored andcommitted
Use Deserializer lambda for embedded thunks in DynamicSliceThunk
PiperOrigin-RevId: 826474606
1 parent faa892d commit 9a77a88

File tree

4 files changed

+18
-8
lines changed

4 files changed

+18
-8
lines changed

xla/backends/gpu/runtime/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ cc_library(
226226
":sequential_thunk",
227227
":thunk",
228228
":thunk_proto_cc",
229-
":thunk_proto_deserialization",
230229
"//xla:literal",
231230
"//xla:literal_util",
232231
"//xla:shape_util",
@@ -270,6 +269,7 @@ xla_test(
270269
":gemm_thunk",
271270
":sequential_thunk",
272271
":thunk",
272+
":thunk_proto_deserialization",
273273
"//xla:shape_util",
274274
"//xla:xla_data_proto_cc",
275275
"//xla/ffi",

xla/backends/gpu/runtime/dynamic_slice_thunk.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ limitations under the License.
4141
#include "xla/backends/gpu/runtime/sequential_thunk.h"
4242
#include "xla/backends/gpu/runtime/thunk.h"
4343
#include "xla/backends/gpu/runtime/thunk.pb.h"
44-
#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h"
4544
#include "xla/hlo/evaluator/hlo_evaluator.h"
4645
#include "xla/hlo/ir/hlo_module.h"
4746
#include "xla/literal.h"
@@ -617,7 +616,8 @@ absl::StatusOr<ThunkProto> DynamicSliceThunk::ToProto() const {
617616
absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
618617
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
619618
absl::Span<const BufferAllocation> buffer_allocations,
620-
absl::Span<const BufferAllocation> fake_allocations) {
619+
absl::Span<const BufferAllocation> fake_allocations,
620+
const Deserializer& deserializer) {
621621
// offset_as_function_of_indvar_metadata
622622
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
623623
offset_as_function_of_indvar_metadata;
@@ -677,9 +677,9 @@ absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
677677
// embedded_thunk
678678
std::vector<std::unique_ptr<Thunk>> embedded_thunks;
679679
for (const auto& thunk_proto : proto.embedded_thunk().thunks()) {
680-
TF_ASSIGN_OR_RETURN(auto thunk,
681-
DeserializeThunkProto(thunk_proto, fake_allocations));
682-
embedded_thunks.push_back(std::move(thunk));
680+
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> embedded_thunk,
681+
deserializer(thunk_proto));
682+
embedded_thunks.push_back(std::move(embedded_thunk));
683683
}
684684

685685
// leave fake_allocations empty, because we manage their lifetime outside

xla/backends/gpu/runtime/dynamic_slice_thunk.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,12 @@ class DynamicSliceThunk : public Thunk {
186186
// replaced during execution in `ExecuteOnStream` with the actual (dynamic)
187187
// slices. We have to create these outside of this method to manage their
188188
// lifetime correctly.
189+
// `deserializer`: The deserializer is used to deserialize the embedded thunk.
189190
static absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> FromProto(
190191
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
191192
absl::Span<const BufferAllocation> buffer_allocations,
192-
absl::Span<const BufferAllocation> fake_allocations);
193+
absl::Span<const BufferAllocation> fake_allocations,
194+
const Deserializer& deserializer);
193195

194196
std::optional<const OffsetAsFunctionOfIndvarModulesMetadata*>
195197
get_offset_function() const {

xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ limitations under the License.
3535
#include "xla/backends/gpu/runtime/gemm_thunk.h"
3636
#include "xla/backends/gpu/runtime/sequential_thunk.h"
3737
#include "xla/backends/gpu/runtime/thunk.h"
38+
#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h"
3839
#include "xla/ffi/attribute_map.h"
3940
#include "xla/ffi/ffi.h"
4041
#include "xla/ffi/ffi_api.h"
@@ -108,11 +109,18 @@ void CheckProtoRoundTrip(const DynamicSliceThunk& thunk,
108109
BufferAllocation(i, arguments[i].value().allocation()->size(), 0));
109110
}
110111
}
112+
113+
Thunk::Deserializer deserializer =
114+
[&buffer_allocations](const ThunkProto& thunk_proto)
115+
-> absl::StatusOr<std::unique_ptr<Thunk>> {
116+
return DeserializeThunkProto(thunk_proto, buffer_allocations);
117+
};
111118
TF_ASSERT_OK_AND_ASSIGN(
112119
auto thunk_from_proto,
113120
DynamicSliceThunk::FromProto(Thunk::ThunkInfo(), proto,
114121
/*buffer_allocations=*/buffer_allocations,
115-
/*fake_allocations=*/fake_allocations_span));
122+
/*fake_allocations=*/fake_allocations_span,
123+
deserializer));
116124
TF_ASSERT_OK_AND_ASSIGN(auto proto_roundtrip, thunk_from_proto->ToProto());
117125
auto dynamic_slice_thunk_proto_roundtrip =
118126
proto_roundtrip.dynamic_slice_thunk();

0 commit comments

Comments
 (0)