Skip to content

Commit b424c49

Browse files
committed
update jax minver to 0.8.1
1 parent 09537f4 commit b424c49

File tree

2 files changed

+1
-11
lines changed

2 files changed

+1
-11
lines changed

.github/workflows/flax_test.yml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ jobs:
7676
- name: Install standalone dependencies only
7777
run: |
7878
uv sync
79-
# temporary: install jax nightly
80-
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
8179
- name: Test importing Flax
8280
run: |
8381
uv run --no-sync python -c "import flax"
@@ -125,21 +123,13 @@ jobs:
125123
if [[ "${{ matrix.test-type }}" == "doctest" ]]; then
126124
# TODO(cgarciae): Remove this once dm-haiku 0.0.14 is released
127125
uv pip install -U git+https://github.com/google-deepmind/dm-haiku.git
128-
# temporary: install jax nightly
129-
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
130126
uv run --no-sync tests/run_all_tests.sh --only-doctest
131127
elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then
132128
uv pip install -U tensorflow-datasets
133-
# temporary: install jax nightly
134-
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
135129
uv run --no-sync tests/run_all_tests.sh --only-pytest
136130
elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then
137-
# temporary: install jax nightly
138-
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
139131
uv run --no-sync tests/run_all_tests.sh --only-pytype
140132
elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then
141-
# temporary: install jax nightly
142-
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
143133
uv run --no-sync tests/run_all_tests.sh --only-mypy
144134
else
145135
echo "Unknown test type: ${{ matrix.test-type }}"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies = [
1414
"numpy>=1.23.2; python_version>='3.11'",
1515
"numpy>=1.26.0; python_version>='3.12'",
1616
# keep in sync with jax-version in .github/workflows/build.yml
17-
"jax>=0.7.1",
17+
"jax>=0.8.1",
1818
"msgpack",
1919
"optax",
2020
"orbax-checkpoint",

0 commit comments

Comments
 (0)