diff --git a/cosmos_predict2/_src/predict2/action/inference/inference.py b/cosmos_predict2/_src/predict2/action/inference/inference.py index e30929f91..8e373699f 100644 --- a/cosmos_predict2/_src/predict2/action/inference/inference.py +++ b/cosmos_predict2/_src/predict2/action/inference/inference.py @@ -360,7 +360,9 @@ def main(): if args.single_chunk: break - chunk_list = [chunk_video[0]] + [chunk_video[i][: args.chunk_size] for i in range(1, len(chunk_video))] + # Drop the first frame of each subsequent chunk: it is the reconstructed + # conditioning frame (~ previous chunk's last frame) and would otherwise duplicate. + chunk_list = [chunk_video[0]] + [chunk_video[i][1:] for i in range(1, len(chunk_video))] chunk_video = np.concatenate(chunk_list, axis=0) if args.single_chunk: chunk_video_name = f"{args.save_root}/{img_name + '_single_chunk.mp4'}" diff --git a/cosmos_predict2/_src/predict2/action/inference/inference_gr00t.py b/cosmos_predict2/_src/predict2/action/inference/inference_gr00t.py index 7efc2be06..881308a3b 100644 --- a/cosmos_predict2/_src/predict2/action/inference/inference_gr00t.py +++ b/cosmos_predict2/_src/predict2/action/inference/inference_gr00t.py @@ -337,7 +337,9 @@ def main(): if args.single_chunk: break - chunk_list = [chunk_video[0]] + [chunk_video[i][: args.chunk_size] for i in range(1, len(chunk_video))] + # Drop the first frame of each subsequent chunk: it is the reconstructed + # conditioning frame (~ previous chunk's last frame) and would otherwise duplicate. + chunk_list = [chunk_video[0]] + [chunk_video[i][1:] for i in range(1, len(chunk_video))] chunk_video = np.concatenate(chunk_list, axis=0) if args.single_chunk: chunk_video_name = f"{args.save_root}/{img_name + '_single_chunk.mp4'}" diff --git a/cosmos_predict2/action_conditioned.py b/cosmos_predict2/action_conditioned.py index 45dcfc4b0..727557382 100644 --- a/cosmos_predict2/action_conditioned.py +++ b/cosmos_predict2/action_conditioned.py @@ -358,9 +358,9 @@ def inference( if inference_args.single_chunk: break - chunk_list = [chunk_video[0]] + [ - chunk_video[i][: inference_args.chunk_size] for i in range(1, len(chunk_video)) - ] + # Drop the first frame of each subsequent chunk: it is the reconstructed + # conditioning frame (~ previous chunk's last frame) and would otherwise duplicate. + chunk_list = [chunk_video[0]] + [chunk_video[i][1:] for i in range(1, len(chunk_video))] chunk_video = np.concatenate(chunk_list, axis=0) if inference_args.single_chunk: chunk_video_name = str(inference_args.save_root / f"{img_name}_single_chunk.mp4")