Skip to content

Commit 49ba0b9

Browse files
[JAX SC] perf: Refactor COO grouping key and add optimization for constant weights.
2x improvement due to avoiding random memory accesses. The 64-bit grouping key now uses the lower bits to store either the original index (if variable weights are present) or the row_id (if weights are constant). This allows skipping a memory lookup for `CooFormat` objects when feature weights are always `1.0`, improving performance. The key unpacking logic is updated to use new static helper functions. PiperOrigin-RevId: 836364065
1 parent 8b38bd6 commit 49ba0b9

File tree

7 files changed

+122
-41
lines changed

7 files changed

+122
-41
lines changed

jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class AbstractInputBatch {
4949
// Return the batch size or the number of samples in this input batch.
5050
virtual ssize_t size() const = 0;
5151

52-
// Extract COO Tensors.
52+
// Returns true if the input batch has variable weights.
53+
virtual bool HasVariableWeights() const { return true; }
54+
5355
virtual void ExtractCooTensors(
5456
const ExtractCooTensorsOptions& options,
5557
ExtractedCooTensors& extracted_coo_tensors) = 0;

jax_tpu_embedding/sparsecore/lib/core/coo_format.h

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,14 @@ struct CooFormat {
5858
// Bits taken by minibatching bucket ID.
5959
static constexpr uint32_t kMinibatchingBucketBits =
6060
absl::bit_width(kMaxMinibatchingBuckets - 1);
61-
// Bits for Index
62-
static constexpr uint32_t kIndexBits = 32 - kMinibatchingBucketBits;
63-
// Index Mask
64-
static constexpr uint32_t kIndexMask = (1 << kIndexBits) - 1;
61+
// Bits for variable data (index or row_id).
62+
static constexpr uint32_t kDataBits = 32 - kMinibatchingBucketBits;
63+
// Mask for variable data (index or row_id).
64+
static constexpr uint32_t kDataMask = (1 << kDataBits) - 1;
65+
// Bit offset for rotated_col_id in grouping key.
66+
static constexpr uint32_t kRotatedColIdOffset = kDataBits;
67+
// Bit offset for bucket_id in grouping key.
68+
static constexpr uint32_t kBucketIdOffset = kRotatedColIdOffset + 32;
6569

6670
// A deterministic hash function eventually used to compute mini-batching
6771
// bucket id as `hash(col_id) % bucket_count`.
@@ -136,24 +140,36 @@ struct CooFormat {
136140
// Computes a 64-bit sorting key with the following layout:
137141
// [63:58] bucket_id (6 bits)
138142
// [57:26] {global_sc_id, local_embedding_id} (32 bits) <- rotated col_id
139-
// [25:0] index (26 bits)
143+
// [25:0] index or row_id (26 bits)
140144
// The key is used to group and sort COO tensors for efficient processing.
141145
uint64_t GetGroupingKey(const uint32_t num_scs_bit, const int index,
142-
const bool create_buckets = false,
143-
HashFn hash_fn = HighwayHash) const {
146+
const bool create_buckets,
147+
HashFn hash_fn = HighwayHash,
148+
const bool has_variable_weights = true) const {
144149
// This structure ensures tensors are sorted first by bucket_id, then by
145150
// sparse core, and finally by embedding ID.
146151
const uint32_t bucket_id = create_buckets ? GetBucketId(hash_fn) : 0;
147152

148-
DCHECK_LE(index, kIndexMask);
153+
const uint32_t data = has_variable_weights ? index : row_id;
154+
DCHECK_LE(data, kDataMask);
149155

150156
// [global_sc_id, local_embedding_id]
151157
uint32_t rotated_col_id =
152158
absl::rotr(static_cast<uint32_t>(col_id), num_scs_bit);
153159

154-
return (uint64_t{bucket_id} << (64 - kMinibatchingBucketBits)) |
155-
(uint64_t{rotated_col_id} << (32 - kMinibatchingBucketBits)) |
156-
static_cast<uint64_t>(index);
160+
return (uint64_t{bucket_id} << kBucketIdOffset) |
161+
(uint64_t{rotated_col_id} << kRotatedColIdOffset) |
162+
static_cast<uint64_t>(data);
163+
}
164+
165+
static uint32_t GetDataFromKey(uint64_t key) { return key & kDataMask; }
166+
167+
static uint32_t GetRotatedColIdFromKey(uint64_t key) {
168+
return (key >> kRotatedColIdOffset) & 0xFFFFFFFF;
169+
}
170+
171+
static uint32_t GetBucketIdFromKey(uint64_t key) {
172+
return key >> kBucketIdOffset;
157173
}
158174
};
159175

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,34 @@ void CheckDeviceBatchSize(int batch_size_for_device, int num_sc_per_device,
113113
batch_size_for_device, stacked_table_name, num_sc_per_device);
114114
}
115115

116+
// We consider a stack to have variable weights if any feature in the stack
117+
// has explicitly variable weights or if any feature uses a row combiner
118+
// other than 'sum' (e.g., 'mean' or 'sqrtn').
119+
bool StackHasVariableWeights(
120+
absl::Span<std::unique_ptr<AbstractInputBatch>> input_batches,
121+
absl::Span<const StackedTableMetadata> stacked_table_metadata) {
122+
for (const auto& metadata : stacked_table_metadata) {
123+
// `kHasVariableWeights` must be true if any feature in the stack:
124+
// 1. Is explicitly marked as having variable weights.
125+
// 2. Uses a row combiner other than 'sum'. Non-'sum' combiners (e.g.,
126+
// 'mean', 'sqrtn') adjust gains during `ExtractCooTensors`. This
127+
// means the gains in `coo_tensors` are not always 1.0, even with unity
128+
// input weights.
129+
if (input_batches[metadata.feature_index]->HasVariableWeights() ||
130+
metadata.row_combiner != RowCombiner::kSum) {
131+
return true;
132+
}
133+
}
134+
return false;
135+
}
136+
116137
// Holds the state for processing a single stacked table across all local
117138
// devices. This includes extracted COO tensors, partitioned COO tensors,
118139
// CSR arrays, and statistics.
119140
struct TableState {
120141
absl::string_view stacked_table_name;
121142
absl::Span<const StackedTableMetadata> stacked_table_metadata;
143+
bool has_variable_weights;
122144
int coo_buffer_size_per_device;
123145
CsrArraysPerHost csr_arrays_per_host;
124146
StatsPerHost stats_per_host;
@@ -131,10 +153,12 @@ struct TableState {
131153

132154
TableState(absl::string_view name,
133155
absl::Span<const StackedTableMetadata> metadata,
156+
bool has_variable_weights,
134157
const PreprocessSparseDenseMatmulInputOptions& options,
135158
int num_scs, int row_pointers_size_per_bucket)
136159
: stacked_table_name(name),
137160
stacked_table_metadata(metadata),
161+
has_variable_weights(has_variable_weights),
138162
coo_buffer_size_per_device(ComputeCooBufferSizePerDevice(
139163
num_scs, options.num_sc_per_device, metadata, options.batch_number,
140164
options.enable_minibatching)),
@@ -154,6 +178,24 @@ struct TableState {
154178
}
155179
};
156180

181+
template <typename SplitType>
182+
void SortAndGroupCooTensorsForTableState(
183+
TableState& state, int local_device,
184+
const PreprocessSparseDenseMatmulInputOptions& options,
185+
internal::StatsPerDevice& stats, SplitType& split) {
186+
if (state.has_variable_weights) {
187+
state.partitioned_coo_tensors_per_device[local_device] =
188+
SortAndGroupCooTensorsPerLocalDevice<true>(
189+
state.extracted_coo_tensors_per_device[local_device],
190+
state.stacked_table_metadata[0], options, stats, split);
191+
} else {
192+
state.partitioned_coo_tensors_per_device[local_device] =
193+
SortAndGroupCooTensorsPerLocalDevice<false>(
194+
state.extracted_coo_tensors_per_device[local_device],
195+
state.stacked_table_metadata[0], options, stats, split);
196+
}
197+
}
198+
157199
// Extracts, sorts, and groups COO tensors for a single stacked table across
158200
// all local devices. This function populates
159201
// `state.extracted_coo_tensors_per_device` and
@@ -180,11 +222,9 @@ void ExtractSortAndGroupCooTensorsForTable(
180222

181223
internal::StatsPerDevice stats_per_device =
182224
state.stats_per_host.GetStatsPerDevice(local_device);
183-
state.partitioned_coo_tensors_per_device[local_device] =
184-
SortAndGroupCooTensorsPerLocalDevice(
185-
state.extracted_coo_tensors_per_device[local_device],
186-
state.stacked_table_metadata[0], options, stats_per_device,
187-
state.table_minibatching_required);
225+
SortAndGroupCooTensorsForTableState(
226+
state, local_device, options, stats_per_device,
227+
state.table_minibatching_required);
188228
state.dropped_id_count_per_device[local_device] =
189229
stats_per_device.dropped_id_count;
190230
counter.DecrementCount();
@@ -230,11 +270,9 @@ void CreateMinibatchingBucketsForTable(
230270
options.num_sc_per_device);
231271
internal::StatsPerDevice dummy_stats =
232272
dummy_stats_host.GetStatsPerDevice(0);
233-
state.partitioned_coo_tensors_per_device[local_device] =
234-
SortAndGroupCooTensorsPerLocalDevice(
235-
state.extracted_coo_tensors_per_device[local_device],
236-
state.stacked_table_metadata[0], options, dummy_stats,
237-
state.table_minibatching_split);
273+
SortAndGroupCooTensorsForTableState(state, local_device, options,
274+
dummy_stats,
275+
state.table_minibatching_split);
238276
state.dropped_id_count_per_device[local_device] =
239277
dummy_stats.dropped_id_count;
240278
counter.DecrementCount();
@@ -538,8 +576,11 @@ PreprocessSparseDenseMatmulInput(
538576
table_states.reserve(stacked_tables.size());
539577
for (const auto& [stacked_table_name, stacked_table_metadata] :
540578
stacked_tables) {
579+
const bool stack_has_weights =
580+
StackHasVariableWeights(input_batches, stacked_table_metadata);
541581
table_states.emplace_back(stacked_table_name, stacked_table_metadata,
542-
options, num_scs, row_pointers_size_per_bucket);
582+
stack_has_weights, options, num_scs,
583+
row_pointers_size_per_bucket);
543584
}
544585

545586
// Stage 1: COO Extraction and Initial Sort/Group

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ TEST_P(MinibatchingTest, KeysAreSorted) {
608608
TEST_P(MinibatchingTest, IndexFromKeyIsCorrect) {
609609
std::vector<uint64_t> keys = GenerateGroupingKeys();
610610
for (int i = 0; i < keys.size(); ++i) {
611-
EXPECT_EQ(keys[i] & CooFormat::kIndexMask, i);
611+
EXPECT_EQ(CooFormat::GetDataFromKey(keys[i]), i);
612612
}
613613
}
614614

jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ class RaggedTensorInputBatch : public AbstractInputBatch {
7575
max_vocab_id_(max_vocab_id) {}
7676

7777
int64_t size() const override { return row_offsets_.size() - 1; }
78+
79+
bool HasVariableWeights() const override { return false; }
80+
7881
void ExtractCooTensors(const ExtractCooTensorsOptions& options,
7982
ExtractedCooTensors& coo_tensors) override {
8083
SparseCsrInputBatchStream<int64_t, EmbeddingIdsView, RowOffsetsView>

jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,13 @@ inline void ValidateMaxIdsOrDie(
8484
}
8585

8686
inline void ValidateKeyCapacity(const int local_sc_id, const int key_count) {
87-
// Index = 0 to kIndexMask giving us a count of kIndexMask + 1.
88-
if (key_count > 1 + CooFormat::kIndexMask) {
87+
// Index = 0 to kDataMask giving us a count of kDataMask + 1.
88+
if (key_count > 1 + CooFormat::kDataMask) {
8989
LOG(ERROR) << absl::StrFormat(
9090
"Too many tensors for SparseCore #%d: got %d, limit: "
9191
"%d. Preprocessed output may not be reliable and cause undefined "
9292
"behavior.",
93-
local_sc_id, key_count, CooFormat::kIndexMask);
93+
local_sc_id, key_count, CooFormat::kDataMask);
9494
}
9595
}
9696

@@ -177,6 +177,7 @@ struct LocalSparseCoreTensorGroupingContext {
177177
const PreprocessSparseDenseMatmulInputOptions& options;
178178
const bool create_buckets;
179179
const int32_t local_sc_id;
180+
const int32_t num_sc_bits;
180181

181182
// Outputs.
182183
PartitionedCooTensors& grouped_coo_tensors;
@@ -187,6 +188,7 @@ struct LocalSparseCoreTensorGroupingContext {
187188
MatrixXi& kept_unique_ids_per_partition_per_bucket;
188189
};
189190

191+
template <bool kHasVariableWeights>
190192
inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
191193
LocalSparseCoreTensorGroupingContext context) {
192194
// Unpack context for readability.
@@ -214,17 +216,24 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
214216
// capacity. This decision is sticky for all tensors with the same `col_id`
215217
// within the same bucket.
216218
bool dropping_current_unique_col_id = false;
219+
const int num_sc_bits = context.num_sc_bits;
217220
for (const uint64_t key : context.keys) {
218221
// Step 1: Unpack key to get tensor coordinates.
219-
const uint32_t index = key & CooFormat::kIndexMask;
220-
const CooFormat& coo_tensor = coo_tensors[index];
221-
const uint32_t col_id = coo_tensor.col_id;
222-
const uint32_t global_sc_id = coo_tensor.col_id & (global_sc_count - 1);
223-
const uint32_t bucket_id =
224-
context.create_buckets
225-
? coo_tensor.GetBucketId(options.minibatching_bucketing_hash_fn)
226-
: 0;
227-
const uint32_t row_id = coo_tensor.row_id;
222+
const uint32_t bucket_id = CooFormat::GetBucketIdFromKey(key);
223+
const uint32_t col_id =
224+
absl::rotl(CooFormat::GetRotatedColIdFromKey(key), num_sc_bits);
225+
const uint32_t global_sc_id = col_id & (global_sc_count - 1);
226+
227+
uint32_t row_id;
228+
CooFormat coo_tensor(0, 0, 0.0f);
229+
if constexpr (kHasVariableWeights) {
230+
const uint32_t index = CooFormat::GetDataFromKey(key);
231+
coo_tensor = coo_tensors[index];
232+
row_id = coo_tensor.row_id;
233+
} else {
234+
row_id = CooFormat::GetDataFromKey(key);
235+
coo_tensor = CooFormat(row_id, col_id, 1.0f);
236+
}
228237

229238
// Step 2: Handle duplicates.
230239
// An ID that is a duplicate of a previously non-dropped ID is merged.
@@ -298,7 +307,7 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
298307
// NOTE: We use output buffers `max_ids_per_sc`, `max_unique_ids_per_sc`, and
299308
// `required_buffer_size_per_sc` because we fill values in a loop to a bigger
300309
// array.
301-
template <typename SplitType>
310+
template <bool kHasVariableWeights = true, typename SplitType>
302311
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
303312
const ExtractedCooTensors& extracted_coo_tensors,
304313
const StackedTableMetadata& stacked_table_metadata,
@@ -364,25 +373,30 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
364373
coo_tensors[coo_tensor_index].row_id <
365374
(local_sc_id + 1) * batch_size_per_sc;
366375
coo_tensor_index++) {
376+
const CooFormat& coo_tensor = coo_tensors[coo_tensor_index];
367377
// The key here is [bucket_id(6 bits), global_sc_id(num_scs bits),
368378
// local_embedding_id(32-num_scs bits), index(26 bits)].
369379
// Note that this assumes `num_scs` is a power of 2.
370-
keys.push_back(coo_tensors[coo_tensor_index].GetGroupingKey(
380+
keys.push_back(coo_tensor.GetGroupingKey(
371381
num_sc_bits, coo_tensor_index, create_buckets,
372-
options.minibatching_bucketing_hash_fn));
382+
options.minibatching_bucketing_hash_fn, kHasVariableWeights));
383+
DCHECK(kHasVariableWeights || coo_tensors[coo_tensor_index].gain == 1.0f)
384+
<< "kHasVariableWeights: " << kHasVariableWeights
385+
<< ", coo: " << coo_tensor;
373386
}
374387

375388
// The expected allocation size may be uninitialized.
376389
DCHECK(expected_keys_size == 0 || keys.size() == expected_keys_size);
377390
hwy::VQSort(keys.data(), keys.size(), hwy::SortAscending());
378391

379-
internal::GroupAndDeduplicateCooTensorsForLocalSparseCore({
392+
const internal::LocalSparseCoreTensorGroupingContext context = {
380393
.keys = keys,
381394
.coo_tensors = coo_tensors,
382395
.stacked_table_metadata = stacked_table_metadata,
383396
.options = options,
384397
.create_buckets = create_buckets,
385398
.local_sc_id = local_sc_id,
399+
.num_sc_bits = num_sc_bits,
386400
.grouped_coo_tensors = grouped_coo_tensors,
387401
.ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket,
388402
.unique_ids_per_partition_per_bucket =
@@ -392,7 +406,10 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
392406
kept_ids_per_sc_partition_per_bucket,
393407
.kept_unique_ids_per_partition_per_bucket =
394408
kept_unique_ids_per_partition_per_bucket,
395-
});
409+
};
410+
411+
internal::GroupAndDeduplicateCooTensorsForLocalSparseCore<
412+
kHasVariableWeights>(context);
396413

397414
grouped_coo_tensors.FillRemainingScBuckets();
398415

jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class PySparseCooInputBatch : public AbstractInputBatch {
5959
// Returns the number of rows in the current slice.
6060
int64_t size() const override { return batch_size_; }
6161

62+
bool HasVariableWeights() const override { return false; }
63+
6264
// Extracts COO tensors for each SparseCore.
6365
void ExtractCooTensors(const ExtractCooTensorsOptions& options,
6466
ExtractedCooTensors& coo_tensors) override;

0 commit comments

Comments
 (0)