diff --git a/flax/nnx/bridge/__init__.py b/flax/nnx/bridge/__init__.py index 7b155fa97..fc83b0da7 100644 --- a/flax/nnx/bridge/__init__.py +++ b/flax/nnx/bridge/__init__.py @@ -27,6 +27,7 @@ from .module import compact as compact from .module import current_context as current_context from .module import current_module as current_module +from .module import share_scope as share_scope from .interop import nnx_in_bridge_mdl as nnx_in_bridge_mdl from .interop import linen_in_bridge_mdl as linen_in_bridge_mdl from flax.nnx.nn import initializers as initializers diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py index 76a7d222b..912e5ba4f 100644 --- a/flax/nnx/bridge/module.py +++ b/flax/nnx/bridge/module.py @@ -65,12 +65,14 @@ class ModuleState(statelib.State): class Scope(Object): - def __init__(self, rngs: rnglib.Rngs, mutable: CollectionFilter): + def __init__(self, module: Module, rngs: rnglib.Rngs, mutable: CollectionFilter): + self.module = module self.rngs = rngs self.mutable = mutable - def copy(self): - return Scope(self.rngs, self.mutable) + def copy(self, new_module): + # Never copy the module - always fill in a new one + return Scope(new_module, self.rngs, self.mutable) class _HasSetup(tp.Protocol): @@ -104,7 +106,7 @@ def _bind_module(parent: Module, module: Module) -> Module: for _, value in reversed(list(graph.iter_graph(module))): if isinstance(value, Module): if module.scope is None: - value.scope = parent.scope.copy() # type: ignore[attribute-error] + value.scope = parent.scope.copy(value) # type: ignore[attribute-error] _maybe_call_setup(value) return module @@ -280,8 +282,9 @@ def param( # type: ignore[invalid-annotation] 'Parameters must be initialized in `setup()` or in a method ' 'wrapped in `@compact`' ) - if hasattr(self, name): - value = getattr(self, name) + module = self.scope.module + if hasattr(module, name): + value = getattr(module, name) # TODO(cgarciae): implement reservations # if self._name_taken(name): # raise errors.NameInUseError('param', name, self.__class__.__name__) @@ -310,10 +313,10 @@ def param( # type: ignore[invalid-annotation] variable = variablelib.Param(value) else: - value = init_fn(self.make_rng('params'), *init_args, **init_kwargs) + value = init_fn(module.make_rng('params'), *init_args, **init_kwargs) variable = variablelib.Param(value) - setattr(self, name, variable) + setattr(module, name, variable) return variable def variable( # type: ignore[invalid-annotation] @@ -333,9 +336,10 @@ def variable( # type: ignore[invalid-annotation] 'Variables must be initialized in `setup()` or in a method ' 'wrapped in `@compact`' ) + module = self.scope.module - if hasattr(self, name): - value = getattr(self, name) + if hasattr(module, name): + value = getattr(module, name) # TODO(cgarciae): implement reservations # if self._name_taken(name): # raise errors.NameInUseError('param', name, self.__class__.__name__) @@ -367,7 +371,7 @@ def variable( # type: ignore[invalid-annotation] value = init_fn(*init_args, **init_kwargs) variable = variable_type(value) - setattr(self, name, variable) + setattr(module, name, variable) return variable def _get_variables(self) -> tp.Mapping: @@ -474,7 +478,7 @@ def to_variable(value): if isinstance(value, Object): value._object__state._initializing = _initialize if isinstance(value, Module): - value.scope = Scope(rngs, mutable) + value.scope = Scope(value, rngs, mutable) _maybe_call_setup(value) MODULE_CONTEXT.module_stack.append( @@ -570,3 +574,7 @@ def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable: return method_or_fn + +def share_scope(parent: Module, child: Module): + """Behaves like `linen.share_scope`, for a pair of parent and child modules.""" + child.scope = parent.scope \ No newline at end of file diff --git a/tests/nnx/bridge/module_test.py b/tests/nnx/bridge/module_test.py index e5995e642..bf1496a81 100644 --- a/tests/nnx/bridge/module_test.py +++ b/tests/nnx/bridge/module_test.py @@ -542,6 +542,28 @@ def __call__(self, x): params = model.init(jax.random.key(0), x)['params'] self.assertSameElements([f'layer_{i}' for i in range(3)], params.keys()) + def test_share_scope(self): + class Dense(bridge.Module): + dout: int + @bridge.compact + def __call__(self, x): + return x @ self.param('w', nn.initializers.normal(), + (x.shape[-1], self.dout)) + + class Top(bridge.Module): + def setup(self): + self.a = Dense(4) + bridge.module.share_scope(self, self.a) + + def __call__(self, x): + return self.a(x) + + model = Top() + x = jnp.ones((4, 32)) + params = model.init(jax.random.key(0), x)['params'] + self.assertSameElements(['w'], params.keys()) # 'a' doesn't exist + + if __name__ == '__main__': absltest.main()