Skip to content

Commit f400015

Browse files
IvyZXFlax Authors
authored andcommitted
Copybara import of the project:
-- ff85248 by IvyZX <[email protected]>: Correctly expose flax.config.temp_flip_flag PiperOrigin-RevId: 811006892
1 parent 8bb5784 commit f400015

File tree

6 files changed

+43
-76
lines changed

6 files changed

+43
-76
lines changed

docs/api_reference/flax.traverse_util.rst

Lines changed: 0 additions & 53 deletions
This file was deleted.

docs/api_reference/flax.config.rst renamed to docs_nnx/api_reference/flax.config.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ flax.config package
55
.. automodule:: flax.configurations
66
:members:
77
:undoc-members:
8-
:exclude-members: FlagHolder, bool_flag, temp_flip_flag, static_bool_env
8+
:exclude-members: FlagHolder, bool_flag, static_bool_env
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
flax.traverse_util package
2+
============================
3+
4+
.. currentmodule:: flax.traverse_util
5+
6+
.. automodule:: flax.traverse_util
7+
8+
Dict utils
9+
------------
10+
11+
.. autofunction:: flatten_dict
12+
13+
.. autofunction:: unflatten_dict
14+
15+
.. autofunction:: path_aware_map

flax/configurations.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,21 @@ def __repr__(self):
7272
values_repr = ', '.join(f'\n {k}={v!r}' for k, v in self._values.items())
7373
return f'Config({values_repr}\n)'
7474

75+
@contextmanager
76+
def temp_flip_flag(self, var_name: str, var_value: bool):
77+
"""Context manager to temporarily flip feature flags for test functions.
78+
79+
Args:
80+
var_name: the config variable name (without the 'flax_' prefix)
81+
var_value: the boolean value to set var_name to temporarily
82+
"""
83+
old_value = getattr(self, f'flax_{var_name}')
84+
try:
85+
self.update(f'flax_{var_name}', var_value)
86+
yield
87+
finally:
88+
self.update(f'flax_{var_name}', old_value)
89+
7590

7691
config = Config()
7792

@@ -207,22 +222,6 @@ def static_int_env(varname: str, default: int | None) -> int | None:
207222
) from None
208223

209224

210-
@contextmanager
211-
def temp_flip_flag(var_name: str, var_value: bool):
212-
"""Context manager to temporarily flip feature flags for test functions.
213-
214-
Args:
215-
var_name: the config variable name (without the 'flax_' prefix)
216-
var_value: the boolean value to set var_name to temporarily
217-
"""
218-
old_value = getattr(config, f'flax_{var_name}')
219-
try:
220-
config.update(f'flax_{var_name}', var_value)
221-
yield
222-
finally:
223-
config.update(f'flax_{var_name}', old_value)
224-
225-
226225
# Flax Global Configuration Variables:
227226

228227
flax_filter_frames = bool_flag(

tests/configurations_test.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,30 +23,36 @@ class MyTestCase(absltest.TestCase):
2323
def setUp(self):
2424
super().setUp()
2525
self.enter_context(mock.patch.object(config, '_values', {}))
26-
self._flag = bool_flag('test', default=False, help='Just a test flag.')
26+
self._flag = bool_flag('flax_test', default=False, help='Just a test flag.')
2727

2828
def test_duplicate_flag(self):
2929
with self.assertRaisesRegex(RuntimeError, 'already defined'):
3030
bool_flag(self._flag.name, default=False, help='Another test flag.')
3131

3232
def test_default(self):
3333
self.assertFalse(self._flag.value)
34-
self.assertFalse(config.test)
34+
self.assertFalse(config.flax_test)
3535

3636
def test_typed_update(self):
3737
config.update(self._flag, True)
3838
self.assertTrue(self._flag.value)
39-
self.assertTrue(config.test)
39+
self.assertTrue(config.flax_test)
4040

4141
def test_untyped_update(self):
4242
config.update(self._flag.name, True)
4343
self.assertTrue(self._flag.value)
44-
self.assertTrue(config.test)
44+
self.assertTrue(config.flax_test)
4545

4646
def test_update_unknown_flag(self):
4747
with self.assertRaisesRegex(LookupError, 'Unrecognized config option'):
4848
config.update('unknown', True)
4949

50+
def test_temp_flip(self):
51+
self.assertFalse(self._flag.value)
52+
with config.temp_flip_flag('test', True):
53+
self.assertTrue(self._flag.value)
54+
self.assertFalse(self._flag.value)
55+
5056

5157
if __name__ == '__main__':
5258
absltest.main()

tests/core/core_scope_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from jax import random
2020

2121
from flax import errors
22-
from flax.configurations import temp_flip_flag
22+
from flax import config
2323
from flax.core import Scope, apply, freeze, init, lazy_init, nn, scope
2424
from flax.core.scope import LazyRng
2525

@@ -283,7 +283,7 @@ def f(scope, x):
283283
with self.assertRaises(errors.LazyInitError):
284284
init_fn(random.key(0), jax.ShapeDtypeStruct((8, 4), jnp.float32))
285285

286-
@temp_flip_flag('fix_rng_separator', True)
286+
@config.temp_flip_flag('fix_rng_separator', True)
287287
def test_fold_in_static_seperator(self):
288288
x = LazyRng(random.key(0), ('ab', 'c'))
289289
y = LazyRng(random.key(0), ('a', 'bc'))

0 commit comments

Comments
 (0)