Skip to content

Error when trying to post-train with validation #30

Description

@mtiezzi

Thanks for the amazing work!

I’m encountering an issue when trying to post-train a model and include a validation set. The script works fine without validation, but adding a validation dataset causes an error.

The validation dataset loads successfully when using Cosmos-Predict2, so it seems specific to the post-training setup of Predict 2.5.

I am launching the post-training via the following command:

torchrun --nproc_per_node=1 --master_port=12341 -m scripts.train --config=cosmos_predict2/_src/predict2/configs/video2world/config.py -- experiment=predict2_video2world_training_2b  trainer.run_validation=True trainer.validation_iter=100

and I am getting the following error:

2025-11-01 16:56:55 [WARNING  | cosmos_predict2._src.imaginaire.lazy_config.lazy]: Config is saved using omegaconf at /home/Projects/inProgress/cosmos-predict2.5/checkpoints/cosmos_predict_v2p5/video2world/2b_validation_error/config.yaml.
2025-11-01 16:56:55 wandb: wandb.init() called while a run is active and reinit is set to 'default', so returning the previous run.
2025-11-01 16:56:55 Traceback (most recent call last):
2025-11-01 16:56:55   File "/home/miniforge3/envs/cosmos-predict25-clean/lib/python3.10/runpy.py", line 196, in _run_module_as_main
2025-11-01 16:56:55     return _run_code(code, main_globals, None,
2025-11-01 16:56:55   File "/home/miniforge3/envs/cosmos-predict25-clean/lib/python3.10/runpy.py", line 86, in _run_code
2025-11-01 16:56:55     exec(code, run_globals)
2025-11-01 16:56:55   File "/home//Projects/inProgress/cosmos-predict2.5/scripts/train.py", line 110, in <module>
2025-11-01 16:56:55     launch(config, args)
2025-11-01 16:56:55   File "/home/Projects/inProgress/cosmos-predict2.5/.venv/lib/python3.10/site-packages/loguru/_logger.py", line 1297, in catch_wrapper
2025-11-01 16:56:55     return function(*args, **kwargs)
2025-11-01 16:56:55   File "/home/Projects/inProgress/cosmos-predict2.5/scripts/train.py", line 54, in launch
2025-11-01 16:56:55     trainer.train(
2025-11-01 16:56:55   File "/home/Projects/inProgress/cosmos-predict2.5/cosmos_predict2/_src/imaginaire/trainer.py", line 187, in train
2025-11-01 16:56:55     self.validate(model, dataloader_val, iteration=iteration)
2025-11-01 16:56:55   File "/home/Projects/inProgress/cosmos-predict2.5/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-11-01 16:56:55     return func(*args, **kwargs)
2025-11-01 16:56:55   File "/home/Projects/inProgress/cosmos-predict2.5/cosmos_predict2/_src/imaginaire/trainer.py", line 345, in validate
2025-11-01 16:56:55     with ema.ema_scope(model, enabled=model.config.ema.enabled):
2025-11-01 16:56:55   File "/home/miniforge3/envs/cosmos-predict25-clean/lib/python3.10/contextlib.py", line 135, in __enter__
2025-11-01 16:56:55     return next(self.gen)
2025-11-01 16:56:55   File "/home/Projects/inProgress/cosmos-predict2.5/cosmos_predict2/_src/imaginaire/utils/ema.py", line 325, in ema_scope
2025-11-01 16:56:55     assert hasattr(model, "ema") and isinstance(model.ema, (FastEmaModelUpdater, EMAModelTracker, PowerEMATracker))
2025-11-01 16:56:55 AssertionError
2025-11-01 16:56:55 [rank0]: Traceback (most recent call last):
2025-11-01 16:56:55 [rank0]:   File "/home/miniforge3/envs/cosmos-predict25-clean/lib/python3.10/runpy.py", line 196, in _run_module_as_main
2025-11-01 16:56:55 [rank0]:     return _run_code(code, main_globals, None,
2025-11-01 16:56:55 [rank0]:   File "/home/miniforge3/envs/cosmos-predict25-clean/lib/python3.10/runpy.py", line 86, in _run_code
2025-11-01 16:56:55 [rank0]:     exec(code, run_globals)
2025-11-01 16:56:55 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/scripts/train.py", line 110, in <module>
2025-11-01 16:56:55 [rank0]:     launch(config, args)
2025-11-01 16:56:55 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/.venv/lib/python3.10/site-packages/loguru/_logger.py", line 1297, in catch_wrapper
2025-11-01 16:56:55 [rank0]:     return function(*args, **kwargs)
2025-11-01 16:56:55 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/scripts/train.py", line 54, in launch
2025-11-01 16:56:55 [rank0]:     trainer.train(
2025-11-01 16:56:55 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/cosmos_predict2/_src/imaginaire/trainer.py", line 187, in train
2025-11-01 16:56:55 [rank0]:     self.validate(model, dataloader_val, iteration=iteration)
2025-11-01 16:56:55 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-11-01 16:56:55 [rank0]:     return func(*args, **kwargs)
2025-11-01 16:56:55 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/cosmos_predict2/_src/imaginaire/trainer.py", line 345, in validate
2025-11-01 16:56:55 [rank0]:     with ema.ema_scope(model, enabled=model.config.ema.enabled):
2025-11-01 16:56:55 [rank0]:   File "/home/miniforge3/envs/cosmos-predict25-clean/lib/python3.10/contextlib.py", line 135, in __enter__
2025-11-01 16:56:55 [rank0]:     return next(self.gen)
2025-11-01 16:56:55 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/cosmos_predict2/_src/imaginaire/utils/ema.py", line 325, in ema_scope
2025-11-01 16:56:55 [rank0]:     assert hasattr(model, "ema") and isinstance(model.ema, (FastEmaModelUpdater, EMAModelTracker, PowerEMATracker))
2025-11-01 16:56:55 [rank0]: AssertionError

I tried to deactivate the ema by adding model.config.ema.enabled=False:

torchrun --nproc_per_node=1 --master_port=12341 -m scripts.train --config=cosmos_predict2/_src/predict2/configs/video2world/config.py -- experiment=predict2_video2world_training_2b  trainer.run_validation=True trainer.validation_iter=100 model.config.ema.enabled=False

but this results in the validation_step returning a NoneType object:

2025-11-01 16:59:48 [WARNING  | cosmos_predict2._src.imaginaire.lazy_config.lazy]: Config is saved using omegaconf at /home/Projects/inProgress/cosmos-predict2.5/checkpoints/cosmos_predict_v2p5/video2world/2b_validation/config.yaml.
2025-11-01 16:59:48 wandb: wandb.init() called while a run is active and reinit is set to 'default', so returning the previous run.
2025-11-01 16:59:49 Traceback (most recent call last):
2025-11-01 16:59:49   File "/home/miniforge3/envs/cosmos-predict25-clean/lib/python3.10/runpy.py", line 196, in _run_module_as_main
2025-11-01 16:59:49     return _run_code(code, main_globals, None,
2025-11-01 16:59:49   File "/homeminiforge3/envs/cosmos-predict25-clean/lib/python3.10/runpy.py", line 86, in _run_code
2025-11-01 16:59:49     exec(code, run_globals)
2025-11-01 16:59:49   File "/home/Projects/inProgress/cosmos-predict2.5/scripts/train.py", line 110, in <module>
2025-11-01 16:59:49     launch(config, args)
2025-11-01 16:59:49   File "/homeProjects/inProgress/cosmos-predict2.5/.venv/lib/python3.10/site-packages/loguru/_logger.py", line 1297, in catch_wrapper
2025-11-01 16:59:49     return function(*args, **kwargs)
2025-11-01 16:59:49   File "/home/Projects/inProgress/cosmos-predict2.5/scripts/train.py", line 54, in launch
2025-11-01 16:59:49     trainer.train(
2025-11-01 16:59:49   File "/home/Projects/inProgress/cosmos-predict2.5/cosmos_predict2/_src/imaginaire/trainer.py", line 187, in train
2025-11-01 16:59:49     self.validate(model, dataloader_val, iteration=iteration)
2025-11-01 16:59:49   File "/home/Projects/inProgress/cosmos-predict2.5/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-11-01 16:59:49     return func(*args, **kwargs)
2025-11-01 16:59:49   File "/home/Projects/inProgress/cosmos-predict2.5/cosmos_predict2/_src/imaginaire/trainer.py", line 351, in validate
2025-11-01 16:59:49     output_batch, loss = model.validation_step(data_batch, iteration)
2025-11-01 16:59:49 TypeError: cannot unpack non-iterable NoneType object
2025-11-01 16:59:49 [rank0]: Traceback (most recent call last):
2025-11-01 16:59:49 [rank0]:   File "/home/miniforge3/envs/cosmos-predict25-clean/lib/python3.10/runpy.py", line 196, in _run_module_as_main
2025-11-01 16:59:49 [rank0]:     return _run_code(code, main_globals, None,
2025-11-01 16:59:49 [rank0]:   File "/home/miniforge3/envs/cosmos-predict25-clean/lib/python3.10/runpy.py", line 86, in _run_code
2025-11-01 16:59:49 [rank0]:     exec(code, run_globals)
2025-11-01 16:59:49 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/scripts/train.py", line 110, in <module>
2025-11-01 16:59:49 [rank0]:     launch(config, args)
2025-11-01 16:59:49 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/.venv/lib/python3.10/site-packages/loguru/_logger.py", line 1297, in catch_wrapper
2025-11-01 16:59:49 [rank0]:     return function(*args, **kwargs)
2025-11-01 16:59:49 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/scripts/train.py", line 54, in launch
2025-11-01 16:59:49 [rank0]:     trainer.train(
2025-11-01 16:59:49 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/cosmos_predict2/_src/imaginaire/trainer.py", line 187, in train
2025-11-01 16:59:49 [rank0]:     self.validate(model, dataloader_val, iteration=iteration)
2025-11-01 16:59:49 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-11-01 16:59:49 [rank0]:     return func(*args, **kwargs)
2025-11-01 16:59:49 [rank0]:   File "/home/Projects/inProgress/cosmos-predict2.5/cosmos_predict2/_src/imaginaire/trainer.py", line 351, in validate
2025-11-01 16:59:49 [rank0]:     output_batch, loss = model.validation_step(data_batch, iteration)
2025-11-01 16:59:49 [rank0]: TypeError: cannot unpack non-iterable NoneType object

I defined the post-training script as follows:

from hydra.core.config_store import ConfigStore

from cosmos_predict2._src.imaginaire.lazy_config import LazyCall as L
from cosmos_predict2._src.imaginaire.utils.checkpoint_db import get_checkpoint_path
from cosmos_predict2._src.predict2.datasets.local_datasets.dataset_video import (
    VideoDataset,
    get_generic_dataloader,
    get_sampler,
)
from cosmos_predict2.config import MODEL_CHECKPOINTS, ModelKey

DEFAULT_CHECKPOINT = MODEL_CHECKPOINTS[ModelKey(post_trained=False)]


#dataset and dataloader

full_dataset_train = L(VideoDataset)(
    dataset_dir="datasets/Full_splits/train",
    num_frames=93,
    video_size=(224, 224),  
)

full_dataset_val = L(VideoDataset)(
    dataset_dir="datasets/Full_splits/val",
    num_frames=93,
    video_size=(224, 224), 
)

# Create DataLoader with distributed sampler
dataloader_train = L(get_generic_dataloader)(
    dataset=Full_dataset_train,
    sampler=L(get_sampler)(dataset=full_dataset_train),
    batch_size=1,
    drop_last=True,
    num_workers=4,
    pin_memory=True,
)

# Create DataLoader with distributed sampler
dataloader_val_ = L(get_generic_dataloader)(
    dataset=Full_dataset_val,
    sampler=L(get_sampler)(dataset=full_dataset_val),
    batch_size=1,
    drop_last=True,
    num_workers=4,
    pin_memory=True,
)
# Video2World post-training configuration for 2B model
# torchrun --nproc_per_node=1 --master_port=12341 -m scripts.train --config=cosmos_predict2/_src/predict2/configs/video2world/config.py -- experiment=predict2_video2world_training_2b_groot_gr1_480
predict2_video2world_training_2b  = dict(
    defaults=[
        f"/experiment/{DEFAULT_CHECKPOINT.experiment}",
        {"override /data_train": "mock"},
        {"override /data_val": "mock"},
        "_self_",
    ],
    dataloader_train=dataloader_train,
    dataloader_val=dataloader_val,
    checkpoint=dict(
        save_iter=200,
        # pyrefly: ignore  # missing-attribute
        load_path=get_checkpoint_path(DEFAULT_CHECKPOINT.s3.uri),
        load_from_object_store=dict(
            enabled=False,
        ),
        save_to_object_store=dict(
            enabled=False,
        ),
    ),
    job=dict(
        project="cosmos_predict_v2p5",
        group="video2world",
        name="2b_480_validation",
    ),
    optimizer=dict(
        lr=2 ** (-14.5),
        weight_decay=0.001,
    ),
    scheduler=dict(
        f_max=[0.5],
        f_min=[0.2],
        warm_up_steps=[1_000],
        cycle_lengths=[100000],
    ),
    trainer=dict(
        logging_iter=100,
        max_iter=1000,
        callbacks=dict(
            heart_beat=dict(
                save_s3=False,
            ),
            iter_speed=dict(
                hit_thres=100,
                save_s3=False,
            ),
            device_monitor=dict(
                save_s3=False,
            ),
            every_n_sample_reg=dict(
                every_n=200,
                save_s3=False,
            ),
            every_n_sample_ema=dict(
                every_n=200,
                save_s3=False,
            ),
            wandb=dict(
                save_s3=False,
            ),
            wandb_10x=dict(
                save_s3=False,
            ),
            dataloader_speed=dict(
                save_s3=False,
            ),
        ),
    ),
    model_parallel=dict(
        context_parallel_size=1,
    ),
)

cs = ConfigStore.instance()

for _item in [
    predict2_video2world_training_2b ,
]:
    # Get the experiment name from the global variable
    experiment_name = [name.lower() for name, value in globals().items() if value is _item][0]
    cs.store(
        group="experiment",
        package="_global_",
        name=experiment_name,
        node=_item,
    )

System Information

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Fields

No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions