diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index f27e109..c4e647b 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -1,6 +1,5 @@ """JAX implementation of the TPU embedding layer.""" -import dataclasses import math import typing from typing import Any, Mapping, Sequence, Union @@ -446,29 +445,48 @@ def sparsecore_build( table_specs = embedding.get_table_specs(feature_specs) table_stacks = jte_table_stacking.get_table_stacks(table_specs) - # Create new instances of StackTableSpec with updated values that are - # the maximum from stacked tables. - stacked_table_specs = embedding.get_stacked_table_specs(feature_specs) - stacked_table_specs = { - stack_name: dataclasses.replace( - stacked_table_spec, - max_ids_per_partition=max( - table.max_ids_per_partition - for table in table_stacks[stack_name] - ), - max_unique_ids_per_partition=max( - table.max_unique_ids_per_partition - for table in table_stacks[stack_name] - ), + # Update stacked table stats to max of values across involved tables. + max_ids_per_partition = {} + max_unique_ids_per_partition = {} + required_buffer_size_per_device = {} + id_drop_counters = {} + for stack_name, stack in table_stacks.items(): + max_ids_per_partition[stack_name] = np.max( + np.asarray( + [s.max_ids_per_partition for s in stack], dtype=np.int32 + ) + ) + max_unique_ids_per_partition[stack_name] = np.max( + np.asarray( + [s.max_unique_ids_per_partition for s in stack], + dtype=np.int32, + ) ) - for stack_name, stacked_table_spec in stacked_table_specs.items() - } - # Rewrite the stacked_table_spec in all TableSpecs. - for stack_name, table_specs in table_stacks.items(): - stacked_table_spec = stacked_table_specs[stack_name] - for table_spec in table_specs: - table_spec.stacked_table_spec = stacked_table_spec + # Only set the suggested buffer size if set on any individual table. + valid_buffer_sizes = [ + s.suggested_coo_buffer_size_per_device + for s in stack + if s.suggested_coo_buffer_size_per_device is not None + ] + if valid_buffer_sizes: + required_buffer_size_per_device[stack_name] = np.max( + np.asarray(valid_buffer_sizes, dtype=np.int32) + ) + + id_drop_counters[stack_name] = 0 + + aggregated_stats = embedding.SparseDenseMatmulInputStats( + max_ids_per_partition=max_ids_per_partition, + max_unique_ids_per_partition=max_unique_ids_per_partition, + required_buffer_size_per_sc=required_buffer_size_per_device, + id_drop_counters=id_drop_counters, + ) + embedding.update_preprocessing_parameters( + feature_specs, + aggregated_stats, + num_sc_per_device, + ) # Create variables for all stacked tables and slot variables. with sparsecore_distribution.scope():