Skip to content

Commit 429f49e

Browse files
committed
Fixed deprecated .value usage failing CI tests
1 parent 09537f4 commit 429f49e

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

flax/nnx/variablelib.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,25 +1712,25 @@ def __contains__(self, item) -> bool:
17121712

17131713
def __eq__(self, other) -> bool:
17141714
if isinstance(other, Variable):
1715-
other = other.value
1716-
return self.value.__eq__(other) # type: ignore
1715+
other = other.get_value()
1716+
return self.get_value().__eq__(other) # type: ignore
17171717

17181718
def __iadd__(self: V, other) -> V:
17191719
raise NotImplementedError(
17201720
'In-place operations are no longer supported for Variable.\n'
1721-
'Use `variable.value += x` instead.'
1721+
'Use `variable[...] += x` instead.'
17221722
)
17231723

17241724
def __isub__(self: V, other) -> V:
17251725
raise NotImplementedError(
17261726
'In-place operations are no longer supported for Variable.\n'
1727-
'Use `variable.value -= x` instead.'
1727+
'Use `variable[...] -= x` instead.'
17281728
)
17291729

17301730
def __imul__(self: V, other) -> V:
17311731
raise NotImplementedError(
17321732
'In-place operations are no longer supported for Variable.\n'
1733-
'Use `variable.value *= x` instead.'
1733+
'Use `variable[...] *= x` instead.'
17341734
)
17351735

17361736
def __imatmul__(self: V, other) -> V:

tests/nnx/variable_test.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
import jax.numpy as jnp
1919
import pytest
2020

21-
from absl.testing import absltest
21+
from absl.testing import absltest, parameterized
2222
from flax import nnx
2323

2424
A = tp.TypeVar('A')
2525

26-
class TestVariable(absltest.TestCase):
26+
class TestVariable(parameterized.TestCase):
2727
def test_pytree(self):
2828
r1 = nnx.Param(1)
2929
self.assertEqual(r1.get_value(), 1)
@@ -95,6 +95,18 @@ def test_binary_ops(self):
9595

9696
self.assertEqual(v1[...], 5)
9797

98+
@parameterized.product(
99+
v1=[jnp.array([1, 2]), jnp.array(2), 3],
100+
v2=[jnp.array([1, 2]), jnp.array(2), 3],
101+
)
102+
def test_eq_op(self, v1, v2):
103+
p1 = nnx.Param(v1)
104+
p2 = nnx.Param(v2)
105+
if isinstance(v1, jax.Array) or isinstance(v2, jax.Array):
106+
self.assertEqual((p1 == p2).all(), (v1 == v2).all())
107+
else:
108+
self.assertEqual(p1 == p2, v1 == v2)
109+
98110
def test_mutable_array_context(self):
99111
initial_mode = nnx.using_hijax()
100112
with nnx.use_hijax(False):

0 commit comments

Comments
 (0)