Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ flaxlib_src/subprojects

# custom
/tmp-files
.env # test
5 changes: 5 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
from .helpers import TrainState as TrainState
from .module import M as M
from .module import Module as Module
from .module import set_mode as set_mode
from .module import train_mode as train_mode
from .module import eval_mode as eval_mode
from .module import set_attributes as set_attributes
from .graph import merge as merge
from .graph import UpdateContext as UpdateContext
from .graph import update_context as update_context
Expand All @@ -58,6 +62,7 @@
from .graph import state as state
from .graph import graphdef as graphdef
from .graph import iter_graph as iter_graph
from .graph import recursive_map as recursive_map
from .graph import find_duplicates as find_duplicates
from .graph import call as call
from .graph import SplitContext as SplitContext
Expand Down
57 changes: 55 additions & 2 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2583,7 +2583,7 @@ def pop(
return states


def clone(node: Node) -> Node:
def clone(node: Node, /, *, variables: bool = True) -> Node:
"""Create a deep copy of the given graph node.

Example usage::
Expand All @@ -2597,11 +2597,12 @@ def clone(node: Node) -> Node:

Args:
node: A graph node object.
copy_variables: Whether to create new copies of the Variables in the states, defaults to ``True``.
Returns:
A deep copy of the :class:`Module` object.
"""
graphdef, state = split(node)
return merge(graphdef, state, copy=True)
return merge(graphdef, state, copy=variables)


def _mutable_like(path, x):
Expand Down Expand Up @@ -2928,6 +2929,54 @@ def _iter_graph(
yield path_parts, node


def recursive_map(f: tp.Callable[[PathParts, tp.Any], tp.Any], node: tp.Any, /):
node = clone(node, variables=False)
path_parts: PathParts = ()
visited: set[int] = set()
results: dict[int, tp.Any] = {}
return _recursive_map(f, node, path_parts, visited, results)


def _recursive_map(
f: tp.Callable[[PathParts, tp.Any], tp.Any],
node: tp.Any,
path: PathParts,
visited: set[int],
results: dict[int, tp.Any],
) -> tp.Any:
node_id = id(node)
if node_id in visited:
if node_id in results:
return results[node_id]
path_str = '/'.join(map(str, path))
raise ValueError(
f"Found cycle in the graph at path '{path_str}'. Node of type"
f' {type(node)} has already been visited but has not been returned yet.'
)
node_impl = get_node_impl(node)
if (
type(node_impl) is GraphNodeImpl
or isinstance(node, Variable)
or is_array_ref(node)
):
visited.add(node_id)
if node_impl is not None:
for key, value in node_impl.node_dict(node).items():
new_value = _recursive_map(f, value, (*path, key), visited, results)
if new_value is not value:
if node_impl.set_key is not None and value is not new_value:
node_impl.set_key(node, key, new_value)
else:
raise ValueError(
f"Cannot update key '{key}' for node of type '{type(node)}'"
' because the node does not support mutation.'
)

new_node = f(path, node)
results[node_id] = new_node
return new_node


def find_duplicates(node: tp.Any, /, *, only: filterlib.Filter = ...) -> list[list[PathParts]]:
"""Finds duplicate nodes or node leaves in the given node.

Expand Down Expand Up @@ -3099,10 +3148,14 @@ def _unflatten_pytree(

# common pytrees
# list
def _list_set_key(x: list[tp.Any], key: int, value: tp.Any):
x[key] = value

register_pytree_node_type(
list,
flatten=lambda x: (list(enumerate(x)), None),
unflatten=lambda nodes, _: [value for _, value in nodes], # type: ignore
set_key=_list_set_key,
)
# tuple
register_pytree_node_type(
Expand Down
48 changes: 48 additions & 0 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,54 @@ def eval(self, **attributes):
)


def set_mode(node: A, /, *, only: filterlib.Filter = ..., **kwargs) -> A:
predicate = filterlib.to_predicate(only)

def _set_mode_fn(path, node):
if hasattr(node, 'set_mode') and predicate(path, node):
node.set_mode(**kwargs)
return node

return graph.recursive_map(_set_mode_fn, node)


def train_mode(node: A, /, *, only: filterlib.Filter = ..., **kwargs) -> A:
return set_mode(
node,
only=only,
train=True,
deterministic=False,
use_running_average=False,
**kwargs,
)


def eval_mode(node: A, /, *, only: filterlib.Filter = ..., **kwargs) -> A:
return set_mode(
node,
only=only,
train=False,
deterministic=True,
use_running_average=True,
**kwargs,
)


def set_attributes(
node: A, /, *, only: filterlib.Filter = ..., **attributes
) -> A:
predicate = filterlib.to_predicate(only)

def _set_attributes_fn(path, node):
if predicate(path, node):
for name, value in attributes.items():
if hasattr(node, name):
setattr(node, name, value)
return node

return graph.recursive_map(_set_attributes_fn, node)


def first_from(*args: tp.Optional[A], error_msg: str) -> A:
"""Return the first non-None argument.

Expand Down
64 changes: 59 additions & 5 deletions flax/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,11 +608,64 @@ def __call__(
out = self.out(x)
return out

def set_mode(
self,
train: bool | None = None,
deterministic: bool | None = None,
decode: bool | None = None,
batch_size: int | Shape | None = None,
max_length: int | None = None,
**kwargs,
):
"""
Args:
train: if True, the module is set to training mode.
deterministic: if True, the module is set to deterministic mode.
decode: if True, the module is set to decode mode.
batch_size: the batch size to use for the cache.
max_length: the max length to use for the cache.
"""
if deterministic is not None:
self.deterministic = deterministic
elif train is not None:
self.deterministic = not train

if decode is not None:
self.decode = decode
if (
not hasattr(self, 'cached_key')
or not hasattr(self, 'cached_value')
or not hasattr(self, 'cache_index')
):
if batch_size is None:
raise TypeError(
"'batch_size' must be provided when initializing cache."
)
if max_length is None:
raise TypeError(
"'max_length' must be provided when initializing cache."
)
self.init_cache2(batch_size, max_length, dtype=self.dtype)

def init_cache2(
self, batch_size: int | Shape, max_length: int, dtype: Dtype | None = None
):
if dtype is None:
dtype = self.dtype
if isinstance(batch_size, int):
batch_size = (batch_size,)

cache_shape = (*batch_size, max_length, self.num_heads, self.head_dim)
self.cached_key = nnx.Cache(jnp.zeros(cache_shape, dtype))
self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype))
self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32))

def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32):
"""Initializes cache for fast autoregressive decoding. When
``decode=True``, this method must be called first before performing
forward inference. When in decode mode, only one token must be passed
at a time.
"""Initializes cache for fast autoregressive decoding.

When ``decode=True``, this method must be called first before performing
forward inference. When in decode mode, only one token must be passed at a
time.

Example usage::

Expand All @@ -632,7 +685,8 @@ def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32):
... rngs=nnx.Rngs(42),
... )
...
>>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized
>>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't
initialized
...
>>> model_nnx.init_cache(x.shape)
>>> out_nnx = model_nnx(x)
Expand Down
13 changes: 12 additions & 1 deletion flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,17 @@ def __call__(
self.epsilon,
)

def set_mode(
self,
train: bool | None = None,
use_running_average: bool | None = None,
**kwargs,
):
if use_running_average is not None:
self.use_running_average = use_running_average
elif train is not None:
self.use_running_average = not train


class LayerNorm(Module):
"""Layer normalization (https://arxiv.org/abs/1607.06450).
Expand Down Expand Up @@ -832,4 +843,4 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
(self.feature_axis,),
self.dtype,
self.epsilon,
)
)
10 changes: 9 additions & 1 deletion flax/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
rate: float,
*,
broadcast_dims: Sequence[int] = (),
deterministic: bool = False,
deterministic: bool | None = None,
rng_collection: str = 'dropout',
rngs: rnglib.Rngs | rnglib.RngStream | None = None,
):
Expand Down Expand Up @@ -153,3 +153,11 @@ def __call__(
mask = random.bernoulli(key, p=keep_prob, shape=broadcast_shape)
mask = jnp.broadcast_to(mask, inputs.shape)
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))

def set_mode(
self,
deterministic: bool | None = None,
**kwargs,
):
if deterministic is not None:
self.deterministic = deterministic
47 changes: 47 additions & 0 deletions tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,53 @@ def __init__(self, rngs: nnx.Rngs):
self.assertLen(duplicates, 1)
self.assertEqual(duplicates[0], [('a',), ('c',)])

def test_resursive_map(self):
class Foo(nnx.Pytree):
def __init__(self, d):
self.d = d

foo1 = Foo(10)
foo2 = Foo(20)
bar = [foo1, foo2, foo1]
n = 0

def inc_d(path, node):
nonlocal n
if isinstance(node, Foo):
n += 1
node.d += 1
return node

bar2 = nnx.recursive_map(inc_d, bar)
self.assertIs(bar2[0], bar2[2])
self.assertEqual(bar2[0].d, 11)
self.assertEqual(bar2[1].d, 21)
self.assertEqual(n, 2)

def test_resursive_map_replace(self):
class Foo(nnx.Pytree):
def __init__(self, d):
self.d = d

foo1 = Foo(10)
foo2 = Foo(20)
bar = [foo1, foo2, foo1]
n = 0

def swap(path, node):
nonlocal n
if isinstance(node, Foo):
n += 1
node = Foo(-node.d)
return node

bar2 = nnx.recursive_map(swap, bar)
self.assertIs(bar2[0], bar2[2])
self.assertEqual(bar2[0].d, -10)
self.assertEqual(bar2[1].d, -20)
self.assertEqual(n, 2)


class SimpleModule(nnx.Module):
pass

Expand Down
Loading
Loading