Skip to content

Commit 251ac34

Browse files
Version bump 3.11.2 and nnx fix #21565 (#21570)
* Fix nnx object state (#21565) * Update operation.py * Update actions.yml * Update operation.py * Update actions.yml * Update operation.py * Update operation.py * Update operation.py * fix test * code reformat * Version bump to 3.11.2 --------- Co-authored-by: Divyashree Sreepathihalli <[email protected]>
1 parent 0e11071 commit 251ac34

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ jobs:
5757
run: |
5858
pip install -r requirements.txt --progress-bar off --upgrade
5959
if [ "${{ matrix.nnx_enabled }}" == "true" ]; then
60-
pip install --upgrade flax>=0.11.0
60+
pip install --upgrade flax>=0.11.1
6161
fi
6262
pip uninstall -y keras keras-nightly
6363
pip install -e "." --progress-bar off --upgrade

keras/src/ops/operation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,11 @@ def __new__(cls, *args, **kwargs):
123123
if backend.backend() == "jax" and is_nnx_enabled():
124124
from flax import nnx
125125

126-
vars(instance)["_object__state"] = nnx.object.ObjectState()
126+
try:
127+
vars(instance)["_pytree__state"] = nnx.pytreelib.PytreeState()
128+
except AttributeError:
129+
vars(instance)["_object__state"] = nnx.object.ObjectState()
130+
127131
# Generate a config to be returned by default by `get_config()`.
128132
arg_names = inspect.getfullargspec(cls.__init__).args
129133
kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
@@ -206,10 +210,9 @@ def __init__(self, arg1, arg2, **kwargs):
206210
207211
def get_config(self):
208212
config = super().get_config()
209-
config.update({{
210-
"arg1": self.arg1,
213+
config.update({"arg1": self.arg1,
211214
"arg2": self.arg2,
212-
}})
215+
})
213216
return config"""
214217
)
215218
)

keras/src/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from keras.src.api_export import keras_export
22

33
# Unique source of truth for the version number.
4-
__version__ = "3.11.1"
4+
__version__ = "3.11.2"
55

66

77
@keras_export("keras.version")

0 commit comments

Comments
 (0)