diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index aea16711ff..ece77a3b84 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -196,6 +196,7 @@ def compute_prompt_mini_batch_boundaries( train_batch_size: int, is_stepwise: bool, n_samples_per_prompt: int, + is_training: bool = True, ) -> List[Tuple[int, int]]: """Compute mini-batch ``(start, end)`` slices from a flat ``uids`` list. @@ -206,10 +207,12 @@ def compute_prompt_mini_batch_boundaries( train_batch_size: Number of prompts in a training batch. For sanity check. is_stepwise: Whether the training is step-wise. For sanity check. n_samples_per_prompt: how many samples per prompt. For sanity check. + is_training: Whether this is a training batch (strict validation) or eval batch (allows partial batches). + Defaults to True for backward compatibility. Returns: List of (start, end) indices of the mini-batches. The length of the list is the number of - mini-batches, guaranteed to be `train_batch_size // mini_batch_size` regardless of whether - the training is step-wise or not. + mini-batches, guaranteed to be `train_batch_size // mini_batch_size` during training, but may differ + during evaluation if the final batch is partial. Consecutive equal entries in ``uids`` belong to the same prompt. Each mini batch spans exactly ``mini_batch_size`` prompts (the last may be smaller if the total prompt count is not divisible @@ -244,23 +247,35 @@ def compute_prompt_mini_batch_boundaries( prompt_end_indices.append(i) prompt_end_indices.append(len(uids)) - # seen_uids should equal to the number of prompts and equal to `train_batch_size` + # Check that num_prompts matches expected batch size num_prompts = len(prompt_end_indices) - assert num_prompts == train_batch_size and len(seen_uids) == train_batch_size - assert train_batch_size % mini_batch_size == 0 - - # Compute boundaries. + if is_training: + assert ( + num_prompts == train_batch_size and len(seen_uids) == train_batch_size + ), f"Expected {train_batch_size} prompts in training batch, got {num_prompts}." + assert train_batch_size % mini_batch_size == 0 + else: + if num_prompts != train_batch_size: + logger.info( + f"Partial batch detected during eval: got {num_prompts} prompts but " + f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries." + ) + + # Compute boundaries. Handle partial batches during eval. boundaries: List[Tuple[int, int]] = [] start_seq = 0 + for i in range(0, num_prompts, mini_batch_size): - end_prompt_idx = i + mini_batch_size - 1 # i + mini_batch_size is next mini-batch's first prompt's end index + end_prompt_idx = min(i + mini_batch_size - 1, num_prompts - 1) end_seq = prompt_end_indices[end_prompt_idx] boundaries.append((start_seq, end_seq)) start_seq = end_seq - assert len(boundaries) == train_batch_size // mini_batch_size + + if is_training: + assert len(boundaries) == train_batch_size // mini_batch_size # Assert that the mini-batch boundaries are uniform for non-step-wise training. - if not is_stepwise: + if not is_stepwise and is_training: expected_num_seq_in_mini_batch = n_samples_per_prompt * mini_batch_size for i, (start, end) in enumerate(boundaries): assert start == i * expected_num_seq_in_mini_batch diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 61655f0f65..b78c28645b 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -592,13 +592,17 @@ def init_weight_sync_state(self): self.dispatch.init_weight_sync_state(self.inference_engine_client) logger.info("Initialized weight sync state for policy model and inference engines.") - def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch: + def convert_to_training_input( + self, generator_output: GeneratorOutput, uids: List[str], is_training: bool = True + ) -> TrainingInputBatch: """Converts lists to a padded batch of tensors for training Args: generator_output (GeneratorOutput): Generated rollouts and associated data. uids (List[str]): List of prompt-unique identifiers for each generator ouput in the same order as `generator_output`. Used to identify which prompt each generated rollout belongs to. + is_training (bool): Whether this batch is for training (strict batch size) or evaluation + (allows partial batches). Defaults to True for backward compatibility. Returns: training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the order of `generator_output` and hence `uids`. @@ -680,11 +684,21 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt is_stepwise = self.cfg.generator.step_wise_trajectories training_input.metadata["policy_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries( - uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt + uids, + self.cfg.trainer.policy_mini_batch_size, + train_batch_size, + is_stepwise, + n_samples_per_prompt, + is_training=is_training, ) if self.cfg.trainer.critic.model.path is not None: training_input.metadata["critic_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries( - uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt + uids, + self.cfg.trainer.critic_mini_batch_size, + train_batch_size, + is_stepwise, + n_samples_per_prompt, + is_training=is_training, ) # 5. Record metadata and metrics. diff --git a/tests/train/test_prompt_mini_batch.py b/tests/train/test_prompt_mini_batch.py index 0ba6c2e684..4d4967dd61 100644 --- a/tests/train/test_prompt_mini_batch.py +++ b/tests/train/test_prompt_mini_batch.py @@ -198,10 +198,88 @@ def test_same_step_count_as_non_stepwise(self): ) assert len(stepwise_bounds) == len(non_stepwise_bounds) == 2 - # Non-step-wise boundaries should be uniform assert non_stepwise_bounds == [(0, 640), (640, 1280)] + def test_eval_partial_batch_nonstepwise(self): + """Test eval mode with partial batches during non-stepwise training. + + This addresses the issue where evaluation crashes when val set size is + not divisible by train_batch_size. With is_training=False, partial + batches should be allowed. + """ + train_batch_size = 4 + spp = 2 + is_stepwise = False + mini_batch_size = 2 + + # Only 3 prompts instead of 4 (partial batch) + uids = ["p0", "p0", "p1", "p1", "p2", "p2"] + + # Should work fine with is_training=False + boundaries = compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False + ) + # With 3 prompts and mini_batch_size=2, we get 2 mini-batches: + # First mini-batch: prompts 0-1 (sequences 0-4) + # Second mini-batch: prompt 2 (sequences 4-6) + assert boundaries == [(0, 4), (4, 6)] + + def test_eval_partial_batch_single_minibatch(self): + """Test eval mode with partial batch that fits in single mini-batch.""" + train_batch_size = 4 + spp = 2 + is_stepwise = False + mini_batch_size = 2 + + # Only 1 prompt instead of 4 (very partial batch) + uids = ["p0", "p0"] + + boundaries = compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False + ) + # With 1 prompt and mini_batch_size=2, we get 1 mini-batch + assert boundaries == [(0, 2)] + + def test_eval_rejects_noncontiguous_uids(self): + """Test that eval mode still enforces contiguous uids.""" + train_batch_size = 4 + spp = 2 + is_stepwise = False + mini_batch_size = 2 + # Non-contiguous uids: p0 appears at index 0-1 and 4-5 + uids = ["p0", "p0", "p1", "p1", "p0", "p0"] + + with pytest.raises(AssertionError, match="uid 'p0' appears in non-contiguous positions"): + compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False + ) + + def test_eval_stepwise_partial_batch(self): + """Test eval mode with stepwise training and partial batch.""" + mini_batch_size = 2 + train_batch_size = 4 + spp = 2 + is_stepwise = True + + # Only 3 prompts instead of 4 + uids = _make_uids_stepwise( + [ + ("p0", 2, [3, 2]), # 5 seqs + ("p1", 2, [1, 4]), # 5 seqs + ("p2", 2, [2, 1]), # 3 seqs + ] + ) + + # Should work fine with is_training=False + boundaries = compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False + ) + # With 3 prompts and mini_batch_size=2, we get 2 mini-batches: + # First: prompts 0-1 (sequences 0-10) + # Second: prompt 2 (sequences 10-13) + assert boundaries == [(0, 10), (10, 13)] + # --------------------------------------------------------------------------- # Tests for MeshDispatch.stage_chunks