Skip to content

[BUG] Megatron Engine fails to save optimizer checkpoint #1341

@daihaowz

Description

@daihaowz

Checklist

  • The error occurs when using our provided Docker image.
  • I can consistently reproduce the bug across multiple trials or random seeds.
  • If the error causes experiment abortion, I've verified that this error is the root
    cause, not a secondary error caused by peer workers.

Detailed Information

Describe the bug

When MegatronEngine saves a recovery checkpoint with optimizer state enabled (default behavior), MegatronCheckpointManager.generate_state_dict() calls self.optimizer.sharded_state_dict(state_dict) without specifying metadata, which defaults to the fully_sharded_model_space sharding strategy.

This strategy sets flattened_range=item_slice in distrib_optimizer.py:1760, but megatron-core 0.16.0+ added a blanket rejection of non-None flattened_range in ShardedTensor.validate_metadata_integrity() (mapping.py:134), causing an exception during optimizer state serialization.

This is an internal inconsistency in megatron-core 0.16.0+: the default sharding strategy uses a feature that the same version explicitly forbids.

Root cause location: areal/engine/megatron_utils/checkpointer.py:283

optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict)

Introduced by: bf9b3c3be (2026-03-17) bumped megatron-core from 0.13.1 to 0.16.0 without adapting the checkpointer. A workaround (no_save_optim flag) was added in 58304de84 (2026-03-23), but the root cause was never fixed.

Expected behavior

Recovery checkpoints should save optimizer state (Adam momentum, variance) successfully, allowing training to resume without losing optimizer state and requiring a warmup restart.

Full logs

Traceback (most recent call last):
  File ".../megatron/core/optimizer/distrib_optimizer.py", line 1766, in _get_param_state_sharded_tensors
    tensors[state_key].validate_metadata_integrity()
  File ".../megatron/core/dist_checkpointing/mapping.py", line 134, in validate_metadata_integrity
    raise CheckpointingException("ShardedTensor.flattened_range is not supported.")
megatron.core.dist_checkpointing.core.CheckpointingException: ShardedTensor.flattened_range is not supported.

To Reproduce

Commit ID

main branch

Environment

  • megatron-core >= 0.16.0 (pyproject.toml currently declares 0.17.0)
  • MegatronEngine with distributed optimizer enabled
  • RecoverConfig with default no_save_optim=False

Script

Any RL training script using MegatronEngine that triggers a recovery checkpoint save. For example:

recover:
  mode: auto
  no_save_optim: false  # default value, triggers the bug

Fix: Pass metadata={"distrib_optim_sharding_type": "dp_reshardable"} in checkpointer.py:283. The dp_reshardable strategy sets flattened_range=None and works correctly under 0.16.0+. The only limitation is that DP configuration cannot change on resume, which is not a concern for recovery checkpoints.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions