Skip to content

Commit d8df0d7

Browse files
committed
Rename sharding_names to sharding_metadata
1 parent 7710c30 commit d8df0d7

File tree

6 files changed

+30
-18
lines changed

6 files changed

+30
-18
lines changed

flax/core/meta.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,13 +300,13 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:
300300
def to_nnx_metadata(self) -> dict[str, Any]:
301301
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
302302
metadata = dict(vars(self))
303-
metadata['sharding_names'] = metadata.pop('names')
303+
metadata['sharding_metadata'] = metadata.pop('names')
304304
return metadata
305305

306306
@classmethod
307307
def from_nnx_metadata(cls, metadata: dict[str, Any]):
308308
"""Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`."""
309-
metadata['names'] = metadata.pop('sharding_names')
309+
metadata['names'] = metadata.pop('sharding_metadata')
310310
fields = {x.name for x in dataclasses.fields(cls)}
311311
return cls(**{k: v for k, v in metadata.items() if k in fields})
312312

flax/core/spmd.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def shard_value(value, sharding_names, sharding_rules, mesh):
4545
f' with annotation {sharding_names=}. '
4646
'For more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.')
4747
pspec = get_pspec(sharding_names, sharding_rules)
48+
if isinstance(sharding_names, NamedSharding) and mesh is not None:
49+
assert sharding_names.mesh == mesh
4850
if mesh is not None:
4951
return _apply_sharding(value, NamedSharding(mesh, pspec))
5052
return _apply_sharding(value, pspec)
@@ -107,8 +109,10 @@ def composite_rules(rule1, rule2):
107109

108110

109111
def from_sharding_rules(
110-
sharding: Sharding, sharding_rules: LogicalRules
112+
sharding, sharding_rules: LogicalRules
111113
) -> Sharding:
114+
if isinstance(sharding, NamedSharding):
115+
sharding = sharding.spec
112116
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
113117
return tuple(
114118
rules[str(s)] if (s and str(s) in rules) else s for s in sharding

flax/linen/spmd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,15 +290,15 @@ def to_nnx_metadata(self) -> dict[str, Any]:
290290
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
291291
metadata = vars(self)
292292
if 'names' in metadata:
293-
metadata['sharding_names'] = metadata.pop('names')
293+
metadata['sharding_metadata'] = metadata.pop('names')
294294
if 'rules' in metadata:
295295
metadata['sharding_rules'] = metadata.pop('rules')
296296
return metadata
297297

298298
@classmethod
299299
def from_nnx_metadata(cls, metadata: dict[str, Any]):
300300
"""Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`."""
301-
metadata['names'] = metadata.pop('sharding_names')
301+
metadata['names'] = metadata.pop('sharding_metadata')
302302
metadata['rules'] = metadata.pop('sharding_rules')
303303
fields = {x.name for x in dataclasses.fields(cls)}
304304
return cls(**{k: v for k, v in metadata.items() if k in fields})

flax/nnx/spmd.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def insert_field(fields, index, value):
4545
def _add_axis(x: tp.Any):
4646
if isinstance(x, variablelib.Variable):
4747
metadata = x.get_metadata()
48-
if 'sharding_names' in metadata and metadata['sharding_names']:
49-
sharding = metadata['sharding_names']
48+
if 'sharding_metadata' in metadata and metadata['sharding_metadata']:
49+
sharding = metadata['sharding_metadata']
5050
x.set_metadata(sharding_names=insert_field(sharding, index, axis_name))
5151

5252
for k, v in other_meta.items():
@@ -74,7 +74,7 @@ def remove_field(fields, index, value):
7474

7575
def _remove_axis(x: tp.Any):
7676
if isinstance(x, variablelib.Variable):
77-
if hasattr(x, 'sharding_names') and x.sharding_names is not None:
77+
if hasattr(x, 'sharding_metadata') and x.sharding_names is not None:
7878
x.set_metadata(
7979
sharding_names=remove_field(x.sharding_names, index, axis_name)
8080
)
@@ -119,7 +119,7 @@ def with_partitioning(
119119
"""A wrapper over any initializer to add sharding annotation data to a `Variable`."""
120120
return variablelib.with_metadata(
121121
initializer,
122-
sharding_names=sharding,
122+
sharding_metadata=sharding,
123123
mesh=mesh,
124124
**metadata,
125125
)
@@ -128,8 +128,8 @@ def with_partitioning(
128128
def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None:
129129
"""Given an `nnx.Variable`, return its `PartitionSpec`."""
130130
metadata = v.get_metadata()
131-
if 'sharding_names' in metadata and metadata['sharding_names']:
132-
sharding = metadata['sharding_names']
131+
if 'sharding_metadata' in metadata and metadata['sharding_metadata']:
132+
sharding = metadata['sharding_metadata']
133133
if core_spmd.get_logical_axis_rules() or 'sharding_rules' in metadata:
134134
context_rules = core_spmd.get_logical_axis_rules()
135135
local_rules = metadata.get('sharding_rules', ())
@@ -174,4 +174,4 @@ def get_abstract_model(init_fn, mesh):
174174
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
175175
abs_state, get_named_sharding(abs_state, mesh)
176176
)
177-
return gdef, abs_state
177+
return gdef, abs_state

flax/nnx/variablelib.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import threading
2222
import typing as tp
2323
from typing import Any
24+
import warnings
2425
from flax import config
2526

2627
import jax
@@ -375,16 +376,20 @@ def __init__(
375376
metadata['on_remove_axis'] = var_t.on_remove_axis
376377

377378
if 'sharding' in metadata:
378-
metadata['sharding_names'] = metadata.pop('sharding')
379+
metadata['sharding_metadata'] = metadata.pop('sharding')
380+
381+
if 'sharding_names' in metadata: # for bw compat
382+
warnings.warn("'sharding_names' is deprecated. Use 'sharding_metadata' instead.", DeprecationWarning)
383+
metadata['sharding_metadata'] = metadata.pop('sharding_names')
379384

380385
object.__setattr__(self, '_var_metadata', metadata)
381386
# run create_value hooks
382387
value = self.create_value(self.raw_value)
383388

384389
# shard the value if applicable
385-
if metadata.get('eager_sharding', using_eager_sharding()) and 'sharding_names' in metadata:
390+
if metadata.get('eager_sharding', using_eager_sharding()) and 'sharding_metadata' in metadata:
386391
value = core_spmd.shard_value(
387-
value, metadata['sharding_names'], metadata.get('sharding_rules', None),
392+
value, metadata['sharding_metadata'], metadata.get('sharding_rules', None),
388393
metadata.get('mesh', None))
389394

390395
# Create the ref out of the array value
@@ -394,6 +399,9 @@ def __init__(
394399
object.__setattr__(self, 'raw_value', value)
395400

396401
def __getattr__(self, name: str) -> tp.Any:
402+
if name == 'sharding_names': # for backward compatibility
403+
warnings.warn("'sharding_names' is deprecated. Use 'sharding_metadata' instead.", DeprecationWarning)
404+
return self.sharding_metadata
397405
if name in object.__getattribute__(self, '_var_metadata'):
398406
return self._var_metadata[name]
399407
return getattr(self.raw_value, name)

tests/nnx/bridge/wrappers_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def test_nnx_to_linen_metadata(self):
410410
pspec_tree = nn.get_partition_spec(variables)
411411
assert y.shape == (1, 64)
412412
self.assertIsInstance(variables['params']['kernel'], nnx.bridge.NNXMeta)
413-
assert variables['params']['kernel'].metadata['sharding_names'] == ('in', 'out')
413+
assert variables['params']['kernel'].metadata['sharding_metadata'] == ('in', 'out')
414414
self.assertEqual(pspec_tree['params']['kernel'],
415415
jax.sharding.PartitionSpec('in', 'out'))
416416
np.testing.assert_allclose(y, x @ variables['params']['kernel'].value)
@@ -519,8 +519,8 @@ def __call__(self, x):
519519
w, b = model.inner.dot['w'], model.inner.b
520520
np.testing.assert_allclose(model(x), x @ w + b)
521521
self.assertIsInstance(w, nnx.Param)
522-
assert hasattr(w, 'sharding_names') and w.sharding_names == ('in', 'out')
523-
assert hasattr(b, 'sharding_names') and b.sharding_names == ('out-alias', )
522+
assert hasattr(w, 'sharding_metadata') and w.sharding_metadata == ('in', 'out')
523+
assert hasattr(b, 'sharding_metadata') and b.sharding_metadata == ('out-alias', )
524524

525525
def test_linen_nnx_linen(self):
526526
# TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without

0 commit comments

Comments
 (0)