Skip to content

Commit d803065

Browse files
authored
feat: add explicit vlm config flag for VLM detection (#2063)
1 parent 19511dc commit d803065

File tree

11 files changed

+240
-118
lines changed

11 files changed

+240
-118
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes).
44

5+
- **`[model.vlm]` (NEW — replaces auto-detection)**: VLM mode is now opt-in via a `[model.vlm]` sub-config with required `vision_encoder_attr` and `language_model_attr` fields. There is no auto-detection — if you train a VLM, you must add `[model.vlm]`. Existing multimodal configs need the new section. See `docs/multimodal.md` for the table of known model attrs. (2026-03-24)
56
- **`model.optimization_dtype` / `model.reduce_dtype` (VLM models, RL only)**: VLM dtype validation now only applies to RL training (`TrainerConfig`), not SFT. VLM models used with `sft` no longer require `optimization_dtype='bfloat16'` / `reduce_dtype='bfloat16'`. RL training still enforces both to match vLLM inference. (2026-03-24)
67
- **`model.optimization_dtype` / `model.reduce_dtype` (VLM models)**: Added validation that VLM models must use `optimization_dtype='bfloat16'` and `reduce_dtype='bfloat16'` to match vLLM inference. Previously valid configs with `float32` (the default) are now rejected for VLM model names. Set both fields to `"bfloat16"` when training VLMs. (2026-03-21)
78
- **`orchestrator.advantage.length_weighted_mean`**: Removed. The default advantage now always uses the plain per-problem mean baseline unless `orchestrator.advantage.length_shaping_alpha` is set. Existing configs must delete this field. (2026-03-19)

configs/ci/nightly/multimodal_color_codeword.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ num_infer_gpus = 4
99
[model]
1010
name = "Qwen/Qwen3-VL-4B-Instruct"
1111

12+
[model.vlm]
13+
vision_encoder_attr = "model.visual"
14+
language_model_attr = "model.language_model"
15+
1216
[orchestrator]
1317
batch_size = 256
1418
rollouts_per_example = 16

configs/multimodal/rl_color_codeword.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ seq_len = 4096
44
[model]
55
name = "Qwen/Qwen3-VL-4B-Instruct"
66

7+
[model.vlm]
8+
vision_encoder_attr = "model.visual"
9+
language_model_attr = "model.language_model"
10+
711
[orchestrator]
812
batch_size = 256
913
rollouts_per_example = 16
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
max_steps = 3
2+
seq_len = 2048
3+
output_dir = "outputs/rl_color_codeword_test"
4+
5+
[model]
6+
name = "Qwen/Qwen3-VL-4B-Instruct"
7+
8+
[model.vlm]
9+
vision_encoder_attr = "model.visual"
10+
language_model_attr = "model.language_model"
11+
12+
[orchestrator]
13+
batch_size = 16
14+
rollouts_per_example = 2
15+
16+
[orchestrator.sampling]
17+
max_tokens = 32
18+
19+
[[orchestrator.env]]
20+
id = "color-codeword"
21+
args = { images_per_turn = 1, max_turns = 2, num_examples = 100, seed = 42 }
22+
23+
[trainer]
24+
25+
[trainer.model]
26+
optimization_dtype = "bfloat16"
27+
reduce_dtype = "bfloat16"
28+
29+
[trainer.optim]
30+
lr = 3e-6
31+
32+
[inference]
33+
34+
[inference.parallel]
35+
dp = 1

docs/multimodal.md

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,62 @@
11
# Multimodal (VLM) Support
22

3-
Prime-RL has experimental support for training vision-language models (VLMs) like Qwen3-VL.
3+
Prime-RL supports training vision-language models (VLMs) like Qwen3-VL.
44

5-
## Current Limitations
5+
## VLM Configuration
6+
7+
### Supported Models
8+
9+
The built-in registry supports these model families out of the box:
10+
11+
| Model Family | model_type | Vision Encoder | Language Model |
12+
|-------------|------------|---------------|----------------|
13+
| Qwen3-VL | `qwen3_vl` | `model.visual` | `model.language_model` |
14+
| Qwen3.5 | `qwen3_5` | `model.visual` | `model.language_model` |
15+
| Qwen3.5-MoE | `qwen3_5_moe` | `model.visual` | `model.language_model` |
16+
17+
Enable VLM mode by adding a `[model.vlm]` section. Both fields are required — they tell prime-rl where the vision encoder and language model live on the model object:
18+
19+
```toml
20+
[model]
21+
name = "Qwen/Qwen3-VL-4B-Instruct"
22+
23+
[model.vlm]
24+
vision_encoder_attr = "model.visual"
25+
language_model_attr = "model.language_model"
26+
```
627

7-
- **No SFT support**: Supervised fine-tuning is not yet supported for VLM models. Only RL training is available.
28+
For the registered models in the table above, use the attrs shown there. For custom VLMs, check your model's structure with `model.named_children()`.
29+
30+
Both fields are dotted attribute paths resolved on the loaded model. A bad path raises a `ValueError` immediately — there are no silent fallbacks.
31+
32+
The weight key prefix for NCCL broadcasting is derived automatically as `{language_model_attr}.layers.`.
33+
34+
To add permanent support for a new model family, add an entry to `VLM_REGISTRY` in `src/prime_rl/utils/vlm.py`.
35+
36+
## Current Limitations
837

938
- **Vision encoder is frozen**: The vision encoder is automatically frozen during training. Only the language model is trained.
1039

11-
- **No multimodal-safe truncation**: Token sequences are truncated to `seq_len`, but `pixel_values` and `image_grid_thw` are passed through unchanged. If a multimodal sample exceeds `seq_len`, image tokens can be dropped while image tensors still describe the full set of images. Ensure `seq_len` covers your longest VLM samples or avoid overlong rollouts.
40+
- **No multimodal-safe truncation**: Token sequences are truncated to `seq_len`, but `pixel_values` and `image_grid_thw` are passed through unchanged. If a multimodal sample exceeds `seq_len`, image tokens can be dropped while image tensors still describe the full set of images. Ensure `seq_len` covers your longest VLM samples.
1241

13-
- **The images that the VLM sees are not logged**
42+
- **Optimization dtype must be bfloat16**: Set `optimization_dtype = "bfloat16"` and `reduce_dtype = "bfloat16"` in your trainer config.
1443

15-
- **Optimization dtype must be bfloat16**: VLM models must load in bfloat16 to match vLLM inference. If the trainer uses a different dtype, the vision encoder produces different `pixel_values`, causing a mismatch between inference and training. A workaround would be to propagate the `pixel_values` computed by vLLM to the trainer, but this is more involved. For now, set `optimization_dtype = "bfloat16"` and `reduce_dtype = "bfloat16"` in your trainer config.
44+
- **Higher KL mismatch with multi-image inputs**: VLM training exhibits higher KL mismatch compared to text-only, especially with multiple images.
1645

17-
- **Higher KL mismatch with multi-image inputs**: VLM training exhibits higher KL mismatch between inference and trainer logprobs compared to text-only models, especially with multiple images per sample. We are investigating the root cause. The existing importance ratio masking thresholds should handle reasonable mismatches.
46+
- **Images are not logged**: The images the VLM sees during training are not logged to monitors.
1847

19-
## How Multi-Turn VLM Training Works
48+
## How Multi-Turn VLM RL Training Works
2049

21-
VLM training uses the same `interleave_rollout` path as text-only models. Multi-turn trajectory steps are merged into a single training sample wherever the extension property holds (consecutive steps share a token prefix). When extension breaks (e.g., due to context compaction), a new sample is started automatically.
50+
VLM training uses the same `interleave_rollout` path as text-only models. Multi-turn trajectory steps are merged into a single training sample wherever the extension property holds.
2251

2352
Images are handled via a `VLMImageCache` built once per batch:
2453

25-
1. **Extract**: Base64 images are decoded from trajectory step prompts into PIL images. Since prompts are cumulative, only new images per step are extracted.
26-
2. **Preprocess**: All images are processed in a single batched call through the HuggingFace image processor, producing `pixel_values` (patches) and `image_grid_thw` (grid dimensions).
27-
3. **Attach**: Each training sample receives the cumulative `pixel_values` up to its last merged step. When steps are merged, the sample's images are updated to include all images seen so far.
28-
29-
This works correctly for all combinations: images in early turns with text-only follow-ups, images appearing mid-conversation, new images accumulating across turns, and interleaved agents with separate image streams.
54+
1. **Extract**: Base64 images are decoded from trajectory step prompts into PIL images.
55+
2. **Preprocess**: Images are processed through the HuggingFace image processor, producing `pixel_values` and `image_grid_thw`.
56+
3. **Attach**: Each training sample receives the cumulative `pixel_values` up to its last merged step.
3057

31-
Each multimodal sample becomes its own micro-batch during training (no packing with other samples) since image tensor sizes vary per sample.
58+
Each multimodal sample becomes its own micro-batch during training (no packing) since image tensor sizes vary.
3259

3360
## vLLM Configuration
3461

35-
`VLLM_WORKER_MULTIPROC_METHOD=spawn` is required for VLM inference. This is set automatically in `src/prime_rl/inference/config.py`, so if you use `uv run rl @ ...` it works out of the box, but if you start the vLLM server yourself, make sure this environment variable is set.
62+
`VLLM_WORKER_MULTIPROC_METHOD=spawn` is required for VLM inference. This is set automatically when using `uv run rl @ ...`, but if you start the vLLM server yourself, make sure this environment variable is set.

src/prime_rl/configs/rl.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from prime_rl.configs.shared import (
2121
SlurmConfig,
22+
VLMConfig,
2223
WandbConfig,
2324
WandbWithExtrasConfig,
2425
)
@@ -127,6 +128,11 @@ class SharedModelConfig(BaseConfig):
127128
Field(description="The name of the model to use."),
128129
] = "Qwen/Qwen3-0.6B"
129130

131+
vlm: Annotated[
132+
"VLMConfig | None",
133+
Field(description="VLM configuration. Set to enable vision-language model support."),
134+
] = None
135+
130136

131137
class SharedWeightBroadcastConfig(BaseConfig):
132138
"""Configures shared weight broadcast settings."""
@@ -520,6 +526,12 @@ def auto_setup_model(self):
520526
else:
521527
self.orchestrator.model.name = self.model.name
522528

529+
if self.model.vlm is not None:
530+
self.trainer.model.vlm = self.model.vlm
531+
self.orchestrator.model.vlm = self.model.vlm
532+
if self.inference is not None:
533+
self.inference.model.vlm = self.model.vlm
534+
523535
validate_shared_model_name(self.trainer, self.orchestrator, self.inference)
524536

525537
return self

src/prime_rl/configs/shared.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,29 @@ def resolve_project_dir(self):
8080
ServerType = Literal["vllm", "openai"]
8181

8282

83+
class VLMConfig(BaseConfig):
84+
"""Configures vision-language model support.
85+
86+
Presence of this config enables VLM mode. You must specify where the
87+
vision encoder and language model live on the model object.
88+
89+
Usage:
90+
[model.vlm]
91+
vision_encoder_attr = "model.visual"
92+
language_model_attr = "model.language_model"
93+
"""
94+
95+
vision_encoder_attr: Annotated[
96+
str,
97+
Field(description="Dotted attribute path to the vision encoder module (e.g. 'model.visual')."),
98+
]
99+
100+
language_model_attr: Annotated[
101+
str,
102+
Field(description="Dotted attribute path to the language model module (e.g. 'model.language_model')."),
103+
]
104+
105+
83106
class BaseModelConfig(BaseConfig):
84107
"""Configures the model."""
85108

@@ -92,6 +115,13 @@ class BaseModelConfig(BaseConfig):
92115
),
93116
] = False
94117

118+
vlm: Annotated[
119+
"VLMConfig | None",
120+
Field(
121+
description="VLM configuration. Set this to enable vision-language model support.",
122+
),
123+
] = None
124+
95125

96126
class ElasticConfig(BaseConfig):
97127
"""Configures elastic inference pool with DNS-based service discovery.

src/prime_rl/configs/trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
WandbConfig,
1414
)
1515
from prime_rl.utils.config import BaseConfig
16-
from prime_rl.utils.vlm import is_vlm_model
1716

1817
# -- Shared trainer configs (used by both SFT and RL trainers) --
1918

@@ -714,7 +713,7 @@ class TrainerConfig(BaseConfig):
714713

715714
@model_validator(mode="after")
716715
def vlms_require_bfloat16(self):
717-
if is_vlm_model(self.model.name) and (
716+
if self.model.vlm is not None and (
718717
self.model.optimization_dtype != "bfloat16" or self.model.reduce_dtype != "bfloat16"
719718
):
720719
raise ValueError(

src/prime_rl/orchestrator/orchestrator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
strip_env_version,
7676
to_col_format,
7777
)
78-
from prime_rl.utils.vlm import is_vlm_model
7978

8079

8180
@clean_exit
@@ -137,7 +136,7 @@ async def orchestrate(config: OrchestratorConfig):
137136
teacher_inference_pool = None
138137

139138
# Check if this is a vision-language model (used throughout for VLM-specific paths)
140-
is_vlm = is_vlm_model(config.model.name)
139+
is_vlm = config.model.vlm is not None
141140

142141
# Load tokenizer and processor (processor only for VLM models)
143142
logger.info(f"Initializing tokenizer for {config.model.name}")

src/prime_rl/trainer/model.py

Lines changed: 15 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from prime_rl.trainer.world import get_world
4343
from prime_rl.utils.logger import get_logger
44-
from prime_rl.utils.vlm import is_vlm_config, is_vlm_model
44+
from prime_rl.utils.vlm import get_language_model, get_vision_encoder
4545

4646

4747
def _patch_qwen3_5_moe_conversion_mapping():
@@ -118,24 +118,11 @@ def _patched_forward(self, hidden_states, position_ids=None, **kwargs):
118118
torch._dynamo.config.cache_size_limit = 64 # default: 8
119119

120120

121-
def freeze_vision_encoder(model: nn.Module) -> None:
122-
"""Freeze the vision encoder parameters for VLM training.
123-
124-
For Qwen3-VL, the vision encoder is at model.model.visual.
125-
This freezes all parameters in the vision encoder so only the
126-
language model (with LoRA) is trained.
127-
"""
121+
def freeze_vision_encoder(model: nn.Module, override_attr: str | None = None) -> None:
128122
logger = get_logger()
129-
130-
# Qwen3-VL structure: model.model.visual
131-
if hasattr(model, "model") and hasattr(model.model, "visual"):
132-
vision_encoder = model.model.visual
133-
# Qwen2-VL structure: model.visual
134-
elif hasattr(model, "visual"):
135-
vision_encoder = model.visual
136-
else:
137-
raise ValueError("Could not find vision encoder to freeze. Expected model.model.visual or model.visual")
138-
123+
vision_encoder = get_vision_encoder(model, override=override_attr)
124+
if vision_encoder is None:
125+
raise ValueError("Could not find vision encoder to freeze")
139126
num_frozen = 0
140127
for param in vision_encoder.parameters():
141128
param.requires_grad = False
@@ -175,17 +162,6 @@ def is_tt_moe_model(model: nn.Module) -> bool:
175162
return hasattr(model.config, "num_experts") or hasattr(model.config, "n_routed_experts")
176163

177164

178-
def get_language_model(model: nn.Module) -> nn.Module:
179-
"""Get the language model component containing transformer layers.
180-
181-
For VLM models (Qwen3-VL): model.model.language_model
182-
For text-only models: model.model
183-
"""
184-
if hasattr(model.model, "language_model"):
185-
return model.model.language_model
186-
return model.model
187-
188-
189165
def get_load_balance_stats(
190166
model: nn.Module, reset_stats: bool = True, try_to_avoid_padding_experts: bool = True
191167
) -> dict[str, Tensor | None]:
@@ -218,8 +194,8 @@ def get_model(
218194
f"Loading model config (name={config.name}, attn={config.attn}, trust_remote_code={config.trust_remote_code})"
219195
)
220196

221-
# Check if this is a vision-language model (by name pattern first)
222-
is_vlm = is_vlm_model(config.name)
197+
# VLM mode is enabled by setting [model.vlm] in config
198+
is_vlm = config.vlm is not None
223199

224200
if "Qwen3.5" in config.name or "qwen3_5" in config.name.lower():
225201
_patch_qwen3_5_text_position_ids()
@@ -233,9 +209,6 @@ def get_model(
233209
)
234210
model_config.use_cache = False
235211

236-
# Fallback VLM detection from loaded config (catches local paths)
237-
if not is_vlm and is_vlm_config(model_config):
238-
is_vlm = True
239212
if is_vlm:
240213
logger.info(f"Detected vision-language model: {config.name}")
241214

@@ -327,7 +300,7 @@ def get_model(
327300

328301
# For VLM models, freeze the vision encoder
329302
if is_vlm:
330-
freeze_vision_encoder(model)
303+
freeze_vision_encoder(model, override_attr=config.vlm.vision_encoder_attr)
331304

332305
assert model.lm_head.weight.dtype == dtype, (
333306
f"LM head dtype wasnt loaded correctly {model.lm_head.weight.dtype} != {dtype}"
@@ -365,33 +338,16 @@ def setup_fsdp(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDim
365338

366339
dp_mod_ep_mesh = parallel_dims.world_mesh[tuple(dp_mod_ep_mesh_dim_names)]
367340

368-
# For VLM models, shard the frozen vision encoder as a single unit
369-
# This allows FSDP to manage the memory while keeping it frozen
370-
is_vlm = is_vlm_model(config.name) or (hasattr(model, "model") and hasattr(model.model, "visual"))
341+
is_vlm = config.vlm is not None
371342
if is_vlm:
372-
if hasattr(model, "model") and hasattr(model.model, "visual"):
373-
vision_encoder = model.model.visual
374-
elif hasattr(model, "visual"):
375-
vision_encoder = model.visual
376-
else:
377-
raise ValueError(f"VLM model {config.name} does not have a recognized vision encoder attribute")
378-
379-
fully_shard(
380-
vision_encoder,
381-
mesh=hsdp_mesh,
382-
**fsdp_config,
383-
)
343+
vision_encoder = get_vision_encoder(model, override=config.vlm.vision_encoder_attr)
344+
if vision_encoder is None:
345+
raise ValueError(f"VLM model {config.name} has no recognized vision encoder")
346+
fully_shard(vision_encoder, mesh=hsdp_mesh, **fsdp_config)
384347
get_logger().info("Applied FSDP to frozen vision encoder")
385348

386-
# Get the language model layers (handle VLM structure)
387-
# For Qwen3-VL: model.model.language_model contains the transformer layers
388-
# For text-only models: model.model contains the layers directly
389-
if is_vlm:
390-
language_model = model.model.language_model
391-
transformer_layers = language_model.layers
392-
else:
393-
language_model = model.model
394-
transformer_layers = language_model.layers
349+
language_model = get_language_model(model, override=config.vlm.language_model_attr if is_vlm else None)
350+
transformer_layers = language_model.layers
395351

396352
for transformer_block in transformer_layers:
397353
if parallel_dims.ep_enabled and isinstance(transformer_block.mlp, MoE):

0 commit comments

Comments
 (0)