Skip to content

Commit 8628804

Browse files
Optimize SparseCore input preprocessing by eliminating buffer copies.
* Introducing `OutputBufferViews`, allowing the caller to pass views of the destination buffers. * Modifying `CsrArraysPerHost` to optionally use `Eigen::Map` to wrap these buffers when provided. We can avoid populating large CSR arrays in the preprocessing return values and skip the data copy step. PiperOrigin-RevId: 831179831
1 parent 8eb529e commit 8628804

File tree

6 files changed

+191
-28
lines changed

6 files changed

+191
-28
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ cc_library(
5757
":partitioned_coo_tensors",
5858
"@com_google_absl//absl/base:core_headers",
5959
"@com_google_absl//absl/base:nullability",
60+
"@com_google_absl//absl/container:flat_hash_map",
6061
"@com_google_absl//absl/log",
6162
"@com_google_absl//absl/log:check",
6263
"@com_google_absl//absl/strings:string_view",

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,26 @@ void CheckDeviceBatchSize(int batch_size_for_device, int num_sc_per_device,
113113
batch_size_for_device, stacked_table_name, num_sc_per_device);
114114
}
115115

116+
CsrArraysPerHost CreateCsrArraysPerHost(
117+
absl::string_view name,
118+
const PreprocessSparseDenseMatmulInputOptions& options,
119+
int coo_buffer_size_per_device, int row_pointers_size_per_bucket) {
120+
const int row_pointers_dim = row_pointers_size_per_bucket *
121+
(options.enable_minibatching
122+
? CooFormat::kMaxMinibatchingBuckets
123+
: 1) *
124+
options.num_sc_per_device;
125+
if (options.output_buffers) {
126+
auto it = options.output_buffers->find(name);
127+
if (it != options.output_buffers->end()) {
128+
return CsrArraysPerHost(options.local_device_count, row_pointers_dim,
129+
coo_buffer_size_per_device, it->second);
130+
}
131+
}
132+
return CsrArraysPerHost(options.local_device_count, row_pointers_dim,
133+
coo_buffer_size_per_device);
134+
}
135+
116136
// Holds the state for processing a single stacked table across all local
117137
// devices. This includes extracted COO tensors, partitioned COO tensors,
118138
// CSR arrays, and statistics.
@@ -138,13 +158,9 @@ struct TableState {
138158
coo_buffer_size_per_device(ComputeCooBufferSizePerDevice(
139159
num_scs, options.num_sc_per_device, metadata, options.batch_number,
140160
options.enable_minibatching)),
141-
csr_arrays_per_host(options.local_device_count,
142-
row_pointers_size_per_bucket *
143-
(options.enable_minibatching
144-
? CooFormat::kMaxMinibatchingBuckets
145-
: 1) *
146-
options.num_sc_per_device,
147-
coo_buffer_size_per_device),
161+
csr_arrays_per_host(CreateCsrArraysPerHost(
162+
name, options, coo_buffer_size_per_device,
163+
row_pointers_size_per_bucket)),
148164
stats_per_host(options.local_device_count, options.GetNumScs(),
149165
options.num_sc_per_device),
150166
batch_size_for_device(0) {
@@ -428,15 +444,21 @@ void PopulateOutput(TableState& state, PreprocessSparseDenseMatmulOutput& out,
428444
absl::Mutex& output_mutex) {
429445
state.stats_per_host.Flatten();
430446

431-
absl::MutexLock mutex(output_mutex);
432-
out.lhs_row_pointers[state.stacked_table_name] =
433-
std::move(state.csr_arrays_per_host.row_pointers);
434-
out.lhs_embedding_ids[state.stacked_table_name] =
435-
std::move(state.csr_arrays_per_host.embedding_ids);
436-
out.lhs_sample_ids[state.stacked_table_name] =
437-
std::move(state.csr_arrays_per_host.sample_ids);
438-
out.lhs_gains[state.stacked_table_name] =
439-
std::move(state.csr_arrays_per_host.gains);
447+
absl::MutexLock lock(output_mutex);
448+
// If `owns_data` is true, it indicates that the data is owned
449+
// by `CsrArraysPerHost`, so we need to move it to the output. Otherwise,
450+
// the data has already been written directly into the output buffers via
451+
// `views`, and no move is necessary.
452+
if (state.csr_arrays_per_host.owns_data) {
453+
out.lhs_row_pointers[state.stacked_table_name] =
454+
std::move(state.csr_arrays_per_host.row_pointers);
455+
out.lhs_embedding_ids[state.stacked_table_name] =
456+
std::move(state.csr_arrays_per_host.embedding_ids);
457+
out.lhs_sample_ids[state.stacked_table_name] =
458+
std::move(state.csr_arrays_per_host.sample_ids);
459+
out.lhs_gains[state.stacked_table_name] =
460+
std::move(state.csr_arrays_per_host.gains);
461+
}
440462

441463
out.stats.max_ids_per_partition[state.stacked_table_name] =
442464
std::move(state.stats_per_host.max_ids_per_partition);

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ using ::testing::Each;
5858
using ::testing::ElementsAreArray;
5959
using ::testing::Eq;
6060
using ::testing::Gt;
61+
using ::testing::NanSensitiveFloatEq;
6162
using ::testing::SizeIs;
6263

6364
std::unique_ptr<AbstractInputBatch> CreateInputBatchFromSamples(
@@ -1388,5 +1389,68 @@ FUZZ_TEST(InputPreprocessingFuzzTest, StatsValidationTest)
13881389
{FeatureStackingStrategy::kStackThenSplit,
13891390
FeatureStackingStrategy::kSplitThenStack}));
13901391

1392+
TEST_F(TableStackingTest,
1393+
PreprocessingWithAndWithoutOutputBuffersIsEquivalent) {
1394+
PreprocessSparseDenseMatmulInputOptions options{
1395+
.local_device_count = 1,
1396+
.global_device_count = 2,
1397+
.num_sc_per_device = 4,
1398+
.feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit};
1399+
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
1400+
stacked_tables({{"table_0", stacked_table_metadata_multi_}});
1401+
1402+
TF_ASSERT_OK_AND_ASSIGN(
1403+
PreprocessSparseDenseMatmulOutput output_matrix,
1404+
PreprocessSparseDenseMatmulInput(absl::MakeSpan(input_batches_multi_),
1405+
stacked_tables, options));
1406+
1407+
const int num_scs = options.GetNumScs();
1408+
const int coo_buffer_size_per_device =
1409+
ComputeCooBufferSizePerDevice(num_scs, options.num_sc_per_device,
1410+
stacked_table_metadata_multi_, 0, false);
1411+
const int row_pointers_size =
1412+
std::max(num_scs, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE) *
1413+
options.num_sc_per_device;
1414+
std::vector<int> row_pointers_data(row_pointers_size, INT_MAX);
1415+
std::vector<int> embedding_ids_data(coo_buffer_size_per_device, INT_MAX);
1416+
std::vector<int> sample_ids_data(coo_buffer_size_per_device, INT_MAX);
1417+
std::vector<float> gains_data(coo_buffer_size_per_device, std::nanf(""));
1418+
absl::flat_hash_map<std::string, internal::CsrArraysPerDevice> output_buffers;
1419+
output_buffers["table_0"] = internal::CsrArraysPerDevice{
1420+
.row_pointers = absl::MakeSpan(row_pointers_data),
1421+
.embedding_ids = absl::MakeSpan(embedding_ids_data),
1422+
.sample_ids = absl::MakeSpan(sample_ids_data),
1423+
.gains = absl::MakeSpan(gains_data),
1424+
};
1425+
options.output_buffers = &output_buffers;
1426+
1427+
TF_ASSERT_OK_AND_ASSIGN(
1428+
PreprocessSparseDenseMatmulOutput output_zero_copy,
1429+
PreprocessSparseDenseMatmulInput(absl::MakeSpan(input_batches_multi_),
1430+
stacked_tables, options));
1431+
options.output_buffers = nullptr; // for next test
1432+
1433+
ASSERT_EQ(output_matrix.lhs_row_pointers["table_0"].rows(), 1);
1434+
ASSERT_EQ(output_matrix.lhs_embedding_ids["table_0"].rows(), 1);
1435+
ASSERT_EQ(output_matrix.lhs_sample_ids["table_0"].rows(), 1);
1436+
ASSERT_EQ(output_matrix.lhs_gains["table_0"].rows(), 1);
1437+
1438+
EXPECT_THAT(row_pointers_data,
1439+
ElementsAreArray(absl::MakeConstSpan(
1440+
output_matrix.lhs_row_pointers["table_0"].data(),
1441+
row_pointers_size)));
1442+
for (int i = 0; i < coo_buffer_size_per_device; ++i) {
1443+
if (embedding_ids_data[i] != INT_MAX) {
1444+
EXPECT_EQ(embedding_ids_data[i],
1445+
output_matrix.lhs_embedding_ids["table_0"].data()[i]);
1446+
EXPECT_EQ(sample_ids_data[i],
1447+
output_matrix.lhs_sample_ids["table_0"].data()[i]);
1448+
EXPECT_THAT(
1449+
gains_data[i],
1450+
NanSensitiveFloatEq(output_matrix.lhs_gains["table_0"].data()[i]));
1451+
}
1452+
}
1453+
}
1454+
13911455
} // namespace
13921456
} // namespace jax_sc_embedding

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ bool ValidIndices(int row_index, int coo_offset, int processed,
9191

9292
// Pad the row pointers buffer to the end of the buffer.
9393
void PadRowPointersBuffer(int& lhs_row_offset, int padding, int row_end,
94-
Eigen::Ref<RowVectorXi> row_pointers) {
94+
absl::Span<int> row_pointers) {
9595
while (lhs_row_offset < row_end) {
9696
row_pointers[lhs_row_offset++] = padding;
9797
}

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
#include "absl/base/attributes.h" // from @com_google_absl
2525
#include "absl/base/nullability.h" // from @com_google_absl
26+
#include "absl/container/flat_hash_map.h" // from @com_google_absl
27+
#include "absl/log/check.h" // from @com_google_absl
2628
#include "absl/strings/string_view.h" // from @com_google_absl
2729
#include "absl/types/span.h" // from @com_google_absl
2830
#include "Eigen/Core" // from @eigen_archive
@@ -60,10 +62,10 @@ using BlockRow = Eigen::Block<MatrixX<T>, 1, Eigen::Dynamic, Eigen::RowMajor>;
6062
namespace internal {
6163

6264
struct CsrArraysPerDevice {
63-
BlockRow<int> row_pointers;
64-
BlockRow<int> embedding_ids;
65-
BlockRow<int> sample_ids;
66-
BlockRow<float> gains;
65+
absl::Span<int> row_pointers;
66+
absl::Span<int> embedding_ids;
67+
absl::Span<int> sample_ids;
68+
absl::Span<float> gains;
6769
};
6870

6971
struct StatsPerDevice {
@@ -81,21 +83,58 @@ struct CsrArraysPerHost {
8183
MatrixXi sample_ids;
8284
MatrixXf gains;
8385

86+
internal::CsrArraysPerDevice views;
87+
bool owns_data = false;
88+
89+
const int local_device_count_;
90+
8491
CsrArraysPerHost(int local_device_count, int row_pointers_size_per_device,
8592
int coo_buffer_size_per_device)
8693
: row_pointers(local_device_count, row_pointers_size_per_device),
8794
embedding_ids(local_device_count, coo_buffer_size_per_device),
8895
sample_ids(local_device_count, coo_buffer_size_per_device),
89-
gains(local_device_count, coo_buffer_size_per_device) {}
96+
gains(local_device_count, coo_buffer_size_per_device),
97+
owns_data(true),
98+
local_device_count_(local_device_count) {
99+
views = {absl::MakeSpan(row_pointers.data(), row_pointers.size()),
100+
absl::MakeSpan(embedding_ids.data(), embedding_ids.size()),
101+
absl::MakeSpan(sample_ids.data(), sample_ids.size()),
102+
absl::MakeSpan(gains.data(), gains.size())};
103+
}
104+
105+
CsrArraysPerHost(int local_device_count, int row_pointers_size_per_device,
106+
int coo_buffer_size_per_device,
107+
internal::CsrArraysPerDevice output_buffers)
108+
: views(output_buffers),
109+
owns_data(false),
110+
local_device_count_(local_device_count) {
111+
CHECK_EQ(output_buffers.row_pointers.size(),
112+
local_device_count * row_pointers_size_per_device);
113+
CHECK_EQ(output_buffers.embedding_ids.size(),
114+
local_device_count * coo_buffer_size_per_device);
115+
CHECK_EQ(output_buffers.sample_ids.size(),
116+
local_device_count * coo_buffer_size_per_device);
117+
CHECK_EQ(output_buffers.gains.size(),
118+
local_device_count * coo_buffer_size_per_device);
119+
}
90120

91121
internal::CsrArraysPerDevice GetCsrArraysPerDevice(int local_device_id)
92122
ABSL_ATTRIBUTE_LIFETIME_BOUND {
93-
return internal::CsrArraysPerDevice{
94-
.row_pointers = row_pointers.row(local_device_id),
95-
.embedding_ids = embedding_ids.row(local_device_id),
96-
.sample_ids = sample_ids.row(local_device_id),
97-
.gains = gains.row(local_device_id),
98-
};
123+
int row_pointers_size_per_device =
124+
views.row_pointers.size() / local_device_count_;
125+
int coo_buffer_size_per_device =
126+
views.embedding_ids.size() / local_device_count_;
127+
return {views.row_pointers.subspan(
128+
local_device_id * row_pointers_size_per_device,
129+
row_pointers_size_per_device),
130+
views.embedding_ids.subspan(
131+
local_device_id * coo_buffer_size_per_device,
132+
coo_buffer_size_per_device),
133+
views.sample_ids.subspan(
134+
local_device_id * coo_buffer_size_per_device,
135+
coo_buffer_size_per_device),
136+
views.gains.subspan(local_device_id * coo_buffer_size_per_device,
137+
coo_buffer_size_per_device)};
99138
}
100139
};
101140

@@ -195,6 +234,10 @@ struct PreprocessSparseDenseMatmulInputOptions {
195234
// mini-batching to synchronize state across different hosts.
196235
AllReduceInterface* absl_nullable all_reduce_interface;
197236

237+
// If provided, CSR data will be written directly to these buffers.
238+
const absl::flat_hash_map<std::string, internal::CsrArraysPerDevice>*
239+
output_buffers = nullptr;
240+
198241
// Hash function used for creating minibatching buckets.
199242
CooFormat::HashFn minibatching_bucketing_hash_fn = HighwayHash;
200243

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,5 +1269,38 @@ TEST(InputPreprocessingUtilTest,
12691269
ElementsAreArray(expected_sample_ids));
12701270
}
12711271

1272+
TEST(CsrArraysPerHostTest,
1273+
CreatingCsrArraysPerHostFromExternalArraysTriggersZeroCopies) {
1274+
const int kLocalDeviceCount = 1;
1275+
const int kRpSize = 10;
1276+
const int kCooSize = 20;
1277+
std::vector<int> rp_data(kRpSize, 0);
1278+
std::vector<int> eid_data(kCooSize, 0);
1279+
std::vector<int> sid_data(kCooSize, 0);
1280+
std::vector<float> gains_data(kCooSize, 0.0);
1281+
1282+
internal::CsrArraysPerDevice buffers{
1283+
.row_pointers = absl::MakeSpan(rp_data),
1284+
.embedding_ids = absl::MakeSpan(eid_data),
1285+
.sample_ids = absl::MakeSpan(sid_data),
1286+
.gains = absl::MakeSpan(gains_data),
1287+
};
1288+
1289+
CsrArraysPerHost csr_arrays_per_host(kLocalDeviceCount, kRpSize, kCooSize,
1290+
buffers);
1291+
EXPECT_FALSE(csr_arrays_per_host.owns_data);
1292+
1293+
internal::CsrArraysPerDevice device_array =
1294+
csr_arrays_per_host.GetCsrArraysPerDevice(0);
1295+
device_array.row_pointers[0] = 1;
1296+
device_array.embedding_ids[0] = 2;
1297+
device_array.sample_ids[0] = 3;
1298+
device_array.gains[0] = 4.0;
1299+
1300+
EXPECT_EQ(rp_data[0], 1);
1301+
EXPECT_EQ(eid_data[0], 2);
1302+
EXPECT_EQ(sid_data[0], 3);
1303+
EXPECT_EQ(gains_data[0], 4.0);
1304+
}
12721305
} // namespace
12731306
} // namespace jax_sc_embedding

0 commit comments

Comments
 (0)