Skip to content

Commit 00cf44e

Browse files
committed
Fixed deprecated .value usage failing CI tests
1 parent 09537f4 commit 00cf44e

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

flax/nnx/variablelib.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,25 +1712,25 @@ def __contains__(self, item) -> bool:
17121712

17131713
def __eq__(self, other) -> bool:
17141714
if isinstance(other, Variable):
1715-
other = other.value
1716-
return self.value.__eq__(other) # type: ignore
1715+
other = other.get_value()
1716+
return self.get_value().__eq__(other) # type: ignore
17171717

17181718
def __iadd__(self: V, other) -> V:
17191719
raise NotImplementedError(
17201720
'In-place operations are no longer supported for Variable.\n'
1721-
'Use `variable.value += x` instead.'
1721+
'Use `variable[...] += x` instead.'
17221722
)
17231723

17241724
def __isub__(self: V, other) -> V:
17251725
raise NotImplementedError(
17261726
'In-place operations are no longer supported for Variable.\n'
1727-
'Use `variable.value -= x` instead.'
1727+
'Use `variable[...] -= x` instead.'
17281728
)
17291729

17301730
def __imul__(self: V, other) -> V:
17311731
raise NotImplementedError(
17321732
'In-place operations are no longer supported for Variable.\n'
1733-
'Use `variable.value *= x` instead.'
1733+
'Use `variable[...] *= x` instead.'
17341734
)
17351735

17361736
def __imatmul__(self: V, other) -> V:

tests/nnx/variable_test.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616

1717
import jax
1818
import jax.numpy as jnp
19+
import numpy as np
1920
import pytest
2021

21-
from absl.testing import absltest
22+
from absl.testing import absltest, parameterized
2223
from flax import nnx
2324

2425
A = tp.TypeVar('A')
2526

26-
class TestVariable(absltest.TestCase):
27+
class TestVariable(parameterized.TestCase):
2728
def test_pytree(self):
2829
r1 = nnx.Param(1)
2930
self.assertEqual(r1.get_value(), 1)
@@ -95,6 +96,18 @@ def test_binary_ops(self):
9596

9697
self.assertEqual(v1[...], 5)
9798

99+
@parameterized.product(
100+
v1=[np.array([1, 2]), np.array(2), 3],
101+
v2=[np.array([1, 2]), np.array(2), 3],
102+
)
103+
def test_eq_op(self, v1, v2):
104+
p1 = nnx.Param(jnp.asarray(v1) if isinstance(v1, np.ndarray) else v1)
105+
p2 = nnx.Param(jnp.asarray(v2) if isinstance(v2, np.ndarray) else v2)
106+
if isinstance(v1, np.ndarray) or isinstance(v2, np.ndarray):
107+
self.assertEqual((p1 == p2).all(), (v1 == v2).all())
108+
else:
109+
self.assertEqual(p1 == p2, v1 == v2)
110+
98111
def test_mutable_array_context(self):
99112
initial_mode = nnx.using_hijax()
100113
with nnx.use_hijax(False):

0 commit comments

Comments
 (0)