diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 3270b6ab1..0eb6dca0d 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -47,6 +47,10 @@ from .helpers import TrainState as TrainState from .module import M as M from .module import Module as Module +from .module import set_mode as set_mode +from .module import train_mode as train_mode +from .module import eval_mode as eval_mode +from .module import set_attributes as set_attributes from .graph import merge as merge from .graph import UpdateContext as UpdateContext from .graph import update_context as update_context @@ -58,6 +62,7 @@ from .graph import state as state from .graph import graphdef as graphdef from .graph import iter_graph as iter_graph +from .graph import recursive_map as recursive_map from .graph import find_duplicates as find_duplicates from .graph import call as call from .graph import SplitContext as SplitContext diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 6ab9d04f3..3e2060110 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -2583,7 +2583,7 @@ def pop( return states -def clone(node: Node) -> Node: +def clone(node: Node, /, *, variables: bool = True) -> Node: """Create a deep copy of the given graph node. Example usage:: @@ -2597,11 +2597,12 @@ def clone(node: Node) -> Node: Args: node: A graph node object. + copy_variables: Whether to create new copies of the Variables in the states, defaults to ``True``. Returns: A deep copy of the :class:`Module` object. """ graphdef, state = split(node) - return merge(graphdef, state, copy=True) + return merge(graphdef, state, copy=variables) def _mutable_like(path, x): @@ -2928,6 +2929,54 @@ def _iter_graph( yield path_parts, node +def recursive_map(f: tp.Callable[[PathParts, tp.Any], tp.Any], node: tp.Any, /): + node = clone(node, variables=False) + path_parts: PathParts = () + visited: set[int] = set() + results: dict[int, tp.Any] = {} + return _recursive_map(f, node, path_parts, visited, results) + + +def _recursive_map( + f: tp.Callable[[PathParts, tp.Any], tp.Any], + node: tp.Any, + path: PathParts, + visited: set[int], + results: dict[int, tp.Any], +) -> tp.Any: + node_id = id(node) + if node_id in visited: + if node_id in results: + return results[node_id] + path_str = '/'.join(map(str, path)) + raise ValueError( + f"Found cycle in the graph at path '{path_str}'. Node of type" + f' {type(node)} has already been visited but has not been returned yet.' + ) + node_impl = get_node_impl(node) + if ( + type(node_impl) is GraphNodeImpl + or isinstance(node, Variable) + or is_array_ref(node) + ): + visited.add(node_id) + if node_impl is not None: + for key, value in node_impl.node_dict(node).items(): + new_value = _recursive_map(f, value, (*path, key), visited, results) + if new_value is not value: + if node_impl.set_key is not None and value is not new_value: + node_impl.set_key(node, key, new_value) + else: + raise ValueError( + f"Cannot update key '{key}' for node of type '{type(node)}'" + ' because the node does not support mutation.' + ) + + new_node = f(path, node) + results[node_id] = new_node + return new_node + + def find_duplicates(node: tp.Any, /, *, only: filterlib.Filter = ...) -> list[list[PathParts]]: """Finds duplicate nodes or node leaves in the given node. @@ -3099,10 +3148,14 @@ def _unflatten_pytree( # common pytrees # list +def _list_set_key(x: list[tp.Any], key: int, value: tp.Any): + x[key] = value + register_pytree_node_type( list, flatten=lambda x: (list(enumerate(x)), None), unflatten=lambda nodes, _: [value for _, value in nodes], # type: ignore + set_key=_list_set_key, ) # tuple register_pytree_node_type( diff --git a/flax/nnx/module.py b/flax/nnx/module.py index cce1f6685..12bb28bf0 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -478,6 +478,54 @@ def eval(self, **attributes): ) +def set_mode(node: A, /, *, only: filterlib.Filter = ..., **kwargs) -> A: + predicate = filterlib.to_predicate(only) + + def _set_mode_fn(path, node): + if hasattr(node, 'set_mode') and predicate(path, node): + node.set_mode(**kwargs) + return node + + return graph.recursive_map(_set_mode_fn, node) + + +def train_mode(node: A, /, *, only: filterlib.Filter = ..., **kwargs) -> A: + return set_mode( + node, + only=only, + train=True, + deterministic=False, + use_running_average=False, + **kwargs, + ) + + +def eval_mode(node: A, /, *, only: filterlib.Filter = ..., **kwargs) -> A: + return set_mode( + node, + only=only, + train=False, + deterministic=True, + use_running_average=True, + **kwargs, + ) + + +def set_attributes( + node: A, /, *, only: filterlib.Filter = ..., **attributes +) -> A: + predicate = filterlib.to_predicate(only) + + def _set_attributes_fn(path, node): + if predicate(path, node): + for name, value in attributes.items(): + if hasattr(node, name): + setattr(node, name, value) + return node + + return graph.recursive_map(_set_attributes_fn, node) + + def first_from(*args: tp.Optional[A], error_msg: str) -> A: """Return the first non-None argument. diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 3514da2a7..9ff2eba1f 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -608,11 +608,64 @@ def __call__( out = self.out(x) return out + def set_mode( + self, + train: bool | None = None, + deterministic: bool | None = None, + decode: bool | None = None, + batch_size: int | Shape | None = None, + max_length: int | None = None, + **kwargs, + ): + """ + Args: + train: if True, the module is set to training mode. + deterministic: if True, the module is set to deterministic mode. + decode: if True, the module is set to decode mode. + batch_size: the batch size to use for the cache. + max_length: the max length to use for the cache. + """ + if deterministic is not None: + self.deterministic = deterministic + elif train is not None: + self.deterministic = not train + + if decode is not None: + self.decode = decode + if ( + not hasattr(self, 'cached_key') + or not hasattr(self, 'cached_value') + or not hasattr(self, 'cache_index') + ): + if batch_size is None: + raise TypeError( + "'batch_size' must be provided when initializing cache." + ) + if max_length is None: + raise TypeError( + "'max_length' must be provided when initializing cache." + ) + self.init_cache2(batch_size, max_length, dtype=self.dtype) + + def init_cache2( + self, batch_size: int | Shape, max_length: int, dtype: Dtype | None = None + ): + if dtype is None: + dtype = self.dtype + if isinstance(batch_size, int): + batch_size = (batch_size,) + + cache_shape = (*batch_size, max_length, self.num_heads, self.head_dim) + self.cached_key = nnx.Cache(jnp.zeros(cache_shape, dtype)) + self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype)) + self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32)) + def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): - """Initializes cache for fast autoregressive decoding. When - ``decode=True``, this method must be called first before performing - forward inference. When in decode mode, only one token must be passed - at a time. + """Initializes cache for fast autoregressive decoding. + + When ``decode=True``, this method must be called first before performing + forward inference. When in decode mode, only one token must be passed at a + time. Example usage:: @@ -632,7 +685,8 @@ def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): ... rngs=nnx.Rngs(42), ... ) ... - >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized + >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't + initialized ... >>> model_nnx.init_cache(x.shape) >>> out_nnx = model_nnx(x) diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index a157c585a..e8c351c45 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -379,6 +379,17 @@ def __call__( self.epsilon, ) + def set_mode( + self, + train: bool | None = None, + use_running_average: bool | None = None, + **kwargs, + ): + if use_running_average is not None: + self.use_running_average = use_running_average + elif train is not None: + self.use_running_average = not train + class LayerNorm(Module): """Layer normalization (https://arxiv.org/abs/1607.06450). @@ -832,4 +843,4 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): (self.feature_axis,), self.dtype, self.epsilon, - ) \ No newline at end of file + ) diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py index ab365ef5c..cd0b3195d 100644 --- a/flax/nnx/nn/stochastic.py +++ b/flax/nnx/nn/stochastic.py @@ -73,7 +73,7 @@ def __init__( rate: float, *, broadcast_dims: Sequence[int] = (), - deterministic: bool = False, + deterministic: bool | None = None, rng_collection: str = 'dropout', rngs: rnglib.Rngs | rnglib.RngStream | None = None, ): @@ -153,3 +153,11 @@ def __call__( mask = random.bernoulli(key, p=keep_prob, shape=broadcast_shape) mask = jnp.broadcast_to(mask, inputs.shape) return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) + + def set_mode( + self, + deterministic: bool | None = None, + **kwargs, + ): + if deterministic is not None: + self.deterministic = deterministic diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 5906812ee..f580bbf6d 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -1107,6 +1107,53 @@ def __init__(self, rngs: nnx.Rngs): self.assertLen(duplicates, 1) self.assertEqual(duplicates[0], [('a',), ('c',)]) + def test_resursive_map(self): + class Foo(nnx.Pytree): + def __init__(self, d): + self.d = d + + foo1 = Foo(10) + foo2 = Foo(20) + bar = [foo1, foo2, foo1] + n = 0 + + def inc_d(path, node): + nonlocal n + if isinstance(node, Foo): + n += 1 + node.d += 1 + return node + + bar2 = nnx.recursive_map(inc_d, bar) + self.assertIs(bar2[0], bar2[2]) + self.assertEqual(bar2[0].d, 11) + self.assertEqual(bar2[1].d, 21) + self.assertEqual(n, 2) + + def test_resursive_map_replace(self): + class Foo(nnx.Pytree): + def __init__(self, d): + self.d = d + + foo1 = Foo(10) + foo2 = Foo(20) + bar = [foo1, foo2, foo1] + n = 0 + + def swap(path, node): + nonlocal n + if isinstance(node, Foo): + n += 1 + node = Foo(-node.d) + return node + + bar2 = nnx.recursive_map(swap, bar) + self.assertIs(bar2[0], bar2[2]) + self.assertEqual(bar2[0].d, -10) + self.assertEqual(bar2[1].d, -20) + self.assertEqual(n, 2) + + class SimpleModule(nnx.Module): pass diff --git a/tests/nnx/integration_test.py b/tests/nnx/integration_test.py index 67eaa71c5..cc31b6277 100644 --- a/tests/nnx/integration_test.py +++ b/tests/nnx/integration_test.py @@ -28,6 +28,59 @@ class TestIntegration(absltest.TestCase): + + def test_basic_example(self): + class Model(nnx.Module): + + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dmid, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.dropout = nnx.Dropout(0.2, rngs=rngs) + self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x): + x = nnx.relu(self.dropout(self.bn(self.linear(x)))) + return self.linear_out(x) + + model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization + train_model = nnx.set_mode( + model, deterministic=False, use_running_average=False + ) + eval_model = nnx.set_mode( + model, deterministic=True, use_running_average=True + ) + optimizer = nnx.Optimizer(train_model, optax.adam(1e-3), wrt=nnx.Param) + + self.assertEqual(train_model.dropout.deterministic, False) + self.assertEqual(train_model.bn.use_running_average, False) + self.assertEqual(eval_model.dropout.deterministic, True) + self.assertEqual(eval_model.bn.use_running_average, True) + self.assertIs(train_model.dropout.rngs.count, eval_model.dropout.rngs.count) + + @nnx.jit # automatic state management for JAX transforms + def train_step(model, optimizer, x, y): + def loss_fn(model): + y_pred = model(x) + return jnp.mean((y_pred - y) ** 2) + + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(model, grads) # in-place updates + + return loss + + @nnx.jit + def eval_step(model, x, y): + y_pred = model(x) + return jnp.mean((y_pred - y) ** 2) + + x = jax.random.normal(jax.random.key(0), (8, 2)) + y = jax.random.normal(jax.random.key(1), (8, 3)) + + train_step(train_model, optimizer, x, y) + self.assertEqual(train_model.dropout.rngs.count.value, 1) + eval_step(eval_model, x, y) + self.assertEqual(train_model.dropout.rngs.count.value, 1) + def test_shared_modules(self): class Block(nnx.Module): def __init__(self, linear: nnx.Linear, *, rngs): @@ -319,7 +372,7 @@ def train_step(x, y): graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) - return ((model(x) - y) ** 2).mean() # call methods directly + return ((model(x) - y) ** 2).mean() loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) optimizer.update(model, grads) # in-place updates