diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index c4e647b..2b71340 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -643,39 +643,53 @@ def _sparsecore_preprocess( # Synchronize input statistics across all devices and update the # underlying stacked tables specs in the feature specs. - # Aggregate stats across all processes/devices via pmax. + # Gather stats across all processes/devices via process_allgather. all_stats = multihost_utils.process_allgather(stats) - aggregated_stats = jax.tree.map( - lambda x: jnp.max(x, axis=0), all_stats - ) + all_stats = jax.tree.map(np.max, all_stats) # Check if stats changed enough to warrant action. stacked_table_specs = embedding.get_stacked_table_specs( self._config.feature_specs ) changed = any( - np.max(aggregated_stats.max_ids_per_partition[stack_name]) + all_stats.max_ids_per_partition[stack_name] > spec.max_ids_per_partition - or np.max( - aggregated_stats.max_unique_ids_per_partition[stack_name] - ) + or all_stats.max_unique_ids_per_partition[stack_name] > spec.max_unique_ids_per_partition - or ( - np.max( - aggregated_stats.required_buffer_size_per_sc[stack_name] - ) - * num_sc_per_device - ) + or all_stats.required_buffer_size_per_sc[stack_name] + * num_sc_per_device > (spec.suggested_coo_buffer_size_per_device or 0) for stack_name, spec in stacked_table_specs.items() ) # Update configuration and repeat preprocessing if stats changed. if changed: + for stack_name, spec in stacked_table_specs.items(): + all_stats.max_ids_per_partition[stack_name] = np.max( + [ + all_stats.max_ids_per_partition[stack_name], + spec.max_ids_per_partition, + ] + ) + all_stats.max_unique_ids_per_partition[stack_name] = np.max( + [ + all_stats.max_unique_ids_per_partition[stack_name], + spec.max_unique_ids_per_partition, + ] + ) + all_stats.required_buffer_size_per_sc[stack_name] = np.max( + [ + all_stats.required_buffer_size_per_sc[stack_name], + ( + (spec.suggested_coo_buffer_size_per_device or 0) + + (num_sc_per_device - 1) + ) + // num_sc_per_device, + ] + ) + embedding.update_preprocessing_parameters( - self._config.feature_specs, - aggregated_stats, - num_sc_per_device, + self._config.feature_specs, all_stats, num_sc_per_device ) # Re-execute preprocessing with consistent input statistics.