Skip to content

Revert "use Pytree as base type in nnx_wrappers"#3

Merged
gulsumgudukbay merged 3 commits into
mainfrom
new_src_layout_modifications
Sep 23, 2025
Merged

Revert "use Pytree as base type in nnx_wrappers"#3
gulsumgudukbay merged 3 commits into
mainfrom
new_src_layout_modifications

Conversation

@gulsumgudukbay

@gulsumgudukbay gulsumgudukbay commented Sep 23, 2025

Copy link
Copy Markdown
Collaborator

Description

This reverts commit 8840296. This commit was using Pytree as base type in nnx_wrappers. Currently, with JAX 0.6.0 on ROCm, the latest flax version supported is 0.10.7, hence this commit is not valid and breaks the execution.

Note: This commit should be reverted when flax version is upgraded to anything >= 0.11.0. Right now, JAX 0.6.0 does not support any flax version >= 0.10.7, hence this commit is required as flax started exposing PyTree after 0.11.0.

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

I ran train.py with the following args:
python -m MaxText.train MaxText/configs/base.yml run_name=test hardware=gpu steps=5 model_name=llama2-7b attention=cudnn_flash_te enable_checkpointing=False ici_expert_parallelism=1 ici_fsdp_parallelism=-1 ici_data_parallelism=1 remat_policy=minimal scan_layers=True dataset_type=synthetic logits_dot_in_fp32=False dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_target_length=2048 shardy=False

It was able to successfully complete execution.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

gulsumgudukbay and others added 3 commits September 23, 2025 05:54
This reverts commit 8840296.

Note: This commit should be reverted when flax version is upgraded
to anything >= 0.11.0. Right now, JAX 0.6.0 does not support any flax
version >= 0.10.7, hence this commit is required as flax started
exposing PyTree after 0.11.0.
@gulsumgudukbay gulsumgudukbay merged commit 3e0b587 into main Sep 23, 2025
2 of 3 checks passed
@gulsumgudukbay gulsumgudukbay deleted the new_src_layout_modifications branch September 23, 2025 06:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant