Skip to content

Commit 6611d4b

Browse files
committed
Merge branch 'main' into sharding_metadata
2 parents 1cc5511 + 5109e2c commit 6611d4b

31 files changed

+2957
-1711
lines changed

.github/workflows/flax_test.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ jobs:
8080
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
8181
- name: Test importing Flax
8282
run: |
83-
uv run python -c "import flax"
83+
uv run --no-sync python -c "import flax"
8484
8585
tests:
8686
name: Run Tests
@@ -127,20 +127,20 @@ jobs:
127127
uv pip install -U git+https://github.com/google-deepmind/dm-haiku.git
128128
# temporary: install jax nightly
129129
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
130-
uv run tests/run_all_tests.sh --only-doctest
130+
uv run --no-sync tests/run_all_tests.sh --only-doctest
131131
elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then
132132
uv pip install -U tensorflow-datasets
133133
# temporary: install jax nightly
134134
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
135-
uv run tests/run_all_tests.sh --only-pytest
135+
uv run --no-sync tests/run_all_tests.sh --only-pytest
136136
elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then
137-
# temporary: install jax nightly
137+
# temporary: install jax nightly
138138
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
139-
uv run tests/run_all_tests.sh --only-pytype
139+
uv run --no-sync tests/run_all_tests.sh --only-pytype
140140
elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then
141141
# temporary: install jax nightly
142142
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
143-
uv run tests/run_all_tests.sh --only-mypy
143+
uv run --no-sync tests/run_all_tests.sh --only-mypy
144144
else
145145
echo "Unknown test type: ${{ matrix.test-type }}"
146146
exit 1

docs_nnx/api_reference/flax.nnx/graph.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ graph
3131

3232
.. autofunction:: find_duplicates
3333
.. autofunction:: pure
34-
.. autofunction:: to_refs
35-
.. autofunction:: to_arrays
34+
.. autofunction:: as_immutable_vars
35+
.. autofunction:: as_mutable_vars
36+
.. autofunction:: as_hijax_vars
37+
.. autofunction:: as_pytree_vars
38+
.. autofunction:: as_ref_vars
39+
.. autofunction:: as_array_vars
3640
.. autofunction:: flatten
3741
.. autofunction:: unflatten

0 commit comments

Comments
 (0)