diff --git a/cosmos_predict2/_src/predict2/models/text2world_model_rectified_flow.py b/cosmos_predict2/_src/predict2/models/text2world_model_rectified_flow.py index c048804b4..660c12d9d 100644 --- a/cosmos_predict2/_src/predict2/models/text2world_model_rectified_flow.py +++ b/cosmos_predict2/_src/predict2/models/text2world_model_rectified_flow.py @@ -591,7 +591,7 @@ def generate_samples_from_batch( def validation_step( self, data: dict[str, torch.Tensor], iteration: int ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - pass + return self.training_step(data, iteration) @torch.no_grad() def forward(self, xt, t, condition: Text2WorldCondition):