Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions keras_rs/src/layers/embedding/jax/distributed_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,39 +643,53 @@ def _sparsecore_preprocess(
# Synchronize input statistics across all devices and update the
# underlying stacked tables specs in the feature specs.

# Aggregate stats across all processes/devices via pmax.
# Gather stats across all processes/devices via process_allgather.
all_stats = multihost_utils.process_allgather(stats)
aggregated_stats = jax.tree.map(
lambda x: jnp.max(x, axis=0), all_stats
)
all_stats = jax.tree.map(np.max, all_stats)

# Check if stats changed enough to warrant action.
stacked_table_specs = embedding.get_stacked_table_specs(
self._config.feature_specs
)
changed = any(
np.max(aggregated_stats.max_ids_per_partition[stack_name])
all_stats.max_ids_per_partition[stack_name]
> spec.max_ids_per_partition
or np.max(
aggregated_stats.max_unique_ids_per_partition[stack_name]
)
or all_stats.max_unique_ids_per_partition[stack_name]
> spec.max_unique_ids_per_partition
or (
np.max(
aggregated_stats.required_buffer_size_per_sc[stack_name]
)
* num_sc_per_device
)
or all_stats.required_buffer_size_per_sc[stack_name]
* num_sc_per_device
> (spec.suggested_coo_buffer_size_per_device or 0)
for stack_name, spec in stacked_table_specs.items()
)

# Update configuration and repeat preprocessing if stats changed.
if changed:
for stack_name, spec in stacked_table_specs.items():
all_stats.max_ids_per_partition[stack_name] = np.max(
[
all_stats.max_ids_per_partition[stack_name],
spec.max_ids_per_partition,
]
)
all_stats.max_unique_ids_per_partition[stack_name] = np.max(
[
all_stats.max_unique_ids_per_partition[stack_name],
spec.max_unique_ids_per_partition,
]
)
all_stats.required_buffer_size_per_sc[stack_name] = np.max(
[
all_stats.required_buffer_size_per_sc[stack_name],
(
(spec.suggested_coo_buffer_size_per_device or 0)
+ (num_sc_per_device - 1)
)
// num_sc_per_device,
]
)
Comment on lines +667 to +689
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and to use more idiomatic Python, you can refactor this loop. Using the built-in max() is more direct for comparing two scalar values than np.max() on a list of two elements. Additionally, extracting the calculation for the old required_buffer_size_per_sc into a separate variable makes the logic clearer.

                for stack_name, spec in stacked_table_specs.items():
                    all_stats.max_ids_per_partition[stack_name] = max(
                        all_stats.max_ids_per_partition[stack_name],
                        spec.max_ids_per_partition,
                    )
                    all_stats.max_unique_ids_per_partition[stack_name] = max(
                        all_stats.max_unique_ids_per_partition[stack_name],
                        spec.max_unique_ids_per_partition,
                    )
                    old_required_buffer_size_per_sc = (
                        (spec.suggested_coo_buffer_size_per_device or 0)
                        + (num_sc_per_device - 1)
                    ) // num_sc_per_device
                    all_stats.required_buffer_size_per_sc[stack_name] = max(
                        all_stats.required_buffer_size_per_sc[stack_name],
                        old_required_buffer_size_per_sc,
                    )


embedding.update_preprocessing_parameters(
self._config.feature_specs,
aggregated_stats,
num_sc_per_device,
self._config.feature_specs, all_stats, num_sc_per_device
)

# Re-execute preprocessing with consistent input statistics.
Expand Down