Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion cosmos_transfer2/multiview_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def validate_input_paths(self):
if not active_views:
raise ValueError("At least one view configuration with a control_path must be provided.")

if self.num_conditional_frames > 0 or self.enable_autoregressive:
if self.num_conditional_frames > 0 and not self.enable_autoregressive:
missing_input_paths = [
view_name for view_name, view_config in active_views if view_config.input_path is None
]
Expand All @@ -131,6 +131,8 @@ def validate_input_paths(self):
"input_path is required for all active views when num_conditional_frames > 0. "
f"Missing input_path for views: {', '.join(missing_input_paths)}"
)

if self.enable_autoregressive:
# Check per-view frame counts when autoregressive mode is enabled.
num_conditional_frames_per_view = [
view_config.num_conditional_frames_per_view for _, view_config in active_views
Expand All @@ -149,6 +151,16 @@ def validate_input_paths(self):
"num_conditional_frames_per_view must be consistent across all active views in autoregressive mode. "
"Either set it for all views or leave all at default (0)."
)

if any(frames > 0 for frames in num_conditional_frames_per_view):
missing_input_paths = [
view_name for view_name, view_config in active_views if view_config.input_path is None
]
if missing_input_paths:
raise ValueError(
"input_path is required for all active views when num_conditional_frames > 0. "
f"Missing input_path for views: {', '.join(missing_input_paths)}"
)
return self

@property
Expand Down