Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 25 additions & 29 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,44 +1514,36 @@ class PendingQueriesMap:
"""Thread-safe map for tracking pending queries with reference counting."""

def __init__(self):
self._map = {} # dataset_idx -> (query, ground_truth, dataset, count)
# dataset_idx -> [data, count]
self._map: dict[int, list[Any]] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using list[Any] to store [data, count] can be a bit ambiguous and relies on magic-number indexing (e.g., [1] for the count). For better code clarity and type safety, consider using a dataclass to structure this data. A PendingItem dataclass with data and count attributes would make the code more self-documenting and accesses more explicit (e.g., self._map[dataset_idx].count += 1).

self._lock = threading.Lock()

def insert(self, dataset_idx, query, ground_truth, dataset, raw_query):
def insert(self, dataset_idx: int, data: dict[str, Any]) -> None:
"""Insert or increment count for a dataset index."""
with self._lock:
if dataset_idx in self._map:
# Already exists - just increment count
existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[
dataset_idx
]
self._map[dataset_idx] = (
existing_query,
existing_ground_truth,
existing_dataset,
existing_raw_query,
count + 1,
)
self._map[dataset_idx][1] += 1
else:
# New entry - count starts at 1
self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1)
self._map[dataset_idx] = [data.copy(), 1]

def pop(self, dataset_idx):
def pop(self, dataset_idx: int) -> dict[str, Any]:
"""Retrieve data and decrement count. Removes entry when count reaches 0."""
with self._lock:
if dataset_idx not in self._map:
raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map")

query, ground_truth, dataset, raw_query, count = self._map[dataset_idx]
data, count = self._map[dataset_idx]

if count > 1:
# More results expected - just decrement
self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1)
self._map[dataset_idx][1] -= 1
else:
# Last result - remove entry
del self._map[dataset_idx]

return query, ground_truth, dataset, raw_query
return data.copy()

def __len__(self):
"""Return the number of entries in the map."""
Expand Down Expand Up @@ -1730,7 +1722,7 @@ def accumulate_inference_batches(
f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}"
)

query, ground_truth, dataset_name, raw_query = pending_queries_map.pop(result.dataset_index)
pending_data = pending_queries_map.pop(result.dataset_index)

# Replenish generation queue with new prompt
if replenish_prompts:
Expand All @@ -1756,10 +1748,10 @@ def accumulate_inference_batches(
decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True)

# TODO(finbarrtimbers): Make PendingQueriesMap.pop return a Batch, and add a Batch.repeat method.
k_queries = repeat_each([query], generation_config.n)
k_ground_truths = repeat_each([ground_truth], generation_config.n)
k_datasets = repeat_each([dataset_name], generation_config.n)
k_raw_queries = repeat_each([raw_query], generation_config.n)
k_queries = repeat_each([pending_data["query"]], generation_config.n)
k_ground_truths = repeat_each([pending_data["ground_truth"]], generation_config.n)
k_datasets = repeat_each([pending_data["dataset"]], generation_config.n)
k_raw_queries = repeat_each([pending_data["raw_query"]], generation_config.n)

scores, reward_metrics = asyncio.run(
reward_fn(
Expand Down Expand Up @@ -1931,7 +1923,7 @@ def data_preparation_thread(
inference_results_Q: ray_queue.Queue, # Ray queue
param_prompt_Q: ray_queue.Queue,
packed_sequences_Q: Queue,
pending_queries_map: dict,
pending_queries_map: PendingQueriesMap,
args: Args,
tokenizer: PreTrainedTokenizer,
num_training_steps: int,
Expand Down Expand Up @@ -2401,15 +2393,19 @@ def add_prompt_to_generator(
is_eval: bool,
) -> None:
"""Split a batch into multiple inference batches and insert individual prompts into queues and mapping."""
query = example[INPUT_IDS_PROMPT_KEY]
ground_truth = example[GROUND_TRUTHS_KEY]
dataset_name = example[VERIFIER_SOURCE_KEY]
raw_query = example[RAW_PROMPT_KEY]
pending_queries_map.insert(example_index, query, ground_truth, dataset_name, raw_query)
pending_queries_map.insert(
example_index,
{
"query": example[INPUT_IDS_PROMPT_KEY],
"ground_truth": example[GROUND_TRUTHS_KEY],
"dataset": example[VERIFIER_SOURCE_KEY],
"raw_query": example[RAW_PROMPT_KEY],
},
)

param_prompt_Q.put(
PromptRequest(
prompt=query,
prompt=example[INPUT_IDS_PROMPT_KEY],
generation_config=generation_config,
epoch_number=epoch_number,
training_step=training_step,
Expand Down
77 changes: 50 additions & 27 deletions open_instruct/test_grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ def setup_and_add_prompts_to_generator(

return param_prompt_Q, inference_results_Q, pending_queries_map

def create_pending_data(self, query, ground_truth, dataset, raw_query):
return {
"query": query,
"ground_truth": ground_truth,
"dataset": dataset,
"raw_query": raw_query,
}


class TestGrpoFastVLLM(TestGrpoFastBase):
def test_vllm_queue_system_single_prompt(self):
Expand Down Expand Up @@ -379,13 +387,13 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int,
dataset_index = result.dataset_index

# Get query from pending_queries_map
q, gt, d, raw_q = pending_queries_map.pop(dataset_index)
pending_data = pending_queries_map.pop(dataset_index)

combined_responses.extend(result.responses)
combined_queries.append(q)
combined_raw_queries.append(raw_q)
combined_ground_truths.append(gt)
combined_datasets.append(d)
combined_queries.append(pending_data["query"])
combined_raw_queries.append(pending_data["raw_query"])
combined_ground_truths.append(pending_data["ground_truth"])
combined_datasets.append(pending_data["dataset"])

combined_result = GenerationResult(
responses=combined_responses,
Expand Down Expand Up @@ -452,11 +460,11 @@ def test_dataset_index_preservation_through_pipeline(self):
result = inference_results_Q.get()
dataset_index = result.dataset_index

q, gt, d, raw_q = pending_queries_map.pop(dataset_index)
combined_queries.append(q)
combined_raw_queries.append(raw_q)
combined_ground_truths.append(gt)
combined_datasets.append(d)
pending_data = pending_queries_map.pop(dataset_index)
combined_queries.append(pending_data["query"])
combined_raw_queries.append(pending_data["raw_query"])
combined_ground_truths.append(pending_data["ground_truth"])
combined_datasets.append(pending_data["dataset"])

# Verify results
self.assertEqual(combined_queries, queries_next)
Expand Down Expand Up @@ -485,7 +493,10 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe
for idx, query, ground_truth, dataset, raw_query in zip(
dataset_indices, queries_next, ground_truths_next, datasets_next, raw_queries_next
):
pending_queries_map.insert(idx, query, ground_truth, dataset, raw_query)
pending_queries_map.insert(
idx,
self.create_pending_data(query, ground_truth, dataset, raw_query),
)

# Simulate vLLM processing with multiple samples
batch_idx = 0
Expand All @@ -507,16 +518,16 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe
dataset_index = result.dataset_index

# Pop the query data for this specific result - pop multiple times for multiple samples
q, gt, d, raw_q = pending_queries_map.pop(dataset_index)
pending_data = pending_queries_map.pop(dataset_index)
# Pop additional times to handle multiple samples per prompt
for _ in range(num_samples_per_prompt - 1):
pending_queries_map.pop(dataset_index)

combined_responses.extend(result.responses)
combined_queries.append(q)
combined_raw_queries.append(raw_q)
combined_ground_truths.append(gt)
combined_datasets.append(d)
combined_queries.append(pending_data["query"])
combined_raw_queries.append(pending_data["raw_query"])
combined_ground_truths.append(pending_data["ground_truth"])
combined_datasets.append(pending_data["dataset"])

combined_result = GenerationResult(
responses=combined_responses,
Expand Down Expand Up @@ -647,10 +658,12 @@ def add_and_remove_entries(thread_id):
for i in range(start_idx, start_idx + entries_per_thread):
pending_queries_map.insert(
i,
f"query_{thread_id}_{i}",
f"truth_{thread_id}_{i}",
f"dataset_{thread_id}_{i}",
f"query_{thread_id}_{i}",
self.create_pending_data(
f"query_{thread_id}_{i}",
f"truth_{thread_id}_{i}",
f"dataset_{thread_id}_{i}",
f"query_{thread_id}_{i}",
),
)
time.sleep(0.0001)

Expand Down Expand Up @@ -696,7 +709,10 @@ def test_accumulate_waits_for_all_engines(self):

# Add entries to map
for i in range(num_prompts):
pending_queries_map.insert(i, f"q_{i}", f"t_{i}", f"d_{i}", f"q_{i}")
pending_queries_map.insert(
i,
self.create_pending_data(f"q_{i}", f"t_{i}", f"d_{i}", f"q_{i}"),
)

# Add results from only 3 engines (missing one)
# With individual prompts, we add individual results
Expand Down Expand Up @@ -863,7 +879,10 @@ def test_streaming_accumulation_basic(self):

# Insert data into pending_queries_map
for i in range(num_prompts):
pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i], raw_queries[i])
pending_queries_map.insert(
i,
self.create_pending_data(queries[i], ground_truths[i], datasets[i], raw_queries[i]),
)

# Create mock results - one per prompt
for i in range(num_prompts):
Expand All @@ -882,8 +901,8 @@ def test_streaming_accumulation_basic(self):

# Get query for this prompt
dataset_index = result.dataset_index
q, gt, d, raw_q = pending_queries_map.pop(dataset_index)
queries_list.append((q, gt, d, raw_q))
pending_data = pending_queries_map.pop(dataset_index)
queries_list.append(pending_data)

# Verify all results processed
self.assertEqual(len(results_list), expected_results)
Expand All @@ -892,8 +911,7 @@ def test_streaming_accumulation_basic(self):
# Combine in order
combined_queries = []
for i in range(num_prompts):
q, _, _, _ = queries_list[i]
combined_queries.append(q)
combined_queries.append(queries_list[i]["query"])

# Verify order is preserved
self.assertEqual(combined_queries, queries)
Expand All @@ -916,7 +934,12 @@ def test_streaming_with_multiple_samples(self):
# Insert data with reference counting for multiple samples
for i in range(num_prompts):
for _ in range(num_samples):
pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i], raw_queries[i])
pending_queries_map.insert(
i,
self.create_pending_data(
queries[i], ground_truths[i], datasets[i], raw_queries[i]
),
)

# Create results - one per prompt with multiple samples
for i in range(num_prompts):
Expand Down
Loading