diff --git a/.github/workflows/flax_test.yml b/.github/workflows/flax_test.yml index af6a285c6..58f99a0a7 100644 --- a/.github/workflows/flax_test.yml +++ b/.github/workflows/flax_test.yml @@ -76,8 +76,6 @@ jobs: - name: Install standalone dependencies only run: | uv sync - # temporary: install jax nightly - uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ - name: Test importing Flax run: | uv run --no-sync python -c "import flax" @@ -125,21 +123,13 @@ jobs: if [[ "${{ matrix.test-type }}" == "doctest" ]]; then # TODO(cgarciae): Remove this once dm-haiku 0.0.14 is released uv pip install -U git+https://github.com/google-deepmind/dm-haiku.git - # temporary: install jax nightly - uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ uv run --no-sync tests/run_all_tests.sh --only-doctest elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then uv pip install -U tensorflow-datasets - # temporary: install jax nightly - uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ uv run --no-sync tests/run_all_tests.sh --only-pytest elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then - # temporary: install jax nightly - uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ uv run --no-sync tests/run_all_tests.sh --only-pytype elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then - # temporary: install jax nightly - uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ uv run --no-sync tests/run_all_tests.sh --only-mypy else echo "Unknown test type: ${{ matrix.test-type }}" diff --git a/pyproject.toml b/pyproject.toml index ebf5dcd90..56596a540 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "numpy>=1.23.2; python_version>='3.11'", "numpy>=1.26.0; python_version>='3.12'", # keep in sync with jax-version in .github/workflows/build.yml - "jax>=0.7.1", + "jax>=0.8.1", "msgpack", "optax", "orbax-checkpoint",