From ba024979446b7633c2ad8e7d1290f53103b91951 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Tue, 3 Feb 2026 07:14:24 -0800 Subject: [PATCH] bug: fix checkpoint override --- cosmos_transfer2/inference.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/cosmos_transfer2/inference.py b/cosmos_transfer2/inference.py index eecd3b04..8b0a8824 100644 --- a/cosmos_transfer2/inference.py +++ b/cosmos_transfer2/inference.py @@ -48,8 +48,8 @@ def __init__( if len(self.batch_hint_keys) == 1: # pyrefly: ignore # bad-argument-type checkpoint = MODEL_CHECKPOINTS[ModelKey(variant=self.batch_hint_keys[0])] - self.checkpoint_list = [checkpoint.path] - self.experiment = checkpoint.experiment + self.checkpoint_list = [args.checkpoint_path] if args.checkpoint_path else [checkpoint.path] + self.experiment = args.experiment if args.experiment else checkpoint.experiment else: # pyrefly: ignore # bad-argument-type self.checkpoint_list = [MODEL_CHECKPOINTS[ModelKey(variant=key)].path for key in self.batch_hint_keys] @@ -85,12 +85,20 @@ def __init__( self.video_guardrail_runner = None self.benchmark_timer = misc.TrainingTimer() + + if args.checkpoint_path: + registered_exp_name = self.experiment + exp_override_opts = [] + else: + registered_exp_name = EXPERIMENTS[self.experiment].registered_exp_name + exp_override_opts = EXPERIMENTS[self.experiment].command_args + # Initialize the inference class self.inference_pipeline = ControlVideo2WorldInference( - registered_exp_name=EXPERIMENTS[self.experiment].registered_exp_name, + registered_exp_name=registered_exp_name, checkpoint_paths=self.checkpoint_list, s3_credential_path="", - exp_override_opts=EXPERIMENTS[self.experiment].command_args, + exp_override_opts=exp_override_opts, process_group=process_group, use_cp_wan=args.enable_parallel_tokenizer, wan_cp_grid=args.parallel_tokenizer_grid,