Skip to content

Commit 51d5c82

Browse files
authored
Add heuristic for max ids of stacked tables. (#167)
Propagate the `max_ids_per_partition` and `max_unique_ids_per_partition` from `TableSpec`s to `StackedTableSpec`s by taking the max from the stacked tables.
1 parent 5713afe commit 51d5c82

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""JAX implementation of the TPU embedding layer."""
22

3+
import dataclasses
34
import math
45
import typing
56
from typing import Any, Mapping, Sequence, Union
@@ -445,6 +446,30 @@ def sparsecore_build(
445446
table_specs = embedding.get_table_specs(feature_specs)
446447
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
447448

449+
# Create new instances of StackTableSpec with updated values that are
450+
# the maximum from stacked tables.
451+
stacked_table_specs = embedding.get_stacked_table_specs(feature_specs)
452+
stacked_table_specs = {
453+
stack_name: dataclasses.replace(
454+
stacked_table_spec,
455+
max_ids_per_partition=max(
456+
table.max_ids_per_partition
457+
for table in table_stacks[stack_name]
458+
),
459+
max_unique_ids_per_partition=max(
460+
table.max_unique_ids_per_partition
461+
for table in table_stacks[stack_name]
462+
),
463+
)
464+
for stack_name, stacked_table_spec in stacked_table_specs.items()
465+
}
466+
467+
# Rewrite the stacked_table_spec in all TableSpecs.
468+
for stack_name, table_specs in table_stacks.items():
469+
stacked_table_spec = stacked_table_specs[stack_name]
470+
for table_spec in table_specs:
471+
table_spec.stacked_table_spec = stacked_table_spec
472+
448473
# Create variables for all stacked tables and slot variables.
449474
with sparsecore_distribution.scope():
450475
self._table_and_slot_variables = {

0 commit comments

Comments
 (0)