@@ -175,7 +175,6 @@ struct LocalSparseCoreTensorGroupingContext {
175175 absl::Span<const CooFormat> coo_tensors;
176176 const StackedTableMetadata& stacked_table_metadata;
177177 const PreprocessSparseDenseMatmulInputOptions& options;
178- const bool create_buckets;
179178 const int32_t local_sc_id;
180179 const int32_t num_sc_bits;
181180
@@ -184,11 +183,13 @@ struct LocalSparseCoreTensorGroupingContext {
184183 MatrixXi& ids_per_sc_partition_per_bucket;
185184 MatrixXi& unique_ids_per_partition_per_bucket;
186185 StatsPerDevice& stats;
186+ // These are only used for id dropping decisions and can be ignored otherwise.
187187 MatrixXi& kept_ids_per_sc_partition_per_bucket;
188188 MatrixXi& kept_unique_ids_per_partition_per_bucket;
189189};
190190
191- template <bool kHasVariableWeights >
191+
192+ template <bool kHasVariableWeights , bool kCreateBuckets >
192193inline void GroupAndDeduplicateCooTensorsForLocalSparseCore (
193194 LocalSparseCoreTensorGroupingContext context) {
194195 // Unpack context for readability.
@@ -219,7 +220,8 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
219220 const int num_sc_bits = context.num_sc_bits ;
220221 for (const uint64_t key : context.keys ) {
221222 // Step 1: Unpack key to get tensor coordinates.
222- const uint32_t bucket_id = CooFormat::GetBucketIdFromKey (key);
223+ const uint32_t bucket_id =
224+ kCreateBuckets ? CooFormat::GetBucketIdFromKey (key) : 0 ;
223225 const uint32_t col_id =
224226 absl::rotl (CooFormat::GetRotatedColIdFromKey (key), num_sc_bits);
225227 const uint32_t global_sc_id = col_id & (global_sc_count - 1 );
@@ -244,8 +246,11 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
244246 }
245247 // If the ID is a duplicate of the last seen ID, it must have been dropped
246248 // (otherwise it would have been merged above), so drop this one too.
247- if (bucket_id == prev_bucket_id && col_id == prev_col_id &&
248- row_id == prev_row_id) {
249+ bool fully_duplicate = col_id == prev_col_id && row_id == prev_row_id;
250+ if constexpr (kCreateBuckets ) {
251+ fully_duplicate = fully_duplicate && bucket_id == prev_bucket_id;
252+ }
253+ if (fully_duplicate) {
249254 ++stats.dropped_id_count ;
250255 continue ;
251256 }
@@ -254,24 +259,27 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
254259 // We have a new column if the bucket_id changes (we can't dedupe across
255260 // bucket boundaries) or if the col_id changes within the same bucket. Note
256261 // that multiple col_ids can map to the same bucket.
257- const bool is_new_col =
258- (bucket_id != prev_bucket_id || col_id != prev_col_id);
262+ bool is_new_col = col_id != prev_col_id;
263+ if constexpr (kCreateBuckets ) {
264+ is_new_col = is_new_col || bucket_id != prev_bucket_id;
265+ }
259266 // Update observed stats. These are never decremented and are used for
260267 // reporting.
268+ // We do NOT drop IDs when minibatching is enabled and we are in the
269+ // first pass (`create_buckets=false`), as we need to detect limit
270+ // overflows to decide if minibatching is required.
271+ const bool can_drop_id = !options.enable_minibatching || kCreateBuckets ;
261272 observed_ids (global_sc_id, bucket_id) += 1 ;
262273 if (is_new_col) {
263274 observed_unique_ids (global_sc_id, bucket_id) += 1 ;
264- dropping_current_unique_col_id =
265- (kept_unique_ids (global_sc_id, bucket_id) + 1 ) >
266- max_unique_ids_per_partition;
275+ if (allow_id_dropping && can_drop_id) {
276+ dropping_current_unique_col_id =
277+ (kept_unique_ids (global_sc_id, bucket_id) + 1 ) >
278+ max_unique_ids_per_partition;
279+ }
267280 }
268281
269282 // Step 4: Determine if the ID should be dropped based on capacity limits.
270- // We do NOT drop IDs when minibatching is enabled and we are in the
271- // first pass (`create_buckets=false`), as we need to detect limit
272- // overflows to decide if minibatching is required.
273- const bool can_drop_id =
274- !options.enable_minibatching || context.create_buckets ;
275283 const bool exceeds_ids_limit =
276284 (kept_ids (global_sc_id, bucket_id) + 1 ) > max_ids_per_partition;
277285
@@ -283,9 +291,11 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
283291 } else {
284292 grouped_coo_tensors.Add (context.local_sc_id , bucket_id, coo_tensor);
285293 // Update kept counts.
286- kept_ids (global_sc_id, bucket_id) += 1 ;
287- if (is_new_col) {
288- kept_unique_ids (global_sc_id, bucket_id) += 1 ;
294+ if (allow_id_dropping && can_drop_id) {
295+ kept_ids (global_sc_id, bucket_id) += 1 ;
296+ if (is_new_col) {
297+ kept_unique_ids (global_sc_id, bucket_id) += 1 ;
298+ }
289299 }
290300 }
291301
@@ -307,8 +317,9 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
307317// NOTE: We use output buffers `max_ids_per_sc`, `max_unique_ids_per_sc`, and
308318// `required_buffer_size_per_sc` because we fill values in a loop to a bigger
309319// array.
310- template <bool kHasVariableWeights = true , typename SplitType>
311- PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice (
320+ template <bool kHasVariableWeights = true , bool kCreateBuckets ,
321+ typename SplitType>
322+ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDeviceImpl (
312323 const ExtractedCooTensors& extracted_coo_tensors,
313324 const StackedTableMetadata& stacked_table_metadata,
314325 const PreprocessSparseDenseMatmulInputOptions& options,
@@ -329,20 +340,18 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
329340 // This function can be called in two passes for minibatching. The logic for
330341 // stats collection and ID dropping depends on the pass.
331342 //
332- // Pass 1: Check if minibatching is required (`create_buckets ` is false).
343+ // Pass 1: Check if minibatching is required (`kCreateBuckets ` is false).
333344 // - No IDs are dropped.
334345 // - Stats are collected on all observed IDs to compute splits.
335346 //
336- // Pass 2: Create buckets (`create_buckets ` is true).
347+ // Pass 2: Create buckets (`kCreateBuckets ` is true).
337348 // - A dummy stats object is used (stats are not re-computed).
338349 // - IDs may be dropped if they exceed capacity.
339- const bool create_buckets = options.enable_minibatching &&
340- (std::is_same_v<SplitType, MinibatchingSplit>);
341350
342351 // Partition COO tensors among SparseCores for the local device (based on row
343352 // id).
344353 const int bucket_count =
345- create_buckets ? CooFormat::kMaxMinibatchingBuckets : 1 ;
354+ kCreateBuckets ? CooFormat::kMaxMinibatchingBuckets : 1 ;
346355 PartitionedCooTensors grouped_coo_tensors (
347356 coo_tensors.size (), num_sc_per_device, global_sc_count, bucket_count);
348357
@@ -378,7 +387,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
378387 // local_embedding_id(32-num_scs bits), index(26 bits)].
379388 // Note that this assumes `num_scs` is a power of 2.
380389 keys.push_back (coo_tensor.GetGroupingKey (
381- num_sc_bits, coo_tensor_index, create_buckets ,
390+ num_sc_bits, coo_tensor_index, kCreateBuckets ,
382391 options.minibatching_bucketing_hash_fn , kHasVariableWeights ));
383392 DCHECK (kHasVariableWeights || coo_tensors[coo_tensor_index].gain == 1 .0f )
384393 << " kHasVariableWeights: " << kHasVariableWeights
@@ -394,7 +403,6 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
394403 .coo_tensors = coo_tensors,
395404 .stacked_table_metadata = stacked_table_metadata,
396405 .options = options,
397- .create_buckets = create_buckets,
398406 .local_sc_id = local_sc_id,
399407 .num_sc_bits = num_sc_bits,
400408 .grouped_coo_tensors = grouped_coo_tensors,
@@ -409,7 +417,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
409417 };
410418
411419 internal::GroupAndDeduplicateCooTensorsForLocalSparseCore<
412- kHasVariableWeights >(context);
420+ kHasVariableWeights , kCreateBuckets >(context);
413421
414422 grouped_coo_tensors.FillRemainingScBuckets ();
415423
@@ -445,7 +453,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
445453
446454 // Only validate if creating minibatching buckets or when minibatching is
447455 // disabled, not when checking if minibatching is required.
448- if (!options.enable_minibatching || create_buckets )
456+ if (!options.enable_minibatching || kCreateBuckets )
449457 internal::ValidateMaxIdsOrDie (
450458 observed_max_ids_per_bucket, observed_max_unique_ids_per_bucket,
451459 max_ids_per_partition, max_unique_ids_per_partition,
@@ -455,6 +463,26 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
455463 return grouped_coo_tensors;
456464}
457465
466+ template <bool kHasVariableWeights , typename SplitType>
467+ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice (
468+ const ExtractedCooTensors& extracted_coo_tensors,
469+ const StackedTableMetadata& stacked_table_metadata,
470+ const PreprocessSparseDenseMatmulInputOptions& options,
471+ internal::StatsPerDevice& stats, SplitType& minibatching_split) {
472+ const bool create_buckets =
473+ options.enable_minibatching &&
474+ std::is_same_v<SplitType, MinibatchingSplit>;
475+ if (create_buckets) {
476+ return SortAndGroupCooTensorsPerLocalDeviceImpl<kHasVariableWeights , true >(
477+ extracted_coo_tensors, stacked_table_metadata, options, stats,
478+ minibatching_split);
479+ } else {
480+ return SortAndGroupCooTensorsPerLocalDeviceImpl<kHasVariableWeights , false >(
481+ extracted_coo_tensors, stacked_table_metadata, options, stats,
482+ minibatching_split);
483+ }
484+ }
485+
458486} // namespace jax_sc_embedding
459487
460488#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SORT_AND_GROUP_COO_TENSORS_IMPL_H_
0 commit comments