Skip to content

Commit 699b9d0

Browse files
[JAX SC] Continuation of cl/836767096. Templatize kCreateBuckets for compile-time optimization.
* `10.89%` geomean reduction in wall time with `9.02%` CPU time decrease. PiperOrigin-RevId: 826579723
1 parent 49ba0b9 commit 699b9d0

File tree

2 files changed

+59
-31
lines changed

2 files changed

+59
-31
lines changed

jax_tpu_embedding/sparsecore/lib/core/sort_and_group_benchmark.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
124124
stats_per_host.GetStatsPerDevice(0);
125125

126126
if (state.thread_index() == 0) {
127-
SortAndGroupCooTensorsPerLocalDevice(
127+
SortAndGroupCooTensorsPerLocalDevice<true>(
128128
extracted_coo_tensors, stacked_table_metadata, options,
129129
stats_per_device, minibatching_required);
130130
LogStats(stats_per_device.max_ids_per_partition,
@@ -134,7 +134,7 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
134134
}
135135

136136
for (auto s : state) {
137-
SortAndGroupCooTensorsPerLocalDevice(
137+
SortAndGroupCooTensorsPerLocalDevice<true>(
138138
extracted_coo_tensors, stacked_table_metadata, options,
139139
stats_per_device, minibatching_required);
140140
}

jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ struct LocalSparseCoreTensorGroupingContext {
175175
absl::Span<const CooFormat> coo_tensors;
176176
const StackedTableMetadata& stacked_table_metadata;
177177
const PreprocessSparseDenseMatmulInputOptions& options;
178-
const bool create_buckets;
179178
const int32_t local_sc_id;
180179
const int32_t num_sc_bits;
181180

@@ -184,11 +183,13 @@ struct LocalSparseCoreTensorGroupingContext {
184183
MatrixXi& ids_per_sc_partition_per_bucket;
185184
MatrixXi& unique_ids_per_partition_per_bucket;
186185
StatsPerDevice& stats;
186+
// These are only used for id dropping decisions and can be ignored otherwise.
187187
MatrixXi& kept_ids_per_sc_partition_per_bucket;
188188
MatrixXi& kept_unique_ids_per_partition_per_bucket;
189189
};
190190

191-
template <bool kHasVariableWeights>
191+
192+
template <bool kHasVariableWeights, bool kCreateBuckets>
192193
inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
193194
LocalSparseCoreTensorGroupingContext context) {
194195
// Unpack context for readability.
@@ -219,7 +220,8 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
219220
const int num_sc_bits = context.num_sc_bits;
220221
for (const uint64_t key : context.keys) {
221222
// Step 1: Unpack key to get tensor coordinates.
222-
const uint32_t bucket_id = CooFormat::GetBucketIdFromKey(key);
223+
const uint32_t bucket_id =
224+
kCreateBuckets ? CooFormat::GetBucketIdFromKey(key) : 0;
223225
const uint32_t col_id =
224226
absl::rotl(CooFormat::GetRotatedColIdFromKey(key), num_sc_bits);
225227
const uint32_t global_sc_id = col_id & (global_sc_count - 1);
@@ -244,8 +246,11 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
244246
}
245247
// If the ID is a duplicate of the last seen ID, it must have been dropped
246248
// (otherwise it would have been merged above), so drop this one too.
247-
if (bucket_id == prev_bucket_id && col_id == prev_col_id &&
248-
row_id == prev_row_id) {
249+
bool fully_duplicate = col_id == prev_col_id && row_id == prev_row_id;
250+
if constexpr (kCreateBuckets) {
251+
fully_duplicate = fully_duplicate && bucket_id == prev_bucket_id;
252+
}
253+
if (fully_duplicate) {
249254
++stats.dropped_id_count;
250255
continue;
251256
}
@@ -254,24 +259,27 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
254259
// We have a new column if the bucket_id changes (we can't dedupe across
255260
// bucket boundaries) or if the col_id changes within the same bucket. Note
256261
// that multiple col_ids can map to the same bucket.
257-
const bool is_new_col =
258-
(bucket_id != prev_bucket_id || col_id != prev_col_id);
262+
bool is_new_col = col_id != prev_col_id;
263+
if constexpr (kCreateBuckets) {
264+
is_new_col = is_new_col || bucket_id != prev_bucket_id;
265+
}
259266
// Update observed stats. These are never decremented and are used for
260267
// reporting.
268+
// We do NOT drop IDs when minibatching is enabled and we are in the
269+
// first pass (`create_buckets=false`), as we need to detect limit
270+
// overflows to decide if minibatching is required.
271+
const bool can_drop_id = !options.enable_minibatching || kCreateBuckets;
261272
observed_ids(global_sc_id, bucket_id) += 1;
262273
if (is_new_col) {
263274
observed_unique_ids(global_sc_id, bucket_id) += 1;
264-
dropping_current_unique_col_id =
265-
(kept_unique_ids(global_sc_id, bucket_id) + 1) >
266-
max_unique_ids_per_partition;
275+
if (allow_id_dropping && can_drop_id) {
276+
dropping_current_unique_col_id =
277+
(kept_unique_ids(global_sc_id, bucket_id) + 1) >
278+
max_unique_ids_per_partition;
279+
}
267280
}
268281

269282
// Step 4: Determine if the ID should be dropped based on capacity limits.
270-
// We do NOT drop IDs when minibatching is enabled and we are in the
271-
// first pass (`create_buckets=false`), as we need to detect limit
272-
// overflows to decide if minibatching is required.
273-
const bool can_drop_id =
274-
!options.enable_minibatching || context.create_buckets;
275283
const bool exceeds_ids_limit =
276284
(kept_ids(global_sc_id, bucket_id) + 1) > max_ids_per_partition;
277285

@@ -283,9 +291,11 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
283291
} else {
284292
grouped_coo_tensors.Add(context.local_sc_id, bucket_id, coo_tensor);
285293
// Update kept counts.
286-
kept_ids(global_sc_id, bucket_id) += 1;
287-
if (is_new_col) {
288-
kept_unique_ids(global_sc_id, bucket_id) += 1;
294+
if (allow_id_dropping && can_drop_id) {
295+
kept_ids(global_sc_id, bucket_id) += 1;
296+
if (is_new_col) {
297+
kept_unique_ids(global_sc_id, bucket_id) += 1;
298+
}
289299
}
290300
}
291301

@@ -307,8 +317,9 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
307317
// NOTE: We use output buffers `max_ids_per_sc`, `max_unique_ids_per_sc`, and
308318
// `required_buffer_size_per_sc` because we fill values in a loop to a bigger
309319
// array.
310-
template <bool kHasVariableWeights = true, typename SplitType>
311-
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
320+
template <bool kHasVariableWeights = true, bool kCreateBuckets,
321+
typename SplitType>
322+
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDeviceImpl(
312323
const ExtractedCooTensors& extracted_coo_tensors,
313324
const StackedTableMetadata& stacked_table_metadata,
314325
const PreprocessSparseDenseMatmulInputOptions& options,
@@ -329,20 +340,18 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
329340
// This function can be called in two passes for minibatching. The logic for
330341
// stats collection and ID dropping depends on the pass.
331342
//
332-
// Pass 1: Check if minibatching is required (`create_buckets` is false).
343+
// Pass 1: Check if minibatching is required (`kCreateBuckets` is false).
333344
// - No IDs are dropped.
334345
// - Stats are collected on all observed IDs to compute splits.
335346
//
336-
// Pass 2: Create buckets (`create_buckets` is true).
347+
// Pass 2: Create buckets (`kCreateBuckets` is true).
337348
// - A dummy stats object is used (stats are not re-computed).
338349
// - IDs may be dropped if they exceed capacity.
339-
const bool create_buckets = options.enable_minibatching &&
340-
(std::is_same_v<SplitType, MinibatchingSplit>);
341350

342351
// Partition COO tensors among SparseCores for the local device (based on row
343352
// id).
344353
const int bucket_count =
345-
create_buckets ? CooFormat::kMaxMinibatchingBuckets : 1;
354+
kCreateBuckets ? CooFormat::kMaxMinibatchingBuckets : 1;
346355
PartitionedCooTensors grouped_coo_tensors(
347356
coo_tensors.size(), num_sc_per_device, global_sc_count, bucket_count);
348357

@@ -378,7 +387,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
378387
// local_embedding_id(32-num_scs bits), index(26 bits)].
379388
// Note that this assumes `num_scs` is a power of 2.
380389
keys.push_back(coo_tensor.GetGroupingKey(
381-
num_sc_bits, coo_tensor_index, create_buckets,
390+
num_sc_bits, coo_tensor_index, kCreateBuckets,
382391
options.minibatching_bucketing_hash_fn, kHasVariableWeights));
383392
DCHECK(kHasVariableWeights || coo_tensors[coo_tensor_index].gain == 1.0f)
384393
<< "kHasVariableWeights: " << kHasVariableWeights
@@ -394,7 +403,6 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
394403
.coo_tensors = coo_tensors,
395404
.stacked_table_metadata = stacked_table_metadata,
396405
.options = options,
397-
.create_buckets = create_buckets,
398406
.local_sc_id = local_sc_id,
399407
.num_sc_bits = num_sc_bits,
400408
.grouped_coo_tensors = grouped_coo_tensors,
@@ -409,7 +417,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
409417
};
410418

411419
internal::GroupAndDeduplicateCooTensorsForLocalSparseCore<
412-
kHasVariableWeights>(context);
420+
kHasVariableWeights, kCreateBuckets>(context);
413421

414422
grouped_coo_tensors.FillRemainingScBuckets();
415423

@@ -445,7 +453,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
445453

446454
// Only validate if creating minibatching buckets or when minibatching is
447455
// disabled, not when checking if minibatching is required.
448-
if (!options.enable_minibatching || create_buckets)
456+
if (!options.enable_minibatching || kCreateBuckets)
449457
internal::ValidateMaxIdsOrDie(
450458
observed_max_ids_per_bucket, observed_max_unique_ids_per_bucket,
451459
max_ids_per_partition, max_unique_ids_per_partition,
@@ -455,6 +463,26 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
455463
return grouped_coo_tensors;
456464
}
457465

466+
template <bool kHasVariableWeights, typename SplitType>
467+
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
468+
const ExtractedCooTensors& extracted_coo_tensors,
469+
const StackedTableMetadata& stacked_table_metadata,
470+
const PreprocessSparseDenseMatmulInputOptions& options,
471+
internal::StatsPerDevice& stats, SplitType& minibatching_split) {
472+
const bool create_buckets =
473+
options.enable_minibatching &&
474+
std::is_same_v<SplitType, MinibatchingSplit>;
475+
if (create_buckets) {
476+
return SortAndGroupCooTensorsPerLocalDeviceImpl<kHasVariableWeights, true>(
477+
extracted_coo_tensors, stacked_table_metadata, options, stats,
478+
minibatching_split);
479+
} else {
480+
return SortAndGroupCooTensorsPerLocalDeviceImpl<kHasVariableWeights, false>(
481+
extracted_coo_tensors, stacked_table_metadata, options, stats,
482+
minibatching_split);
483+
}
484+
}
485+
458486
} // namespace jax_sc_embedding
459487

460488
#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SORT_AND_GROUP_COO_TENSORS_IMPL_H_

0 commit comments

Comments
 (0)