Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/post_training/configs/sft.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ fsdp_offload = false
fsdp_reshard_after_forward = "default"
train_batch_per_replica = 32
sync_weight_interval = 1
enable_validation = true
validation_step = 30
validation_batch_per_replica = 2

[validation]
enable = true
freq = 30
batch_size = 2

[policy]
model_name_or_path = "nvidia/Cosmos-Reason1-7B"
Expand Down
11 changes: 5 additions & 6 deletions examples/post_training/tools/dataset/cosmos_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from cosmos_rl.dispatcher.data.packer import DataPacker, Qwen2_5_VLM_DataPacker
from cosmos_rl.launcher.worker_entry import main as launch_worker
from cosmos_rl.policy.config import Config
from cosmos_rl.policy.config import Config as CosmosConfig
from cosmos_rl.utils.logging import logger
from cosmos_rl.utils.util import basename_from_modelpath
from datasets import load_dataset
Expand Down Expand Up @@ -65,7 +64,7 @@ def get_mm_files_paths(self, dataset_name: str, dataset_subset: str):
mm_files_paths[file] = os.path.join(root, file)
return mm_files_paths

def setup(self, config: CosmosConfig, tokenizer: AutoTokenizer, *args, **kwargs):
def setup(self, config: Config, tokenizer: AutoTokenizer, *args, **kwargs):
self.config = config
self.tokenizer = tokenizer
self.dataset = load_dataset(
Expand Down Expand Up @@ -268,16 +267,16 @@ def policy_collate_fn(
config = Config.from_dict(config)

util.prepare_cosmos_data(dataset=config.train.train_policy.dataset)
if config.train.enable_validation:
if config.validation.enable:
util.prepare_cosmos_data(dataset=config.validation.dataset)

# It is best practice to pass the dataset and val_dataset as factory functions
# so that the dataset and val_dataset can be loaded on demand. (Not all workers need them)
def get_dataset(config: CosmosConfig) -> Dataset:
def get_dataset(config: Config) -> Dataset:
return CosmosGRPODataset()

def get_val_dataset(config: CosmosConfig) -> Dataset:
return CosmosGRPOValDataset() if config.train.enable_validation else None
def get_val_dataset(config: Config) -> Dataset:
return CosmosGRPOValDataset() if config.validation.enable else None

launch_worker(
dataset=get_dataset,
Expand Down