diff --git a/examples/post_training/configs/sft.toml b/examples/post_training/configs/sft.toml index ac34a15..bd8d5ac 100644 --- a/examples/post_training/configs/sft.toml +++ b/examples/post_training/configs/sft.toml @@ -35,9 +35,11 @@ fsdp_offload = false fsdp_reshard_after_forward = "default" train_batch_per_replica = 32 sync_weight_interval = 1 -enable_validation = true -validation_step = 30 -validation_batch_per_replica = 2 + +[validation] +enable = true +freq = 30 +batch_size = 2 [policy] model_name_or_path = "nvidia/Cosmos-Reason1-7B" diff --git a/examples/post_training/tools/dataset/cosmos_grpo.py b/examples/post_training/tools/dataset/cosmos_grpo.py index 6240105..f96e83f 100644 --- a/examples/post_training/tools/dataset/cosmos_grpo.py +++ b/examples/post_training/tools/dataset/cosmos_grpo.py @@ -31,7 +31,6 @@ from cosmos_rl.dispatcher.data.packer import DataPacker, Qwen2_5_VLM_DataPacker from cosmos_rl.launcher.worker_entry import main as launch_worker from cosmos_rl.policy.config import Config -from cosmos_rl.policy.config import Config as CosmosConfig from cosmos_rl.utils.logging import logger from cosmos_rl.utils.util import basename_from_modelpath from datasets import load_dataset @@ -65,7 +64,7 @@ def get_mm_files_paths(self, dataset_name: str, dataset_subset: str): mm_files_paths[file] = os.path.join(root, file) return mm_files_paths - def setup(self, config: CosmosConfig, tokenizer: AutoTokenizer, *args, **kwargs): + def setup(self, config: Config, tokenizer: AutoTokenizer, *args, **kwargs): self.config = config self.tokenizer = tokenizer self.dataset = load_dataset( @@ -268,16 +267,16 @@ def policy_collate_fn( config = Config.from_dict(config) util.prepare_cosmos_data(dataset=config.train.train_policy.dataset) - if config.train.enable_validation: + if config.validation.enable: util.prepare_cosmos_data(dataset=config.validation.dataset) # It is best practice to pass the dataset and val_dataset as factory functions # so that the dataset and val_dataset can be loaded on demand. (Not all workers need them) - def get_dataset(config: CosmosConfig) -> Dataset: + def get_dataset(config: Config) -> Dataset: return CosmosGRPODataset() - def get_val_dataset(config: CosmosConfig) -> Dataset: - return CosmosGRPOValDataset() if config.train.enable_validation else None + def get_val_dataset(config: Config) -> Dataset: + return CosmosGRPOValDataset() if config.validation.enable else None launch_worker( dataset=get_dataset,