Skip to content

Commit 01b16e4

Browse files
committed
hijax Variable
1 parent e8e6572 commit 01b16e4

File tree

7 files changed

+587
-177
lines changed

7 files changed

+587
-177
lines changed

flax/nnx/extract.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import jax
1919

2020
from flax import struct
21+
from flax import typing
2122
from flax.nnx.pytreelib import Pytree
2223
from flax.typing import Missing, PathParts
2324
from flax.nnx import graph, variablelib
@@ -35,7 +36,7 @@ class PrefixMapping(abc.ABC):
3536
@abc.abstractmethod
3637
def map_prefix(
3738
self,
38-
path: variablelib.PathParts,
39+
path: typing.PathParts,
3940
variable: variablelib.Variable,
4041
/,
4142
) -> tp.Any: ...

flax/nnx/transforms/iteration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import typing as tp
2020

2121
from flax import struct
22+
from flax import typing
2223
from flax.core.frozen_dict import FrozenDict
2324
from flax.nnx import extract, filterlib, graph, spmd, variablelib
2425
from flax.nnx import statelib
@@ -89,7 +90,7 @@ def axes(self) -> tuple[Index | type[Carry] | None, ...]:
8990
return self._axes
9091

9192
def map_prefix(
92-
self, path: variablelib.PathParts, variable: variablelib.Variable
93+
self, path: typing.PathParts, variable: variablelib.Variable
9394
) -> tp.Any:
9495
for filter, axis in zip(self.filters, self.axes):
9596
predicate = filterlib.to_predicate(filter)

0 commit comments

Comments
 (0)