Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d00dcdd
uses add_request
finbarrtimbers Sep 2, 2025
7b74746
ran linter
finbarrtimbers Sep 2, 2025
1408569
Clean up
finbarrtimbers Sep 2, 2025
dba447a
Fixed bug
finbarrtimbers Sep 2, 2025
3932bf1
Added duplication
finbarrtimbers Sep 2, 2025
7becde9
Added prompt_tokens to metadata.
finbarrtimbers Sep 2, 2025
d2cb9f7
Added missing key to metadata
finbarrtimbers Sep 2, 2025
01dd23b
Fixed bug where we weren't returning properly.
finbarrtimbers Sep 2, 2025
f47879f
Fix script
finbarrtimbers Sep 2, 2025
ca5f07d
Added logging
finbarrtimbers Sep 2, 2025
3f626fb
fix bug
finbarrtimbers Sep 2, 2025
ca79b8b
use clone for SamplingParams
finbarrtimbers Sep 2, 2025
084ad77
Fixes to duplication
finbarrtimbers Sep 2, 2025
d2e6041
Removed logging.
finbarrtimbers Sep 2, 2025
12a4ce7
Cleaned up PR.
finbarrtimbers Sep 2, 2025
0813b85
Clean PR
finbarrtimbers Sep 2, 2025
e9d6cfb
Removed whitespace
finbarrtimbers Sep 2, 2025
417748b
Cleaned up PR
finbarrtimbers Sep 2, 2025
1ff890c
Merge branch 'main' into combined-llm-loop
finbarrtimbers Sep 2, 2025
d96e7b2
Added comment for cleaner PR.
finbarrtimbers Sep 2, 2025
fe8e1bf
Merge branch 'main' into combined-llm-loop
finbarrtimbers Sep 2, 2025
f133a8e
Cleaning up PR
finbarrtimbers Sep 2, 2025
341a77b
Revert "load pretokenized user query (v0) (#965)"
finbarrtimbers Sep 3, 2025
8ebfdf9
Bug fix.
finbarrtimbers Sep 3, 2025
e88c2c2
Fixed issue where we weren't setting params right in tools.
finbarrtimbers Sep 3, 2025
a929c63
Updated descriptions.
finbarrtimbers Sep 3, 2025
ba441f0
Fix ordering.
finbarrtimbers Sep 3, 2025
d4e6fd9
Updated tool script with description.
finbarrtimbers Sep 3, 2025
4e7cbe3
Fixed use of wrong vllm.SamplingParams.
finbarrtimbers Sep 3, 2025
d425a42
Now, tool use should run.
finbarrtimbers Sep 3, 2025
826f199
Reapply "load pretokenized user query (v0) (#965)"
finbarrtimbers Sep 3, 2025
e214dba
Merge branch 'main' into combined-llm-loop
finbarrtimbers Sep 3, 2025
74fb0a6
minor clean up.
finbarrtimbers Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,6 @@ def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"):
DEFAULT_SFT_MESSAGES_KEY = "messages"
GROUND_TRUTHS_KEY = "ground_truth"
VERIFIER_SOURCE_KEY = "dataset"
RAW_PROMPT_KEY = "prompt"


@dataclass
Expand Down Expand Up @@ -815,14 +814,8 @@ def tokenizer(self):
ATTENTION_MASK_KEY = "attention_mask"
LABELS_KEY = "labels"
DATASET_ORIGIN_KEY = "dataset_source" # just 'dataset' clashes with RLVR stuff (see VERIFIER_SOURCE_KEY)
TOKENIZED_SFT_DATASET_KEYS = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY, RAW_PROMPT_KEY]
TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE = [
INPUT_IDS_KEY,
ATTENTION_MASK_KEY,
LABELS_KEY,
DATASET_ORIGIN_KEY,
RAW_PROMPT_KEY,
]
TOKENIZED_SFT_DATASET_KEYS = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY]
TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY, DATASET_ORIGIN_KEY]


def remove_dataset_source_field(dataset: Dataset) -> Dataset:
Expand Down Expand Up @@ -1192,8 +1185,6 @@ def rlvr_tokenize_v1(
row[LABELS_KEY] = labels
row[GROUND_TRUTHS_KEY] = row[ground_truths_key]
row[VERIFIER_SOURCE_KEY] = row[verifier_source_key]
# concatenate all the previous messages as <role>: <content>\n <role>: <content>\n ...
row[RAW_PROMPT_KEY] = "\n".join(f"{msg['role']}: {msg['content']}" for msg in prompt)
return row


Expand Down Expand Up @@ -1221,10 +1212,6 @@ def rlvr_tokenize_v2(
row[LABELS_KEY] = labels
row[GROUND_TRUTHS_KEY] = row[ground_truths_key]
row[VERIFIER_SOURCE_KEY] = row[verifier_source_key]
# concatenate all the previous messages as <role>: <content>\n <role>: <content>\n ...
# row[DEFAULT_SFT_MESSAGES_KEY] = prompt
# concatenate all the previous messages as <role>: <content>\n <role>: <content>\n ...
row[RAW_PROMPT_KEY] = "\n".join(f"{msg['role']}: {msg['content']}" for msg in prompt)
# some basic transformations:
# if ground truths is a string, make it a list
if isinstance(row[ground_truths_key], str):
Expand Down Expand Up @@ -1686,6 +1673,7 @@ def get_cached_dataset_tulu_with_statistics(
frac_or_num_samples = float(frac_or_num_samples)
else:
frac_or_num_samples = int(frac_or_num_samples)

dataset_config = DatasetConfig(
dataset_name=dataset_name,
dataset_split=dataset_mixer_list_splits[i],
Expand Down
11 changes: 2 additions & 9 deletions open_instruct/ground_truth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,6 @@

logger = logger_utils.setup_logger(__name__)

# remove excessive logging from liteLLM
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
logging.getLogger("litellm").setLevel(logging.ERROR)
logging.getLogger("litellm.cost_calculator").setLevel(logging.CRITICAL)
logging.getLogger("litellm._client").setLevel(logging.CRITICAL)
logging.getLogger("cost_calculator").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)


@dataclass
class VerifierConfig:
Expand Down Expand Up @@ -671,7 +663,8 @@ async def async_call(
for attempt in range(max_retries):
# judges the quality of a response
try:
messages = build_messages(prompt)
system_prompt = "Do not generate text between the <think> and </think> tags." # "You are a concise assistant who gives very short explanations before giving a quality score."
messages = build_messages(prompt, system_prompt)

# Faeze: check if the request would exceed context window
# Import the context window checker
Expand Down
58 changes: 18 additions & 40 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
from open_instruct.dataset_transformation import (
GROUND_TRUTHS_KEY,
INPUT_IDS_PROMPT_KEY,
RAW_PROMPT_KEY,
VERIFIER_SOURCE_KEY,
TokenizerConfig,
get_cached_dataset_tulu,
Expand Down Expand Up @@ -120,6 +119,7 @@
calibrate_checkpoint_state_dir,
clean_last_n_checkpoints_deepspeed,
download_latest_checkpoint_from_gs,
extract_user_query,
get_beaker_whoami,
get_eval_ds_config,
get_optimizer_grouped_parameters,
Expand Down Expand Up @@ -487,7 +487,6 @@ def next_batch(dataset_indices: List[int], dataset: datasets.Dataset) -> Batch:
queries=data_next[INPUT_IDS_PROMPT_KEY],
ground_truths=data_next[GROUND_TRUTHS_KEY],
datasets=data_next[VERIFIER_SOURCE_KEY],
raw_queries=data_next[RAW_PROMPT_KEY],
indices=dataset_indices,
)

Expand Down Expand Up @@ -1280,63 +1279,45 @@ def __init__(self):
self._map = {} # dataset_idx -> (query, ground_truth, dataset, count)
self._lock = threading.Lock()

def insert(self, dataset_idx, query, ground_truth, dataset, raw_query):
def insert(self, dataset_idx, query, ground_truth, dataset):
"""Insert or increment count for a dataset index."""
with self._lock:
if dataset_idx in self._map:
# Already exists - just increment count
existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[
dataset_idx
]
self._map[dataset_idx] = (
existing_query,
existing_ground_truth,
existing_dataset,
existing_raw_query,
count + 1,
)
existing_query, existing_ground_truth, existing_dataset, count = self._map[dataset_idx]
self._map[dataset_idx] = (existing_query, existing_ground_truth, existing_dataset, count + 1)
else:
# New entry - count starts at 1
self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1)
self._map[dataset_idx] = (query, ground_truth, dataset, 1)

def insert_many(self, dataset_indices, queries, ground_truths, datasets, raw_queries):
def insert_many(self, dataset_indices, queries, ground_truths, datasets):
"""Insert or increment count for multiple dataset indices at once."""
with self._lock:
for i, dataset_idx in enumerate(dataset_indices):
current_raw_query = raw_queries[i]

if dataset_idx in self._map:
# Already exists - just increment count
existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[
dataset_idx
]
self._map[dataset_idx] = (
existing_query,
existing_ground_truth,
existing_dataset,
existing_raw_query,
count + 1,
)
existing_query, existing_ground_truth, existing_dataset, count = self._map[dataset_idx]
self._map[dataset_idx] = (existing_query, existing_ground_truth, existing_dataset, count + 1)
else:
# New entry - count starts at 1
self._map[dataset_idx] = (queries[i], ground_truths[i], datasets[i], current_raw_query, 1)
self._map[dataset_idx] = (queries[i], ground_truths[i], datasets[i], 1)

def pop(self, dataset_idx):
"""Retrieve data and decrement count. Removes entry when count reaches 0."""
with self._lock:
if dataset_idx not in self._map:
raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map")

query, ground_truth, dataset, raw_query, count = self._map[dataset_idx]
query, ground_truth, dataset, count = self._map[dataset_idx]

if count > 1:
# More results expected - just decrement
self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1)
self._map[dataset_idx] = (query, ground_truth, dataset, count - 1)
else:
# Last result - remove entry
del self._map[dataset_idx]

return query, ground_truth, dataset, raw_query
return query, ground_truth, dataset

def __len__(self):
"""Return the number of entries in the map."""
Expand Down Expand Up @@ -1389,7 +1370,6 @@ def accumulate_inference_batches(
all_queries = []
all_ground_truths = []
all_datasets = []
all_raw_queries = []
for i in tqdm(
range(args.vllm_num_engines),
total=args.vllm_num_engines,
Expand Down Expand Up @@ -1421,20 +1401,17 @@ def accumulate_inference_batches(
batch_queries = []
batch_ground_truths = []
batch_datasets = []
batch_raw_queries = []

for dataset_idx in dataset_indices:
query, ground_truth, dataset, raw_query = pending_queries_map.pop(dataset_idx)
query, ground_truth, dataset = pending_queries_map.pop(dataset_idx)
batch_queries.append(query)
batch_ground_truths.append(ground_truth)
batch_datasets.append(dataset)
batch_raw_queries.append(raw_query)

results.append(result)
all_queries.extend(batch_queries)
all_ground_truths.extend(batch_ground_truths)
all_datasets.extend(batch_datasets)
all_raw_queries.extend(batch_raw_queries)

# Combine all results into a single GenerationResult
combined_responses = []
Expand Down Expand Up @@ -1496,7 +1473,6 @@ def accumulate_inference_batches(
queries=all_queries,
ground_truths=all_ground_truths,
datasets=all_datasets,
raw_queries=all_raw_queries,
indices=None, # Not meaningful for combined results
)
return combined_result, batch
Expand Down Expand Up @@ -1533,7 +1509,6 @@ def data_preparation_thread(
queries=repeat_each(batch.queries, args.num_samples_per_prompt_rollout),
ground_truths=repeat_each(batch.ground_truths, args.num_samples_per_prompt_rollout),
datasets=repeat_each(batch.datasets, args.num_samples_per_prompt_rollout),
raw_queries=repeat_each(batch.raw_queries, args.num_samples_per_prompt_rollout),
indices=repeat_each(batch.indices, args.num_samples_per_prompt_rollout) if batch.indices else None,
)
good_outputs = [
Expand All @@ -1555,7 +1530,8 @@ def data_preparation_thread(

with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True)
decoded_queries = batch.raw_queries
decoded_queries = tokenizer.batch_decode(batch.queries, skip_special_tokens=True)
decoded_queries = [extract_user_query(query) for query in decoded_queries]
stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len(
result.finish_reasons
)
Expand Down Expand Up @@ -2071,7 +2047,7 @@ def split_and_insert_batch(

# Store prompts in the map using thread-safe insert_many
pending_queries_map.insert_many(
sub_batch.indices, sub_batch.queries, sub_batch.ground_truths, sub_batch.datasets, sub_batch.raw_queries
sub_batch.indices, sub_batch.queries, sub_batch.ground_truths, sub_batch.datasets
)

# Use PromptRequest for Ray queue with batch-specific dataset_index list
Expand Down Expand Up @@ -2883,6 +2859,8 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
actor_manager,
checkpoint_state,
)
except Exception as e:
logger.error(f"Error in run_training: {e}", exc_info=True)
finally:
cleanup_training_resources(
stop_event, executor, [inference_results_Q, param_prompt_Q, evaluation_inference_results_Q], actor_manager
Expand Down
23 changes: 10 additions & 13 deletions open_instruct/judge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,19 @@
AI assistant to the user query displayed below.

Notes:
- Your evaluation should consider factors such as the helpfulness, relevance, accuracy, creativity, appropriate level of detail, and how well the response satisfies the user's explicit constraints or accurately follows their instructions.
- If there is a system prompt, ensure the AI answer prioritizes following it.
- Begin your evaluation by providing a short explanation.
- Be as objective as possible. After providing your short explanation, please output a score on a scale of 1 to 10.
- Please adhere to the following format.
1- Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response.
2- Begin your evaluation by providing a short explanation.
3- Be as objective as possible. After providing your explanation, please rate the response on a scale of 1 to 10.

[Conversation History]
[Query]
{input}

[AI Answer]
[Response]
{output}

[Your judgement]
Respond in JSON format. {{"REASONING": "[...]", "SCORE": "<your-score>"}}"""
Respond in JSON format. {{"REASONING": "[...]", "SCORE": "<your-score>"}}
"""


general_quality_rubric_template = """
Expand Down Expand Up @@ -77,18 +76,16 @@
general_quality_ref_template = """
### Task Description
Please act as an impartial judge and evaluate the quality of the answer provided by an
AI assistant to the conversation history leading up to the answer displayed below.
Judge whether the provided answer is good by comparing it to the reference answer.
AI assistant to the user query displayed below. Judge whether the provided answer is good by comparing it to the reference answer.

Notes:
- Besides comparing to the reference answer, your evaluation should consider factors such as the helpfulness, relevance, accuracy, creativity, appropriate level of detail, and how well the response satisfies the user's explicit constraints or accurately follows their instructions.
- Besides comparing to the referennce answer, your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and appropriate level of detail of the response.
- Note that sometimes the reference answer is not the only answer. So any valid variation of the reference answer is also acceptable and can get a full score.
- If there is a system prompt, ensure the AI answer prioritizes following it.
- Begin your evaluation by providing a short explanation.
- Be as objective as possible. After providing your short explanation, please output a score on a scale of 1 to 10.
- Please adhere to the following format.

[Conversation History]
[Query]
{input}

[AI Answer]
Expand Down
4 changes: 0 additions & 4 deletions open_instruct/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class Batch:
queries: List[List[int]]
ground_truths: List[List[int]]
datasets: List[str]
raw_queries: Optional[List[str]]
indices: Optional[List[int]]

def __getitem__(self, key: Union[slice, int, List[int]]) -> "Batch":
Expand All @@ -65,7 +64,6 @@ def __getitem__(self, key: Union[slice, int, List[int]]) -> "Batch":
queries=self.queries[key],
ground_truths=self.ground_truths[key],
datasets=self.datasets[key],
raw_queries=self.raw_queries[key],
indices=self.indices[key] if self.indices else None,
)
elif isinstance(key, int):
Expand All @@ -74,7 +72,6 @@ def __getitem__(self, key: Union[slice, int, List[int]]) -> "Batch":
queries=[self.queries[key]],
ground_truths=[self.ground_truths[key]],
datasets=[self.datasets[key]],
raw_queries=[self.raw_queries[key]],
indices=[self.indices[key]] if self.indices else None,
)
else:
Expand All @@ -83,7 +80,6 @@ def __getitem__(self, key: Union[slice, int, List[int]]) -> "Batch":
queries=[self.queries[i] for i in key],
ground_truths=[self.ground_truths[i] for i in key],
datasets=[self.datasets[i] for i in key],
raw_queries=[self.raw_queries[i] for i in key],
indices=[self.indices[i] for i in key] if self.indices else None,
)

Expand Down
Loading