Skip to content

[bridge module] Add bridge.share_scope for layer-sublayer pairs. #4638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .module import compact as compact
from .module import current_context as current_context
from .module import current_module as current_module
from .module import share_scope as share_scope
from .interop import nnx_in_bridge_mdl as nnx_in_bridge_mdl
from .interop import linen_in_bridge_mdl as linen_in_bridge_mdl
from flax.nnx.nn import initializers as initializers
32 changes: 20 additions & 12 deletions flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@ class ModuleState(statelib.State):


class Scope(Object):
def __init__(self, rngs: rnglib.Rngs, mutable: CollectionFilter):
def __init__(self, module: Module, rngs: rnglib.Rngs, mutable: CollectionFilter):
self.module = module
self.rngs = rngs
self.mutable = mutable

def copy(self):
return Scope(self.rngs, self.mutable)
def copy(self, new_module):
# Never copy the module - always fill in a new one
return Scope(new_module, self.rngs, self.mutable)


class _HasSetup(tp.Protocol):
Expand Down Expand Up @@ -104,7 +106,7 @@ def _bind_module(parent: Module, module: Module) -> Module:
for _, value in reversed(list(graph.iter_graph(module))):
if isinstance(value, Module):
if module.scope is None:
value.scope = parent.scope.copy() # type: ignore[attribute-error]
value.scope = parent.scope.copy(value) # type: ignore[attribute-error]
_maybe_call_setup(value)
return module

Expand Down Expand Up @@ -280,8 +282,9 @@ def param( # type: ignore[invalid-annotation]
'Parameters must be initialized in `setup()` or in a method '
'wrapped in `@compact`'
)
if hasattr(self, name):
value = getattr(self, name)
module = self.scope.module
if hasattr(module, name):
value = getattr(module, name)
# TODO(cgarciae): implement reservations
# if self._name_taken(name):
# raise errors.NameInUseError('param', name, self.__class__.__name__)
Expand Down Expand Up @@ -310,10 +313,10 @@ def param( # type: ignore[invalid-annotation]

variable = variablelib.Param(value)
else:
value = init_fn(self.make_rng('params'), *init_args, **init_kwargs)
value = init_fn(module.make_rng('params'), *init_args, **init_kwargs)
variable = variablelib.Param(value)

setattr(self, name, variable)
setattr(module, name, variable)
return variable

def variable( # type: ignore[invalid-annotation]
Expand All @@ -333,9 +336,10 @@ def variable( # type: ignore[invalid-annotation]
'Variables must be initialized in `setup()` or in a method '
'wrapped in `@compact`'
)
module = self.scope.module

if hasattr(self, name):
value = getattr(self, name)
if hasattr(module, name):
value = getattr(module, name)
# TODO(cgarciae): implement reservations
# if self._name_taken(name):
# raise errors.NameInUseError('param', name, self.__class__.__name__)
Expand Down Expand Up @@ -367,7 +371,7 @@ def variable( # type: ignore[invalid-annotation]
value = init_fn(*init_args, **init_kwargs)
variable = variable_type(value)

setattr(self, name, variable)
setattr(module, name, variable)
return variable

def _get_variables(self) -> tp.Mapping:
Expand Down Expand Up @@ -474,7 +478,7 @@ def to_variable(value):
if isinstance(value, Object):
value._object__state._initializing = _initialize
if isinstance(value, Module):
value.scope = Scope(rngs, mutable)
value.scope = Scope(value, rngs, mutable)
_maybe_call_setup(value)

MODULE_CONTEXT.module_stack.append(
Expand Down Expand Up @@ -570,3 +574,7 @@ def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable:

return method_or_fn


def share_scope(parent: Module, child: Module):
"""Behaves like `linen.share_scope`, for a pair of parent and child modules."""
child.scope = parent.scope
22 changes: 22 additions & 0 deletions tests/nnx/bridge/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,28 @@ def __call__(self, x):
params = model.init(jax.random.key(0), x)['params']
self.assertSameElements([f'layer_{i}' for i in range(3)], params.keys())

def test_share_scope(self):
class Dense(bridge.Module):
dout: int
@bridge.compact
def __call__(self, x):
return x @ self.param('w', nn.initializers.normal(),
(x.shape[-1], self.dout))

class Top(bridge.Module):
def setup(self):
self.a = Dense(4)
bridge.module.share_scope(self, self.a)

def __call__(self, x):
return self.a(x)

model = Top()
x = jnp.ones((4, 32))
params = model.init(jax.random.key(0), x)['params']
self.assertSameElements(['w'], params.keys()) # 'a' doesn't exist



if __name__ == '__main__':
absltest.main()
Expand Down
Loading