Skip to content

[cfg] fix: sync strategy from ActorConfig/CriticConfig to EngineConfig#5885

Merged
wuxibin89 merged 1 commit intoverl-project:mainfrom
yifannnwu:fix/sync-strategy-to-engine-config
Apr 7, 2026
Merged

[cfg] fix: sync strategy from ActorConfig/CriticConfig to EngineConfig#5885
wuxibin89 merged 1 commit intoverl-project:mainfrom
yifannnwu:fix/sync-strategy-to-engine-config

Conversation

@yifannnwu
Copy link
Copy Markdown
Contributor

Summary

FSDPActorConfig.__post_init__ and FSDPCriticConfig.__post_init__ set self.engine = self.fsdp_config but never sync self.strategy to self.engine.strategy. Since EngineConfig.strategy defaults to None, engine_workers.py:162 always passes None as the backend to EngineRegistry.new(), which falls back to FSDP1 regardless of the user's actor.strategy setting.

This causes crashes for models that require FSDP2, such as Qwen3.5 and other models with multi-dimensional RoPE position_ids, where FSDP1's parameter wrapping breaks apply_rotary_pos_emb with shape mismatches.

Repro

  1. Set actor.strategy=fsdp2 with use_legacy_worker_impl=disable (new engine_workers.py path)
  2. Train any model
  3. engine_workers.py reads engine_config.strategy → gets None → defaults to FSDP1
  4. For models with multi-dimensional position_ids (Qwen3.5, Qwen3-VL), FSDP1 wrapping breaks apply_rotary_pos_emb

Fix

Sync strategy from the actor/critic config to the engine config in __post_init__:

object.__setattr__(self.engine, "strategy", self.strategy)

Uses object.__setattr__ because BaseConfig has frozen field logic that prevents normal attribute assignment.

Affected configs

  • FSDPActorConfig (verl/workers/config/actor.py)
  • FSDPCriticConfig (verl/workers/config/critic.py)

Note: McoreActorConfig, VeOmniActorConfig, TorchTitanActorConfig are not affected because their engine configs have matching hardcoded strategy defaults.

Impact

Affects all FSDP2 training using the new engine_workers.py path (use_legacy_worker_impl=disable). The legacy worker path is unaffected because it doesn't read engine_config.strategy.

Test plan

  • Verified engine_config.strategy == "fsdp2" when actor.strategy = "fsdp2" with use_legacy_worker_impl=disable
  • Trained Qwen3.5-0.8B with strategy=fsdp2 + use_legacy_worker_impl=disable — FSDP2 correctly applied, 3 training steps completed
  • Confirmed strategy=fsdp (default) is unchanged

FSDPActorConfig and FSDPCriticConfig set self.engine = self.fsdp_config
but never sync self.strategy to self.engine.strategy. Since
EngineConfig.strategy defaults to None, engine_workers.py (the new
worker path used with use_legacy_worker_impl=disable) always falls back
to FSDP1 regardless of the user's actor.strategy setting.

This causes crashes for models that require FSDP2, such as Qwen3.5 and
other models with multi-dimensional RoPE position_ids, where FSDP1's
parameter wrapping breaks apply_rotary_pos_emb with shape mismatches.

Fix: sync strategy in __post_init__ using object.__setattr__ (needed
because BaseConfig has frozen field logic).
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request ensures that the FSDP strategy is correctly propagated to the engine configuration in both actor and critic workers, preventing an unintended fallback to FSDP1. A review comment suggests also syncing the ulysses_sequence_parallel_size in the critic configuration to maintain consistency and ensure sequence parallelism settings are properly applied.

# Sync strategy to engine config so engine_workers can pick the right FSDP version.
# EngineConfig.strategy defaults to None, so without this, engine_workers.py always
# falls back to FSDP1 even when critic.strategy="fsdp2".
object.__setattr__(self.engine, "strategy", self.strategy)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In addition to syncing the strategy, ulysses_sequence_parallel_size should also be synced to the engine configuration in FSDPCriticConfig for consistency and backward compatibility, similar to the implementation in FSDPActorConfig. Without this, sequence parallelism settings defined at the top level of the critic configuration will not propagate to the underlying FSDP engine.

Note that ulysses_sequence_parallel_size is already defined as a mutable field in FSDPEngineConfig, so direct assignment is permitted.

        object.__setattr__(self.engine, "strategy", self.strategy)

        # backward compatibility
        if self.ulysses_sequence_parallel_size > 1:
            self.fsdp.ulysses_sequence_parallel_size = self.ulysses_sequence_parallel_size

@wuxibin89 wuxibin89 changed the title fix: sync strategy from ActorConfig/CriticConfig to EngineConfig [cfg] fix: sync strategy from ActorConfig/CriticConfig to EngineConfig Apr 7, 2026
@wuxibin89 wuxibin89 merged commit 74dc16c into verl-project:main Apr 7, 2026
59 of 69 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants