@@ -185,20 +185,27 @@ def __call__(
185
185
return out
186
186
187
187
188
- def linen_rngs_dict (linen_module : linen .Module ) -> tp . Mapping [ str , jax . Array ] :
188
+ def linen_rngs_dict (linen_module : linen .Module , add_default : bool = False ) :
189
189
"""Given a module, split out one of its every active RNG key collections."""
190
190
assert linen_module .scope is not None , 'linen_rngs_dict() must be called inside a Linen module.'
191
- return {name : linen_module .make_rng (name )
192
- for name in linen_module .scope .rngs .keys ()}
191
+ rngs : dict [str , tp .Any ] = {
192
+ name : linen_module .make_rng (name )
193
+ for name in linen_module .scope .rngs .keys ()
194
+ }
195
+ if add_default and 'default' not in rngs :
196
+ rngs ['default' ] = 0
197
+ return rngs
193
198
194
199
195
200
class ToLinen (linen .Module ):
196
201
"""A wrapper to turn any NNX module into a Linen module.
197
202
198
- The result Linen module can be used standalone with all Linen APIs, or as a submodule of
203
+ The result Linen module can be used standalone with all Linen APIs, or as a
204
+ submodule of
199
205
another Linen module.
200
206
201
- Since NNX modules are stateful and owns the state, we only create it once during init
207
+ Since NNX modules are stateful and owns the state, we only create it once
208
+ during init
202
209
time, and will track its state and static data as separate variables.
203
210
204
211
Example::
@@ -214,15 +221,16 @@ class ToLinen(linen.Module):
214
221
(32, 64)
215
222
>>> # The static GraphDef of the underlying NNX module
216
223
>>> variables.keys()
217
- dict_keys(['nnx', 'params'])
218
- >>> type(variables['nnx']['graphdef'])
219
- <class 'flax.nnx.graph.GraphDef'>
224
+ dict_keys(['params'])
220
225
221
226
Args:
222
227
nnx_class: The NNX Module class (not instance!).
223
- args: The arguments that normally would be passed in to create the NNX module.
224
- kwargs: The keyword arguments that normally would be passed in to create the NNX module.
225
- skip_rng: True if this NNX module doesn't need `rngs` arg during initialization (not common).
228
+ args: The arguments that normally would be passed in to create the NNX
229
+ module.
230
+ kwargs: The keyword arguments that normally would be passed in to create the
231
+ NNX module.
232
+ skip_rng: True if this NNX module doesn't need `rngs` arg during
233
+ initialization (not common).
226
234
227
235
Returns:
228
236
A stateful NNX module that behaves the same as the wrapped Linen module.
@@ -231,59 +239,108 @@ class ToLinen(linen.Module):
231
239
args : tp .Sequence = ()
232
240
kwargs : tp .Mapping [str , tp .Any ] = FrozenDict ({})
233
241
skip_rng : bool = False
234
- metadata_type : tp .Type = bv .NNXMeta
242
+ metadata_fn : tp .Callable [[variablelib .VariableState ], tp .Any ] | None = (
243
+ bv .to_linen_var
244
+ )
235
245
236
246
@linen .compact
237
247
def __call__ (self , * args , ** kwargs ):
248
+ module_kwargs = dict (self .kwargs )
249
+ maybe_add_default = not self .is_initializing ()
250
+ def _module_kwargs ():
251
+ if not self .skip_rng :
252
+ module_kwargs ['rngs' ] = nnx .Rngs (
253
+ ** linen_rngs_dict (self , add_default = maybe_add_default )
254
+ )
255
+ return module_kwargs
256
+
238
257
# init codepath
239
258
if self .is_initializing ():
240
- module_kwargs = dict (self .kwargs )
241
- if not self .skip_rng :
242
- module_kwargs |= dict (rngs = nnx .Rngs (** linen_rngs_dict (self )))
243
- module = self .nnx_class (* self .args , ** module_kwargs )
259
+ module = self .nnx_class (* self .args , ** _module_kwargs ())
244
260
# TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`.
261
+ # update linen variables before call module to save initial state
245
262
self ._update_variables (module )
246
- return module (* args , ** kwargs )
247
-
248
- # apply codepath
249
- gdef = self .get_variable ('nnx' , 'graphdef' )
250
- assert gdef , 'GraphDef not found in variables. Was the collection "nnx" dropped somewhere?'
251
- variables = {col : v for col , v in self .variables .items () if col != 'nnx' }
252
- states = jtu .tree_map_with_path (
253
- lambda kp , x : bv .to_nnx_var (bv .get_col_name (kp ), x ).to_state (),
254
- variables , is_leaf = lambda x : isinstance (x , meta .AxisMetadata ))
255
- states = [State (v ) for v in states .values ()]
256
- nnx_state = nnx .merge_state (* states ) if states else nnx .GraphState ({})
257
- module = nnx .merge (gdef , nnx_state )
258
- nnx .reseed (module , ** linen_rngs_dict (self )) # reseed with keys from linen apply call.
263
+ out = module (* args , ** kwargs )
264
+ return out
265
+
266
+ # create state
267
+ def maybe_unbox (x ):
268
+ if isinstance (x , meta .AxisMetadata ):
269
+ return x .unbox ()
270
+ return x
271
+ states = jtu .tree_map (
272
+ maybe_unbox ,
273
+ list (self .variables .values ()),
274
+ is_leaf = lambda x : isinstance (x , meta .AxisMetadata ),
275
+ )
276
+ if not states :
277
+ states = ({},)
278
+
279
+ # update module state
280
+ module = nnx .eval_shape (
281
+ lambda : self .nnx_class (* self .args , ** _module_kwargs ())
282
+ )
283
+ nnx .update (module , * states )
284
+ nnx .reseed (
285
+ module , ** linen_rngs_dict (self , add_default = maybe_add_default )
286
+ ) # reseed with keys from linen apply call.
287
+
259
288
out = module (* args , ** kwargs )
260
289
self ._update_variables (module )
261
290
return out
262
291
263
292
def _update_variables (self , module ):
264
293
"""Store the NNX module's graph def and state inside Linen module variables."""
265
- gdef , state = nnx .split (module )
266
- # Save the graph def.
267
- if self .is_mutable_collection ('nnx' ):
268
- self .put_variable ('nnx' , 'graphdef' , gdef )
269
- # Sort all the variable types.
270
- types = set (jax .tree .leaves (
271
- jax .tree .map (lambda x : x .type , state ,
272
- is_leaf = lambda x : isinstance (x , nnx .VariableState ))))
273
- types = bv .sort_variable_types (types )
274
- _ , * state_by_types = nnx .split (module , * types )
275
- # Each variable type goes to its own linen collection, and
276
- # each attribute goes to its own linen variable
277
- for typ , state in zip (types , state_by_types ):
278
- collection = variablelib .variable_name_from_type (typ , allow_register = True )
294
+ state = nnx .state (module , nnx .Not (nnx .RngState ))
295
+
296
+ collection_flat_state : dict [str , list [tuple [tuple [str , ...], tp .Any ]]] = {}
297
+
298
+ # group state by collection
299
+ for path , leaf in nnx .to_flat_state (state ):
300
+ type_ = leaf .type if isinstance (leaf , nnx .VariableState ) else type (leaf )
301
+ collection = variablelib .variable_name_from_type (
302
+ type_ , allow_register = True
303
+ )
304
+ if collection not in collection_flat_state :
305
+ collection_flat_state [collection ] = []
306
+ collection_flat_state [collection ].append ((path , leaf ))
307
+
308
+ # update linen variables
309
+ for collection , flat_state in collection_flat_state .items ():
279
310
if self .is_mutable_collection (collection ):
280
- for k , v in state .raw_mapping .items ():
281
- v = jax .tree .map (bv .to_linen_var , v ,
282
- is_leaf = lambda x : isinstance (x , nnx .VariableState ))
311
+
312
+ def _to_linen_var (x ):
313
+ if isinstance (x , nnx .VariableState ):
314
+ if self .metadata_fn :
315
+ return self .metadata_fn (x )
316
+ else :
317
+ return x .value
318
+ return x
319
+
320
+ collection_state = nnx .traversals .unflatten_mapping (flat_state )
321
+ collection_state = jax .tree .map (
322
+ _to_linen_var ,
323
+ collection_state ,
324
+ is_leaf = lambda x : isinstance (x , nnx .VariableState ),
325
+ )
326
+ for k , v in collection_state .items ():
283
327
self .put_variable (collection , k , v )
284
328
285
329
286
- def to_linen (nnx_class : tp .Callable [..., Module ], * args ,
287
- name : str | None = None , ** kwargs ):
330
+ def to_linen (
331
+ nnx_class : tp .Callable [..., Module ],
332
+ * args ,
333
+ metadata_fn : (
334
+ tp .Callable [[variablelib .VariableState ], tp .Any ] | None
335
+ ) = bv .to_linen_var ,
336
+ name : str | None = None ,
337
+ ** kwargs ,
338
+ ):
288
339
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
289
- return ToLinen (nnx_class , args = args , kwargs = FrozenDict (kwargs ), name = name )
340
+ return ToLinen (
341
+ nnx_class ,
342
+ args = args ,
343
+ kwargs = FrozenDict (kwargs ),
344
+ metadata_fn = metadata_fn ,
345
+ name = name ,
346
+ )
0 commit comments