Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions keras_rs/src/layers/embedding/jax/distributed_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""JAX implementation of the TPU embedding layer."""

import dataclasses
import math
import typing
from typing import Any, Mapping, Sequence, Union
Expand Down Expand Up @@ -446,29 +445,48 @@ def sparsecore_build(
table_specs = embedding.get_table_specs(feature_specs)
table_stacks = jte_table_stacking.get_table_stacks(table_specs)

# Create new instances of StackTableSpec with updated values that are
# the maximum from stacked tables.
stacked_table_specs = embedding.get_stacked_table_specs(feature_specs)
stacked_table_specs = {
stack_name: dataclasses.replace(
stacked_table_spec,
max_ids_per_partition=max(
table.max_ids_per_partition
for table in table_stacks[stack_name]
),
max_unique_ids_per_partition=max(
table.max_unique_ids_per_partition
for table in table_stacks[stack_name]
),
# Update stacked table stats to max of values across involved tables.
max_ids_per_partition = {}
max_unique_ids_per_partition = {}
required_buffer_size_per_device = {}
id_drop_counters = {}
for stack_name, stack in table_stacks.items():
max_ids_per_partition[stack_name] = np.max(
np.asarray(
[s.max_ids_per_partition for s in stack], dtype=np.int32
)
)
max_unique_ids_per_partition[stack_name] = np.max(
np.asarray(
[s.max_unique_ids_per_partition for s in stack],
dtype=np.int32,
)
)
for stack_name, stacked_table_spec in stacked_table_specs.items()
}

# Rewrite the stacked_table_spec in all TableSpecs.
for stack_name, table_specs in table_stacks.items():
stacked_table_spec = stacked_table_specs[stack_name]
for table_spec in table_specs:
table_spec.stacked_table_spec = stacked_table_spec
# Only set the suggested buffer size if set on any individual table.
valid_buffer_sizes = [
s.suggested_coo_buffer_size_per_device
for s in stack
if s.suggested_coo_buffer_size_per_device is not None
]
if valid_buffer_sizes:
required_buffer_size_per_device[stack_name] = np.max(
np.asarray(valid_buffer_sizes, dtype=np.int32)
)

id_drop_counters[stack_name] = 0

aggregated_stats = embedding.SparseDenseMatmulInputStats(
max_ids_per_partition=max_ids_per_partition,
max_unique_ids_per_partition=max_unique_ids_per_partition,
required_buffer_size_per_sc=required_buffer_size_per_device,
id_drop_counters=id_drop_counters,
)
embedding.update_preprocessing_parameters(
feature_specs,
aggregated_stats,
num_sc_per_device,
)

# Create variables for all stacked tables and slot variables.
with sparsecore_distribution.scope():
Expand Down