Skip to content

Commit 023dbd7

Browse files
committed
Prevent max ids per partition fluctuating by always taking the max.
There are two scenarios that currently cause the numbers to go down: 1. The numbers went up for one stack, but slightly lower for another stack. Because we pass in all the stacks to update_preprocessing_parameters it will update all the specs for all the stacks, meaning it will put lower values in the other stacks. So we could filter out the stacks that went down, but it still doesn't cover the next case. 2. For one stack, the `max_ids_per_partition` went up, but the `max_unique_ids_per_partition` went down for a batch for whatever reason. Then it will update the stats, it will put the higher value for `max_ids_per_partition` but it will put a lower value for `max_unique_ids_per_partition`. So instead of converging to the upper bounds, the max_ids_per_partitionand max_unique_ids_per_partition keep going up and down batch after batch. The fix is to do a max to take the max between the current `StackedTableSpec` and the incoming max values.
1 parent 5d7f18a commit 023dbd7

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -643,39 +643,53 @@ def _sparsecore_preprocess(
643643
# Synchronize input statistics across all devices and update the
644644
# underlying stacked tables specs in the feature specs.
645645

646-
# Aggregate stats across all processes/devices via pmax.
646+
# Gather stats across all processes/devices via process_allgather.
647647
all_stats = multihost_utils.process_allgather(stats)
648-
aggregated_stats = jax.tree.map(
649-
lambda x: jnp.max(x, axis=0), all_stats
650-
)
648+
all_stats = jax.tree.map(np.max, all_stats)
651649

652650
# Check if stats changed enough to warrant action.
653651
stacked_table_specs = embedding.get_stacked_table_specs(
654652
self._config.feature_specs
655653
)
656654
changed = any(
657-
np.max(aggregated_stats.max_ids_per_partition[stack_name])
655+
all_stats.max_ids_per_partition[stack_name]
658656
> spec.max_ids_per_partition
659-
or np.max(
660-
aggregated_stats.max_unique_ids_per_partition[stack_name]
661-
)
657+
or all_stats.max_unique_ids_per_partition[stack_name]
662658
> spec.max_unique_ids_per_partition
663-
or (
664-
np.max(
665-
aggregated_stats.required_buffer_size_per_sc[stack_name]
666-
)
667-
* num_sc_per_device
668-
)
659+
or all_stats.required_buffer_size_per_sc[stack_name]
660+
* num_sc_per_device
669661
> (spec.suggested_coo_buffer_size_per_device or 0)
670662
for stack_name, spec in stacked_table_specs.items()
671663
)
672664

673665
# Update configuration and repeat preprocessing if stats changed.
674666
if changed:
667+
for stack_name, spec in stacked_table_specs.items():
668+
all_stats.max_ids_per_partition[stack_name] = np.max(
669+
[
670+
all_stats.max_ids_per_partition[stack_name],
671+
spec.max_ids_per_partition,
672+
]
673+
)
674+
all_stats.max_unique_ids_per_partition[stack_name] = np.max(
675+
[
676+
all_stats.max_unique_ids_per_partition[stack_name],
677+
spec.max_unique_ids_per_partition,
678+
]
679+
)
680+
all_stats.required_buffer_size_per_sc[stack_name] = np.max(
681+
[
682+
all_stats.required_buffer_size_per_sc[stack_name],
683+
(
684+
(spec.suggested_coo_buffer_size_per_device or 0)
685+
+ (num_sc_per_device - 1)
686+
)
687+
// num_sc_per_device,
688+
]
689+
)
690+
675691
embedding.update_preprocessing_parameters(
676-
self._config.feature_specs,
677-
aggregated_stats,
678-
num_sc_per_device,
692+
self._config.feature_specs, all_stats, num_sc_per_device
679693
)
680694

681695
# Re-execute preprocessing with consistent input statistics.

0 commit comments

Comments
 (0)