-
Notifications
You must be signed in to change notification settings - Fork 330
Fix AssertionError during eval when val set size is not divisible by train_batch_size #1589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
7185ef1
4adaade
6d7f74c
dbb781d
bde201e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,7 +2,6 @@ | |||||||||||||||||||||||||||||
| from typing import List, Optional, Tuple | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||
| from jaxtyping import Float, Integer | ||||||||||||||||||||||||||||||
| from transformers import AutoTokenizer | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||||||
|
|
@@ -39,13 +38,13 @@ def convert_prompts_responses_to_batch_tensors( | |||||||||||||||||||||||||||||
| rollout_expert_indices: Optional[List[List[List[List[int]]]]] = None, | ||||||||||||||||||||||||||||||
| max_seq_len: Optional[int] = None, | ||||||||||||||||||||||||||||||
| ) -> Tuple[ | ||||||||||||||||||||||||||||||
| Float[torch.Tensor, "batch seq_len"], | ||||||||||||||||||||||||||||||
| Float[torch.Tensor, "batch seq_len"], | ||||||||||||||||||||||||||||||
| Float[torch.Tensor, "batch response_len"], | ||||||||||||||||||||||||||||||
| Float[torch.Tensor, "batch response_len"], | ||||||||||||||||||||||||||||||
| Float[torch.Tensor, "batch response_len"], | ||||||||||||||||||||||||||||||
| Optional[Float[torch.Tensor, "batch response_len"]], | ||||||||||||||||||||||||||||||
| Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]], | ||||||||||||||||||||||||||||||
| torch.Tensor, | ||||||||||||||||||||||||||||||
| torch.Tensor, | ||||||||||||||||||||||||||||||
| torch.Tensor, | ||||||||||||||||||||||||||||||
| torch.Tensor, | ||||||||||||||||||||||||||||||
| torch.Tensor, | ||||||||||||||||||||||||||||||
| Optional[torch.Tensor], | ||||||||||||||||||||||||||||||
| Optional[torch.Tensor], | ||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The detailed shape annotations for the return types were removed. It is recommended to keep these for better maintainability and readability.
Suggested change
|
||||||||||||||||||||||||||||||
| ]: | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| Convert prompts and responses to batch tensors for training. | ||||||||||||||||||||||||||||||
|
|
@@ -196,6 +195,7 @@ def compute_prompt_mini_batch_boundaries( | |||||||||||||||||||||||||||||
| train_batch_size: int, | ||||||||||||||||||||||||||||||
| is_stepwise: bool, | ||||||||||||||||||||||||||||||
| n_samples_per_prompt: int, | ||||||||||||||||||||||||||||||
| is_training: bool = True, | ||||||||||||||||||||||||||||||
| ) -> List[Tuple[int, int]]: | ||||||||||||||||||||||||||||||
| """Compute mini-batch ``(start, end)`` slices from a flat ``uids`` list. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -206,10 +206,12 @@ def compute_prompt_mini_batch_boundaries( | |||||||||||||||||||||||||||||
| train_batch_size: Number of prompts in a training batch. For sanity check. | ||||||||||||||||||||||||||||||
| is_stepwise: Whether the training is step-wise. For sanity check. | ||||||||||||||||||||||||||||||
| n_samples_per_prompt: how many samples per prompt. For sanity check. | ||||||||||||||||||||||||||||||
| is_training: Whether this is a training batch (strict validation) or eval batch (allows partial batches). | ||||||||||||||||||||||||||||||
| Defaults to True for backward compatibility. | ||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||
| List of (start, end) indices of the mini-batches. The length of the list is the number of | ||||||||||||||||||||||||||||||
| mini-batches, guaranteed to be `train_batch_size // mini_batch_size` regardless of whether | ||||||||||||||||||||||||||||||
| the training is step-wise or not. | ||||||||||||||||||||||||||||||
| mini-batches, guaranteed to be `train_batch_size // mini_batch_size` during training, but may differ | ||||||||||||||||||||||||||||||
| during evaluation if the final batch is partial. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Consecutive equal entries in ``uids`` belong to the same prompt. Each mini batch spans exactly | ||||||||||||||||||||||||||||||
| ``mini_batch_size`` prompts (the last may be smaller if the total prompt count is not divisible | ||||||||||||||||||||||||||||||
|
|
@@ -244,23 +246,35 @@ def compute_prompt_mini_batch_boundaries( | |||||||||||||||||||||||||||||
| prompt_end_indices.append(i) | ||||||||||||||||||||||||||||||
| prompt_end_indices.append(len(uids)) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # seen_uids should equal to the number of prompts and equal to `train_batch_size` | ||||||||||||||||||||||||||||||
| # Check that num_prompts matches expected batch size | ||||||||||||||||||||||||||||||
| num_prompts = len(prompt_end_indices) | ||||||||||||||||||||||||||||||
| assert num_prompts == train_batch_size and len(seen_uids) == train_batch_size | ||||||||||||||||||||||||||||||
| assert train_batch_size % mini_batch_size == 0 | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Compute boundaries. | ||||||||||||||||||||||||||||||
| if is_training: | ||||||||||||||||||||||||||||||
| assert num_prompts == train_batch_size and len(seen_uids) == train_batch_size, ( | ||||||||||||||||||||||||||||||
| f"Expected {train_batch_size} prompts in training batch, got {num_prompts}." | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| assert train_batch_size % mini_batch_size == 0 | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| if num_prompts != train_batch_size: | ||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||
| f"Partial batch detected during eval: got {num_prompts} prompts but " | ||||||||||||||||||||||||||||||
| f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries." | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Logging a warning for partial batches during evaluation is likely too noisy. When
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Compute boundaries. Handle partial batches during eval. | ||||||||||||||||||||||||||||||
| boundaries: List[Tuple[int, int]] = [] | ||||||||||||||||||||||||||||||
| start_seq = 0 | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| for i in range(0, num_prompts, mini_batch_size): | ||||||||||||||||||||||||||||||
| end_prompt_idx = i + mini_batch_size - 1 # i + mini_batch_size is next mini-batch's first prompt's end index | ||||||||||||||||||||||||||||||
| end_prompt_idx = min(i + mini_batch_size - 1, num_prompts - 1) | ||||||||||||||||||||||||||||||
| end_seq = prompt_end_indices[end_prompt_idx] | ||||||||||||||||||||||||||||||
| boundaries.append((start_seq, end_seq)) | ||||||||||||||||||||||||||||||
| start_seq = end_seq | ||||||||||||||||||||||||||||||
| assert len(boundaries) == train_batch_size // mini_batch_size | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if is_training: | ||||||||||||||||||||||||||||||
| assert len(boundaries) == train_batch_size // mini_batch_size | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Assert that the mini-batch boundaries are uniform for non-step-wise training. | ||||||||||||||||||||||||||||||
| if not is_stepwise: | ||||||||||||||||||||||||||||||
| if not is_stepwise and is_training: | ||||||||||||||||||||||||||||||
| expected_num_seq_in_mini_batch = n_samples_per_prompt * mini_batch_size | ||||||||||||||||||||||||||||||
| for i, (start, end) in enumerate(boundaries): | ||||||||||||||||||||||||||||||
| assert start == i * expected_num_seq_in_mini_batch | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -199,6 +199,85 @@ def test_same_step_count_as_non_stepwise(self): | |
|
|
||
| assert len(stepwise_bounds) == len(non_stepwise_bounds) == 2 | ||
|
|
||
| def test_eval_partial_batch_nonstepwise(self): | ||
| """Test eval mode with partial batches during non-stepwise training. | ||
|
|
||
| This addresses the issue where evaluation crashes when val set size is | ||
| not divisible by train_batch_size. With is_training=False, partial | ||
| batches should be allowed. | ||
| """ | ||
| train_batch_size = 4 | ||
| spp = 2 | ||
| is_stepwise = False | ||
| mini_batch_size = 2 | ||
|
|
||
| # Only 3 prompts instead of 4 (partial batch) | ||
| uids = ["p0", "p0", "p1", "p1", "p2", "p2"] | ||
|
|
||
| # Should work fine with is_training=False | ||
| boundaries = compute_prompt_mini_batch_boundaries( | ||
| uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False | ||
| ) | ||
| # With 3 prompts and mini_batch_size=2, we get 2 mini-batches: | ||
| # First mini-batch: prompts 0-1 (sequences 0-4) | ||
| # Second mini-batch: prompt 2 (sequences 4-6) | ||
| assert boundaries == [(0, 4), (4, 6)] | ||
|
|
||
| def test_eval_partial_batch_single_minibatch(self): | ||
| """Test eval mode with partial batch that fits in single mini-batch.""" | ||
| train_batch_size = 4 | ||
| spp = 2 | ||
| is_stepwise = False | ||
| mini_batch_size = 2 | ||
|
|
||
| # Only 1 prompt instead of 4 (very partial batch) | ||
| uids = ["p0", "p0"] | ||
|
|
||
| boundaries = compute_prompt_mini_batch_boundaries( | ||
| uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False | ||
| ) | ||
| # With 1 prompt and mini_batch_size=2, we get 1 mini-batch | ||
| assert boundaries == [(0, 2)] | ||
|
|
||
| def test_eval_rejects_noncontiguous_uids(self): | ||
| """Test that eval mode still enforces contiguous uids.""" | ||
| train_batch_size = 4 | ||
| spp = 2 | ||
| is_stepwise = False | ||
| mini_batch_size = 2 | ||
| # Non-contiguous uids: p0 appears at index 0-1 and 4-5 | ||
| uids = ["p0", "p0", "p1", "p1", "p0", "p0"] | ||
|
|
||
| with pytest.raises(AssertionError, match="uid 'p0' appears in non-contiguous positions"): | ||
| compute_prompt_mini_batch_boundaries( | ||
| uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False | ||
| ) | ||
|
|
||
| def test_eval_stepwise_partial_batch(self): | ||
| """Test eval mode with stepwise training and partial batch.""" | ||
| mini_batch_size = 2 | ||
| train_batch_size = 4 | ||
| spp = 2 | ||
| is_stepwise = True | ||
|
|
||
| # Only 3 prompts instead of 4 | ||
| uids = _make_uids_stepwise( | ||
| [ | ||
| ("p0", 2, [3, 2]), # 5 seqs | ||
| ("p1", 2, [1, 4]), # 5 seqs | ||
| ("p2", 2, [2, 1]), # 3 seqs | ||
| ] | ||
| ) | ||
|
|
||
| # Should work fine with is_training=False | ||
| boundaries = compute_prompt_mini_batch_boundaries( | ||
| uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False | ||
| ) | ||
| # With 3 prompts and mini_batch_size=2, we get 2 mini-batches: | ||
| # First: prompts 0-1 (sequences 0-10) | ||
| # Second: prompt 2 (sequences 10-13) | ||
| assert boundaries == [(0, 10), (10, 13)] | ||
|
|
||
| # Non-step-wise boundaries should be uniform | ||
| assert non_stepwise_bounds == [(0, 640), (640, 1280)] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The last two lines of
devin-ai-integration[bot] marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
jaxtypingimports and associated type annotations were removed in this file. These annotations provide valuable documentation regarding tensor shapes and dtypes, which is particularly helpful in complex batching logic. Unless there is a specific reason for their removal, they should be retained to maintain code clarity and type safety.