Skip to content

Commit 54444da

Browse files
Optimize SparseCore input preprocessing by eliminating buffer copies.
* Introducing `OutputBufferViews`, allowing the caller to pass views of the destination buffers. * Modifying `CsrArraysPerHost` to optionally use `Eigen::Map` to wrap these buffers when provided. We can avoid populating large CSR arrays in the preprocessing return values and skip the data copy step. PiperOrigin-RevId: 831179831
1 parent e8aeb14 commit 54444da

File tree

6 files changed

+192
-27
lines changed

6 files changed

+192
-27
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ cc_library(
5757
":partitioned_coo_tensors",
5858
"@com_google_absl//absl/base:core_headers",
5959
"@com_google_absl//absl/base:nullability",
60+
"@com_google_absl//absl/container:flat_hash_map",
6061
"@com_google_absl//absl/log",
6162
"@com_google_absl//absl/log:check",
6263
"@com_google_absl//absl/strings:string_view",

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,26 @@ void CheckDeviceBatchSize(int batch_size_for_device, int num_sc_per_device,
113113
batch_size_for_device, stacked_table_name, num_sc_per_device);
114114
}
115115

116+
CsrArraysPerHost CreateCsrArraysPerHost(
117+
absl::string_view name,
118+
const PreprocessSparseDenseMatmulInputOptions& options,
119+
int coo_buffer_size_per_device, int row_pointers_size_per_bucket) {
120+
const int row_pointers_dim = row_pointers_size_per_bucket *
121+
(options.enable_minibatching
122+
? CooFormat::kMaxMinibatchingBuckets
123+
: 1) *
124+
options.num_sc_per_device;
125+
if (options.output_buffers) {
126+
auto it = options.output_buffers->find(name);
127+
if (it != options.output_buffers->end()) {
128+
return CsrArraysPerHost(options.local_device_count, row_pointers_dim,
129+
coo_buffer_size_per_device, it->second);
130+
}
131+
}
132+
return CsrArraysPerHost(options.local_device_count, row_pointers_dim,
133+
coo_buffer_size_per_device);
134+
}
135+
116136
// Holds the state for processing a single stacked table across all local
117137
// devices. This includes extracted COO tensors, partitioned COO tensors,
118138
// CSR arrays, and statistics.
@@ -138,13 +158,9 @@ struct TableState {
138158
coo_buffer_size_per_device(ComputeCooBufferSizePerDevice(
139159
num_scs, options.num_sc_per_device, metadata, options.batch_number,
140160
options.enable_minibatching)),
141-
csr_arrays_per_host(options.local_device_count,
142-
row_pointers_size_per_bucket *
143-
(options.enable_minibatching
144-
? CooFormat::kMaxMinibatchingBuckets
145-
: 1) *
146-
options.num_sc_per_device,
147-
coo_buffer_size_per_device),
161+
csr_arrays_per_host(CreateCsrArraysPerHost(
162+
name, options, coo_buffer_size_per_device,
163+
row_pointers_size_per_bucket)),
148164
stats_per_host(options.local_device_count, options.GetNumScs(),
149165
options.num_sc_per_device),
150166
batch_size_for_device(0) {
@@ -419,14 +435,20 @@ void PopulateOutput(TableState& state, PreprocessSparseDenseMatmulOutput& out,
419435
state.stats_per_host.Flatten();
420436

421437
absl::MutexLock mutex_lock(output_mutex);
422-
out.lhs_row_pointers[state.stacked_table_name] =
423-
std::move(state.csr_arrays_per_host.row_pointers);
424-
out.lhs_embedding_ids[state.stacked_table_name] =
425-
std::move(state.csr_arrays_per_host.embedding_ids);
426-
out.lhs_sample_ids[state.stacked_table_name] =
427-
std::move(state.csr_arrays_per_host.sample_ids);
428-
out.lhs_gains[state.stacked_table_name] =
429-
std::move(state.csr_arrays_per_host.gains);
438+
// If `owns_data` is true, it indicates that the data is owned
439+
// by `CsrArraysPerHost`, so we need to move it to the output. Otherwise,
440+
// the data has already been written directly into the output buffers via
441+
// `views`, and no move is necessary.
442+
if (state.csr_arrays_per_host.owns_data) {
443+
out.lhs_row_pointers[state.stacked_table_name] =
444+
std::move(state.csr_arrays_per_host.row_pointers);
445+
out.lhs_embedding_ids[state.stacked_table_name] =
446+
std::move(state.csr_arrays_per_host.embedding_ids);
447+
out.lhs_sample_ids[state.stacked_table_name] =
448+
std::move(state.csr_arrays_per_host.sample_ids);
449+
out.lhs_gains[state.stacked_table_name] =
450+
std::move(state.csr_arrays_per_host.gains);
451+
}
430452

431453
out.stats.max_ids_per_partition[state.stacked_table_name] =
432454
std::move(state.stats_per_host.max_ids_per_partition);

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ using ::testing::Each;
5757
using ::testing::ElementsAreArray;
5858
using ::testing::Eq;
5959
using ::testing::Gt;
60+
using ::testing::NanSensitiveFloatEq;
61+
using ::testing::Pointwise;
6062
using ::testing::SizeIs;
6163

6264
std::unique_ptr<AbstractInputBatch> CreateInputBatchFromSamples(
@@ -1387,5 +1389,68 @@ FUZZ_TEST(InputPreprocessingFuzzTest, StatsValidationTest)
13871389
{FeatureStackingStrategy::kStackThenSplit,
13881390
FeatureStackingStrategy::kSplitThenStack}));
13891391

1392+
TEST_F(TableStackingTest,
1393+
PreprocessingWithAndWithoutOutputBuffersIsEquivalent) {
1394+
PreprocessSparseDenseMatmulInputOptions options{
1395+
.local_device_count = 1,
1396+
.global_device_count = 2,
1397+
.num_sc_per_device = 4,
1398+
.feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit};
1399+
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
1400+
stacked_tables({{"table_0", stacked_table_metadata_multi_}});
1401+
1402+
TF_ASSERT_OK_AND_ASSIGN(
1403+
PreprocessSparseDenseMatmulOutput output_matrix,
1404+
PreprocessSparseDenseMatmulInput(absl::MakeSpan(input_batches_multi_),
1405+
stacked_tables, options));
1406+
1407+
const int num_scs = options.GetNumScs();
1408+
const int coo_buffer_size_per_device =
1409+
ComputeCooBufferSizePerDevice(num_scs, options.num_sc_per_device,
1410+
stacked_table_metadata_multi_, 0, false);
1411+
const int row_pointers_size =
1412+
std::max(num_scs, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE) *
1413+
options.num_sc_per_device;
1414+
std::vector<int> row_pointers_data(row_pointers_size, INT_MAX);
1415+
std::vector<int> embedding_ids_data(coo_buffer_size_per_device, INT_MAX);
1416+
std::vector<int> sample_ids_data(coo_buffer_size_per_device, INT_MAX);
1417+
std::vector<float> gains_data(coo_buffer_size_per_device, std::nanf(""));
1418+
absl::flat_hash_map<std::string, internal::CsrArraysPerDevice> output_buffers;
1419+
output_buffers["table_0"] = internal::CsrArraysPerDevice{
1420+
.row_pointers = absl::MakeSpan(row_pointers_data),
1421+
.embedding_ids = absl::MakeSpan(embedding_ids_data),
1422+
.sample_ids = absl::MakeSpan(sample_ids_data),
1423+
.gains = absl::MakeSpan(gains_data),
1424+
};
1425+
options.output_buffers = &output_buffers;
1426+
1427+
TF_ASSERT_OK_AND_ASSIGN(
1428+
PreprocessSparseDenseMatmulOutput output_zero_copy,
1429+
PreprocessSparseDenseMatmulInput(absl::MakeSpan(input_batches_multi_),
1430+
stacked_tables, options));
1431+
options.output_buffers = nullptr; // for next test
1432+
1433+
ASSERT_EQ(output_matrix.lhs_row_pointers["table_0"].rows(), 1);
1434+
ASSERT_EQ(output_matrix.lhs_embedding_ids["table_0"].rows(), 1);
1435+
ASSERT_EQ(output_matrix.lhs_sample_ids["table_0"].rows(), 1);
1436+
ASSERT_EQ(output_matrix.lhs_gains["table_0"].rows(), 1);
1437+
1438+
EXPECT_THAT(row_pointers_data,
1439+
ElementsAreArray(absl::MakeConstSpan(
1440+
output_matrix.lhs_row_pointers["table_0"].data(),
1441+
row_pointers_size)));
1442+
for (int i = 0; i < coo_buffer_size_per_device; ++i) {
1443+
if (embedding_ids_data[i] != INT_MAX) {
1444+
EXPECT_EQ(embedding_ids_data[i],
1445+
output_matrix.lhs_embedding_ids["table_0"].data()[i]);
1446+
EXPECT_EQ(sample_ids_data[i],
1447+
output_matrix.lhs_sample_ids["table_0"].data()[i]);
1448+
EXPECT_THAT(
1449+
gains_data[i],
1450+
NanSensitiveFloatEq(output_matrix.lhs_gains["table_0"].data()[i]));
1451+
}
1452+
}
1453+
}
1454+
13901455
} // namespace
13911456
} // namespace jax_sc_embedding

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ bool ValidIndices(int row_index, int coo_offset, int processed,
9191

9292
// Pad the row pointers buffer to the end of the buffer.
9393
void PadRowPointersBuffer(int& lhs_row_offset, int padding, int row_end,
94-
Eigen::Ref<RowVectorXi> row_pointers) {
94+
absl::Span<int> row_pointers) {
9595
while (lhs_row_offset < row_end) {
9696
row_pointers[lhs_row_offset++] = padding;
9797
}

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
#include "absl/base/attributes.h" // from @com_google_absl
2525
#include "absl/base/nullability.h" // from @com_google_absl
26+
#include "absl/container/flat_hash_map.h" // from @com_google_absl
27+
#include "absl/log/check.h" // from @com_google_absl
2628
#include "absl/strings/string_view.h" // from @com_google_absl
2729
#include "absl/types/span.h" // from @com_google_absl
2830
#include "Eigen/Core" // from @eigen_archive
@@ -60,10 +62,10 @@ using BlockRow = Eigen::Block<MatrixX<T>, 1, Eigen::Dynamic, Eigen::RowMajor>;
6062
namespace internal {
6163

6264
struct CsrArraysPerDevice {
63-
BlockRow<int> row_pointers;
64-
BlockRow<int> embedding_ids;
65-
BlockRow<int> sample_ids;
66-
BlockRow<float> gains;
65+
absl::Span<int> row_pointers;
66+
absl::Span<int> embedding_ids;
67+
absl::Span<int> sample_ids;
68+
absl::Span<float> gains;
6769
};
6870

6971
struct StatsPerDevice {
@@ -81,21 +83,58 @@ struct CsrArraysPerHost {
8183
MatrixXi sample_ids;
8284
MatrixXf gains;
8385

86+
internal::CsrArraysPerDevice views;
87+
bool owns_data = false;
88+
89+
const int local_device_count_;
90+
8491
CsrArraysPerHost(int local_device_count, int row_pointers_size_per_device,
8592
int coo_buffer_size_per_device)
8693
: row_pointers(local_device_count, row_pointers_size_per_device),
8794
embedding_ids(local_device_count, coo_buffer_size_per_device),
8895
sample_ids(local_device_count, coo_buffer_size_per_device),
89-
gains(local_device_count, coo_buffer_size_per_device) {}
96+
gains(local_device_count, coo_buffer_size_per_device),
97+
owns_data(true),
98+
local_device_count_(local_device_count) {
99+
views = {absl::MakeSpan(row_pointers.data(), row_pointers.size()),
100+
absl::MakeSpan(embedding_ids.data(), embedding_ids.size()),
101+
absl::MakeSpan(sample_ids.data(), sample_ids.size()),
102+
absl::MakeSpan(gains.data(), gains.size())};
103+
}
104+
105+
CsrArraysPerHost(int local_device_count, int row_pointers_size_per_device,
106+
int coo_buffer_size_per_device,
107+
internal::CsrArraysPerDevice output_buffers)
108+
: views(output_buffers),
109+
owns_data(false),
110+
local_device_count_(local_device_count) {
111+
CHECK_EQ(output_buffers.row_pointers.size(),
112+
local_device_count * row_pointers_size_per_device);
113+
CHECK_EQ(output_buffers.embedding_ids.size(),
114+
local_device_count * coo_buffer_size_per_device);
115+
CHECK_EQ(output_buffers.sample_ids.size(),
116+
local_device_count * coo_buffer_size_per_device);
117+
CHECK_EQ(output_buffers.gains.size(),
118+
local_device_count * coo_buffer_size_per_device);
119+
}
90120

91121
internal::CsrArraysPerDevice GetCsrArraysPerDevice(int local_device_id)
92122
ABSL_ATTRIBUTE_LIFETIME_BOUND {
93-
return internal::CsrArraysPerDevice{
94-
.row_pointers = row_pointers.row(local_device_id),
95-
.embedding_ids = embedding_ids.row(local_device_id),
96-
.sample_ids = sample_ids.row(local_device_id),
97-
.gains = gains.row(local_device_id),
98-
};
123+
int row_pointers_size_per_device =
124+
views.row_pointers.size() / local_device_count_;
125+
int coo_buffer_size_per_device =
126+
views.embedding_ids.size() / local_device_count_;
127+
return {views.row_pointers.subspan(
128+
local_device_id * row_pointers_size_per_device,
129+
row_pointers_size_per_device),
130+
views.embedding_ids.subspan(
131+
local_device_id * coo_buffer_size_per_device,
132+
coo_buffer_size_per_device),
133+
views.sample_ids.subspan(
134+
local_device_id * coo_buffer_size_per_device,
135+
coo_buffer_size_per_device),
136+
views.gains.subspan(local_device_id * coo_buffer_size_per_device,
137+
coo_buffer_size_per_device)};
99138
}
100139
};
101140

@@ -195,6 +234,10 @@ struct PreprocessSparseDenseMatmulInputOptions {
195234
// mini-batching to synchronize state across different hosts.
196235
AllReduceInterface* absl_nullable all_reduce_interface;
197236

237+
// If provided, CSR data will be written directly to these buffers.
238+
const absl::flat_hash_map<std::string, internal::CsrArraysPerDevice>*
239+
output_buffers = nullptr;
240+
198241
// Hash function used for creating minibatching buckets.
199242
CooFormat::HashFn minibatching_bucketing_hash_fn = HighwayHash;
200243

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,4 +1270,38 @@ TEST(InputPreprocessingUtilTest,
12701270
}
12711271

12721272
} // namespace
1273+
1274+
TEST(CsrArraysPerHostTest,
1275+
CreatingCsrArraysPerHostFromExternalArraysTriggersZeroCopies) {
1276+
const int kLocalDeviceCount = 1;
1277+
const int kRpSize = 10;
1278+
const int kCooSize = 20;
1279+
std::vector<int> rp_data(kRpSize, 0);
1280+
std::vector<int> eid_data(kCooSize, 0);
1281+
std::vector<int> sid_data(kCooSize, 0);
1282+
std::vector<float> gains_data(kCooSize, 0.0);
1283+
1284+
internal::CsrArraysPerDevice buffers{
1285+
.row_pointers = absl::MakeSpan(rp_data),
1286+
.embedding_ids = absl::MakeSpan(eid_data),
1287+
.sample_ids = absl::MakeSpan(sid_data),
1288+
.gains = absl::MakeSpan(gains_data),
1289+
};
1290+
1291+
CsrArraysPerHost csr_arrays_per_host(kLocalDeviceCount, kRpSize, kCooSize,
1292+
buffers);
1293+
EXPECT_FALSE(csr_arrays_per_host.owns_data);
1294+
1295+
internal::CsrArraysPerDevice device_array =
1296+
csr_arrays_per_host.GetCsrArraysPerDevice(0);
1297+
device_array.row_pointers[0] = 1;
1298+
device_array.embedding_ids[0] = 2;
1299+
device_array.sample_ids[0] = 3;
1300+
device_array.gains[0] = 4.0;
1301+
1302+
EXPECT_EQ(rp_data[0], 1);
1303+
EXPECT_EQ(eid_data[0], 2);
1304+
EXPECT_EQ(sid_data[0], 3);
1305+
EXPECT_EQ(gains_data[0], 4.0);
1306+
}
12731307
} // namespace jax_sc_embedding

0 commit comments

Comments
 (0)