Skip to content

Commit 8b85e5d

Browse files
Cristian GarciaFlax Authors
Cristian Garcia
authored and
Flax Authors
committed
convert DenseGeneral to NNX
PiperOrigin-RevId: 748311465
1 parent aa89779 commit 8b85e5d

File tree

4 files changed

+128
-86
lines changed

4 files changed

+128
-86
lines changed

flax/nnx/bridge/wrappers.py

Lines changed: 106 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -185,20 +185,27 @@ def __call__(
185185
return out
186186

187187

188-
def linen_rngs_dict(linen_module: linen.Module) -> tp.Mapping[str, jax.Array]:
188+
def linen_rngs_dict(linen_module: linen.Module, add_default: bool = False):
189189
"""Given a module, split out one of its every active RNG key collections."""
190190
assert linen_module.scope is not None, 'linen_rngs_dict() must be called inside a Linen module.'
191-
return {name: linen_module.make_rng(name)
192-
for name in linen_module.scope.rngs.keys()}
191+
rngs: dict[str, tp.Any] = {
192+
name: linen_module.make_rng(name)
193+
for name in linen_module.scope.rngs.keys()
194+
}
195+
if add_default and 'default' not in rngs:
196+
rngs['default'] = 0
197+
return rngs
193198

194199

195200
class ToLinen(linen.Module):
196201
"""A wrapper to turn any NNX module into a Linen module.
197202
198-
The result Linen module can be used standalone with all Linen APIs, or as a submodule of
203+
The result Linen module can be used standalone with all Linen APIs, or as a
204+
submodule of
199205
another Linen module.
200206
201-
Since NNX modules are stateful and owns the state, we only create it once during init
207+
Since NNX modules are stateful and owns the state, we only create it once
208+
during init
202209
time, and will track its state and static data as separate variables.
203210
204211
Example::
@@ -214,15 +221,16 @@ class ToLinen(linen.Module):
214221
(32, 64)
215222
>>> # The static GraphDef of the underlying NNX module
216223
>>> variables.keys()
217-
dict_keys(['nnx', 'params'])
218-
>>> type(variables['nnx']['graphdef'])
219-
<class 'flax.nnx.graph.GraphDef'>
224+
dict_keys(['params'])
220225
221226
Args:
222227
nnx_class: The NNX Module class (not instance!).
223-
args: The arguments that normally would be passed in to create the NNX module.
224-
kwargs: The keyword arguments that normally would be passed in to create the NNX module.
225-
skip_rng: True if this NNX module doesn't need `rngs` arg during initialization (not common).
228+
args: The arguments that normally would be passed in to create the NNX
229+
module.
230+
kwargs: The keyword arguments that normally would be passed in to create the
231+
NNX module.
232+
skip_rng: True if this NNX module doesn't need `rngs` arg during
233+
initialization (not common).
226234
227235
Returns:
228236
A stateful NNX module that behaves the same as the wrapped Linen module.
@@ -231,59 +239,108 @@ class ToLinen(linen.Module):
231239
args: tp.Sequence = ()
232240
kwargs: tp.Mapping[str, tp.Any] = FrozenDict({})
233241
skip_rng: bool = False
234-
metadata_type: tp.Type = bv.NNXMeta
242+
metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = (
243+
bv.to_linen_var
244+
)
235245

236246
@linen.compact
237247
def __call__(self, *args, **kwargs):
248+
module_kwargs = dict(self.kwargs)
249+
maybe_add_default = not self.is_initializing()
250+
def _module_kwargs():
251+
if not self.skip_rng:
252+
module_kwargs['rngs'] = nnx.Rngs(
253+
**linen_rngs_dict(self, add_default=maybe_add_default)
254+
)
255+
return module_kwargs
256+
238257
# init codepath
239258
if self.is_initializing():
240-
module_kwargs = dict(self.kwargs)
241-
if not self.skip_rng:
242-
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self)))
243-
module = self.nnx_class(*self.args, **module_kwargs)
259+
module = self.nnx_class(*self.args, **_module_kwargs())
244260
# TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`.
261+
# update linen variables before call module to save initial state
245262
self._update_variables(module)
246-
return module(*args, **kwargs)
247-
248-
# apply codepath
249-
gdef = self.get_variable('nnx', 'graphdef')
250-
assert gdef, 'GraphDef not found in variables. Was the collection "nnx" dropped somewhere?'
251-
variables = {col: v for col, v in self.variables.items() if col != 'nnx'}
252-
states = jtu.tree_map_with_path(
253-
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x).to_state(),
254-
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
255-
states = [State(v) for v in states.values()]
256-
nnx_state = nnx.merge_state(*states) if states else nnx.GraphState({})
257-
module = nnx.merge(gdef, nnx_state)
258-
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
263+
out = module(*args, **kwargs)
264+
return out
265+
266+
# create state
267+
def maybe_unbox(x):
268+
if isinstance(x, meta.AxisMetadata):
269+
return x.unbox()
270+
return x
271+
states = jtu.tree_map(
272+
maybe_unbox,
273+
list(self.variables.values()),
274+
is_leaf=lambda x: isinstance(x, meta.AxisMetadata),
275+
)
276+
if not states:
277+
states = ({},)
278+
279+
# update module state
280+
module = nnx.eval_shape(
281+
lambda: self.nnx_class(*self.args, **_module_kwargs())
282+
)
283+
nnx.update(module, *states)
284+
nnx.reseed(
285+
module, **linen_rngs_dict(self, add_default=maybe_add_default)
286+
) # reseed with keys from linen apply call.
287+
259288
out = module(*args, **kwargs)
260289
self._update_variables(module)
261290
return out
262291

263292
def _update_variables(self, module):
264293
"""Store the NNX module's graph def and state inside Linen module variables."""
265-
gdef, state = nnx.split(module)
266-
# Save the graph def.
267-
if self.is_mutable_collection('nnx'):
268-
self.put_variable('nnx', 'graphdef', gdef)
269-
# Sort all the variable types.
270-
types = set(jax.tree.leaves(
271-
jax.tree.map(lambda x: x.type, state,
272-
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
273-
types = bv.sort_variable_types(types)
274-
_, *state_by_types = nnx.split(module, *types)
275-
# Each variable type goes to its own linen collection, and
276-
# each attribute goes to its own linen variable
277-
for typ, state in zip(types, state_by_types):
278-
collection = variablelib.variable_name_from_type(typ, allow_register=True)
294+
state = nnx.state(module, nnx.Not(nnx.RngState))
295+
296+
collection_flat_state: dict[str, list[tuple[tuple[str, ...], tp.Any]]] = {}
297+
298+
# group state by collection
299+
for path, leaf in nnx.to_flat_state(state):
300+
type_ = leaf.type if isinstance(leaf, nnx.VariableState) else type(leaf)
301+
collection = variablelib.variable_name_from_type(
302+
type_, allow_register=True
303+
)
304+
if collection not in collection_flat_state:
305+
collection_flat_state[collection] = []
306+
collection_flat_state[collection].append((path, leaf))
307+
308+
# update linen variables
309+
for collection, flat_state in collection_flat_state.items():
279310
if self.is_mutable_collection(collection):
280-
for k, v in state.raw_mapping.items():
281-
v = jax.tree.map(bv.to_linen_var, v,
282-
is_leaf=lambda x: isinstance(x, nnx.VariableState))
311+
312+
def _to_linen_var(x):
313+
if isinstance(x, nnx.VariableState):
314+
if self.metadata_fn:
315+
return self.metadata_fn(x)
316+
else:
317+
return x.value
318+
return x
319+
320+
collection_state = nnx.traversals.unflatten_mapping(flat_state)
321+
collection_state = jax.tree.map(
322+
_to_linen_var,
323+
collection_state,
324+
is_leaf=lambda x: isinstance(x, nnx.VariableState),
325+
)
326+
for k, v in collection_state.items():
283327
self.put_variable(collection, k, v)
284328

285329

286-
def to_linen(nnx_class: tp.Callable[..., Module], *args,
287-
name: str | None = None, **kwargs):
330+
def to_linen(
331+
nnx_class: tp.Callable[..., Module],
332+
*args,
333+
metadata_fn: (
334+
tp.Callable[[variablelib.VariableState], tp.Any] | None
335+
) = bv.to_linen_var,
336+
name: str | None = None,
337+
**kwargs,
338+
):
288339
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
289-
return ToLinen(nnx_class, args=args, kwargs=FrozenDict(kwargs), name=name)
340+
return ToLinen(
341+
nnx_class,
342+
args=args,
343+
kwargs=FrozenDict(kwargs),
344+
metadata_fn=metadata_fn,
345+
name=name,
346+
)

flax/nnx/rnglib.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,9 @@ class Rngs(Object):
157157
"""
158158

159159
def __init__(
160-
self,
161-
default: RngValue | RngDict | None = None,
162-
/,
163-
**rngs: RngValue,
160+
self,
161+
default: RngValue | RngDict | None = None,
162+
**rngs: RngValue,
164163
):
165164
"""
166165
Args:

flax/nnx/traversals.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from __future__ import annotations
1818

1919
from collections.abc import Callable, Mapping
20+
from collections.abc import Iterable, Sequence
2021
from typing import Any, overload
21-
from collections.abc import Iterable
2222

2323
from flax import struct
2424

@@ -167,28 +167,25 @@ def _flatten(xs: Any, prefix: tuple[Any, ...]):
167167

168168

169169
@overload
170-
def unflatten_mapping(xs: Mapping[tuple[Any, ...], Any],
171-
/,
172-
*,
173-
sep: None = None
174-
) -> dict[Any, Any]:
170+
def unflatten_mapping(
171+
xs: Sequence[tuple[tuple[Any, ...], Any]], /, *, sep: None = None
172+
) -> dict[Any, Any]:
175173
...
176174

177175

178176
@overload
179-
def unflatten_mapping(xs: Mapping[str, Any],
180-
/,
181-
*,
182-
sep: str
183-
) -> dict[Any, Any]:
177+
def unflatten_mapping(
178+
xs: Mapping[tuple[Any, ...], Any], /, *, sep: None = None
179+
) -> dict[Any, Any]:
184180
...
185181

186182

187-
def unflatten_mapping(xs: Any,
188-
/,
189-
*,
190-
sep: str | None = None
191-
) -> dict[Any, Any]:
183+
@overload
184+
def unflatten_mapping(xs: Mapping[str, Any], /, *, sep: str) -> dict[Any, Any]:
185+
...
186+
187+
188+
def unflatten_mapping(xs: Any, /, *, sep: str | None = None) -> dict[Any, Any]:
192189
"""Unflatten a mapping.
193190
194191
See ``flatten_mapping``

tests/nnx/bridge/wrappers_test.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,6 @@ def test_nnx_to_linen(self):
253253
y, variables = model.init_with_output(jax.random.key(0), x)
254254
assert y.shape == (1, 64)
255255
np.testing.assert_allclose(y, x @ variables['params']['kernel'])
256-
assert 'nnx' in variables
257-
assert isinstance(variables['nnx']['graphdef'], nnx.GraphDef)
258256

259257
def test_nnx_to_linen_multiple_rngs(self):
260258
class NNXInner(nnx.Module):
@@ -291,7 +289,7 @@ def __call__(self, x):
291289
x = jax.random.normal(xkey, (2, 4))
292290
model = bridge.to_linen(NNXInner, 4, 3)
293291
var = model.init({'params': pkey, 'dropout': dkey}, x)
294-
self.assertSameElements(var.keys(), ['nnx', 'LoRAParam', 'params', 'batch_stats'])
292+
self.assertSameElements(var.keys(), ['LoRAParam', 'params', 'batch_stats'])
295293
y = model.apply(var, x)
296294
assert y.shape == (2, 3)
297295

@@ -332,15 +330,7 @@ def __call__(self):
332330
variables = model.init(jax.random.key(0))
333331
assert variables['Count']['count'] == 0
334332

335-
# This does not work, because the __call__ also changes the static data of the model.
336-
_, updates = model.apply(variables, mutable='Count')
337-
assert updates['Count']['count'] == 1
338-
assert updates['Count']['count_nonzero'] == 1
339-
with self.assertRaises(ValueError):
340-
_ = model.apply(variables | updates)
341-
342-
# This makes sure the static data is updated too. Using mutable=True also works.
343-
_, updates = model.apply(variables, mutable=['Count', 'nnx'])
333+
_, updates = model.apply(variables, mutable=['Count'])
344334
assert updates['Count']['count'] == 1
345335
assert updates['Count']['count_nonzero'] == 1
346336
_ = model.apply(variables | updates)
@@ -351,9 +341,9 @@ class LinenOuter(nn.Module):
351341
@nn.compact
352342
def __call__(self, x):
353343
inner = nn.vmap(
354-
bridge.ToLinen,
355-
variable_axes={'params': 0, 'nnx': None},
356-
split_rngs={'params': True}
344+
bridge.ToLinen,
345+
variable_axes={'params': 0},
346+
split_rngs={'params': True},
357347
)(nnx.Linear, args=(x.shape[-1], self.dout))
358348
return inner(x)
359349

@@ -364,7 +354,6 @@ def __call__(self, x):
364354
k = var['params']['VmapToLinen_0']['kernel']
365355
assert k.shape == (2, 4, 3)
366356
np.testing.assert_allclose(y, jnp.einsum('ab,abc->ac', x, k))
367-
assert 'nnx' in var
368357

369358
def test_nnx_to_linen_metadata(self):
370359
model = bridge.to_linen(
@@ -420,7 +409,7 @@ def __call__(self, x):
420409
def get_weights(variables):
421410
non_rngs = {}
422411
for kp, v in flax.traverse_util.flatten_dict(variables).items():
423-
if 'rngs' not in kp and 'nnx' not in kp:
412+
if 'rngs' not in kp:
424413
non_rngs[kp] = v
425414
return flax.traverse_util.unflatten_dict(non_rngs)
426415
from_top_weights = get_weights(from_top)

0 commit comments

Comments
 (0)