From 5f15aaa1a6a782135b18d100bb0a0d92bc286228 Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Wed, 5 Feb 2025 18:03:00 -0800 Subject: [PATCH 01/16] First crack at tool hook. --- trl/trainer/qwen_grpo_trainer.py | 46 +++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 303746643b4..df02ce66689 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -16,6 +16,7 @@ import textwrap import warnings from collections import defaultdict +from dataclasses import dataclass from typing import Any, Callable, Optional, Union from unittest.mock import patch @@ -62,6 +63,12 @@ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] +@dataclass +class ToolDefinition: + """Basic metadata that the trainer needs to know about the tools.""" + stop_string: str + call_tool: Callable[[torch.Tensor], torch.Tensor] + class QwenGRPOTrainer(Trainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -161,6 +168,7 @@ def __init__( callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, + tool_defn: Optional[ToolDefinition] = None, ): # Args if args is None: @@ -217,6 +225,8 @@ def __init__( ) self.reward_funcs = reward_funcs + self.tool_defn = tool_defn + # Reward processing class if reward_processing_classes is None: reward_processing_classes = [None] * len(reward_funcs) @@ -327,12 +337,19 @@ def data_collator(features): # No data collation is needed in GRPO # synchronize all processes after vLLM has been fully initialized. self.accelerator.wait_for_everyone() else: + # No vLLM, so we use the regular generation config + + stop_strings = None + if self.tool_defn: + stop_strings = [self.tool_defn.stop_string] + self.generation_config = GenerationConfig( max_new_tokens=self.max_completion_length, do_sample=True, temperature=args.temperature, num_return_sequences=self.num_generations, pad_token_id=processing_class.tokenizer.pad_token_id, + stop_strings=stop_strings, ) # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the @@ -366,6 +383,29 @@ def _set_signature_columns_if_needed(self): def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: return inputs + def _generate_completions_with_tools(self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor]) -> torch.Tensor: + """Iterates between generation and tool calling. + + Note this is currently only called from the non-vLLM path + """ + + # Loop until tool isn't called. + out = [] + while True: + prompt_completion_ids = model.generate( + **prompt_inputs, generation_config=self.generation_config + ) + out.append(prompt_completion_ids) + if self.tool_defn: + # Check if the stop string is in the completions + import pdb; pdb.set_trace() + if self.tool_defn.stop_string in prompt_completion_ids: + # Call the tool. + tool_response = self.tool_defn.call_tool(prompt_completion_ids) + out.append(tool_response) + # TODO: Feed the tool response back into the model. + return torch.cat(out, dim=0) + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") @@ -416,11 +456,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N prompt_inputs_repeated = torch.repeat_interleave(prompt_inputs["input_ids"], self.num_generations, dim=0) prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1) else: - # Regular generation path + # Regular generation path (not using vLLM) with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - prompt_completion_ids = unwrapped_model.generate( - **prompt_inputs, generation_config=self.generation_config - ) + prompt_completion_ids = self._generate_completions(unwrapped_model, prompt_inputs) prompt_length = prompt_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] From 29a8ef1b8e5fb92325af9ed5adf793c0f34110e2 Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Wed, 5 Feb 2025 18:15:42 -0800 Subject: [PATCH 02/16] It's calling my completion now. --- trl/trainer/__init__.py | 13 +++++++++++-- trl/trainer/qwen_grpo_trainer.py | 6 +++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 930caf59425..8a21969060e 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -36,7 +36,7 @@ "dpo_trainer": ["DPOTrainer"], "gkd_config": ["GKDConfig"], "gkd_trainer": ["GKDTrainer"], - "qwen_grpo_trainer": ["QwenGRPOTrainer"], + "qwen_grpo_trainer": ["QwenGRPOTrainer", "ToolDefinition"], "grpo_config": ["GRPOConfig"], "grpo_trainer": ["GRPOTrainer"], "iterative_sft_trainer": ["IterativeSFTTrainer"], @@ -134,7 +134,7 @@ from .ppo_trainer import PPOTrainer from .prm_config import PRMConfig from .prm_trainer import PRMTrainer - from .qwen_grpo_trainer import QwenGRPOTrainer + from .qwen_grpo_trainer import QwenGRPOTrainer, ToolDefinition from .reward_config import RewardConfig from .reward_trainer import RewardTrainer from .rloo_config import RLOOConfig @@ -164,3 +164,12 @@ import sys sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + +# Importing from qwen_grpo_trainer to expose them at the package level. +from .qwen_grpo_trainer import QwenGRPOTrainer, ToolDefinition + +# Define __all__ to explicitly list the public API of this package. +__all__ = [ + "QwenGRPOTrainer", + "ToolDefinition", +] diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index df02ce66689..e05031f7c6b 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -69,6 +69,7 @@ class ToolDefinition: stop_string: str call_tool: Callable[[torch.Tensor], torch.Tensor] + class QwenGRPOTrainer(Trainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -388,9 +389,8 @@ def _generate_completions_with_tools(self, model: PreTrainedModel, prompt_inputs Note this is currently only called from the non-vLLM path """ - - # Loop until tool isn't called. out = [] + # Loop until tool isn't called. while True: prompt_completion_ids = model.generate( **prompt_inputs, generation_config=self.generation_config @@ -458,7 +458,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N else: # Regular generation path (not using vLLM) with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - prompt_completion_ids = self._generate_completions(unwrapped_model, prompt_inputs) + prompt_completion_ids = self._generate_completions_with_tools(unwrapped_model, prompt_inputs) prompt_length = prompt_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] From c1c8a925812b9a9c380a977ff8add466122cb95c Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Wed, 5 Feb 2025 21:29:03 -0800 Subject: [PATCH 03/16] Progress - sorta maybe almost incorporating the tool response into the conversation. --- trl/trainer/qwen_grpo_trainer.py | 81 ++++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 14 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index e05031f7c6b..2830ec7de22 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -69,6 +69,10 @@ class ToolDefinition: stop_string: str call_tool: Callable[[torch.Tensor], torch.Tensor] + def completion_has_tool_call(self, completion_str: str) -> bool: + """Check if the completion has a tool call.""" + return self.stop_string in completion_str + class QwenGRPOTrainer(Trainer): """ @@ -384,27 +388,59 @@ def _set_signature_columns_if_needed(self): def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: return inputs - def _generate_completions_with_tools(self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor]) -> torch.Tensor: + def _generate_completions( + self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor] + ) -> torch.Tensor: + """Generate completions using the model.""" + prompt_completion_ids = model.generate( + **prompt_inputs, + generation_config=self.generation_config, + tokenizer=self.processing_class.tokenizer, + ) + return prompt_completion_ids + + def _add_response_to_prompt_inputs(self, prompt_inputs: dict[str, torch.Tensor], response: torch.Tensor) -> dict[str, torch.Tensor]: + """Add the response to the prompt inputs.""" + prompt_inputs["input_ids"] = torch.cat([prompt_inputs["input_ids"], response], dim=1) + prompt_inputs["attention_mask"] = torch.cat([prompt_inputs["attention_mask"], torch.ones_like(tool_response)], dim=1) + return prompt_inputs + + def _generate_completions_with_tools( + self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor] + ) -> torch.Tensor: """Iterates between generation and tool calling. Note this is currently only called from the non-vLLM path + + prompt_inputs is a dict with the following keys: + - input_ids: [1, 710] ints. Some stuff at the beginning and the end, the middle full of 151655 + - attention_mask: [1, 710] ints. All 1 + - pixel_values: 2024x1176 floats. The image. + - image_grid_thw: a 1x3 tensor with values: [1, 46, 44]. + (Note that 46*44 is 2024). """ out = [] # Loop until tool isn't called. while True: - prompt_completion_ids = model.generate( - **prompt_inputs, generation_config=self.generation_config - ) + prompt_completion_ids = self._generate_completions(model, prompt_inputs) + # prompt_completion_ids is a tensor of shape (B, L) + # Where B is (3) for the number of generations. + # L is 875 here. It's just token ids, nothing else. out.append(prompt_completion_ids) - if self.tool_defn: - # Check if the stop string is in the completions - import pdb; pdb.set_trace() - if self.tool_defn.stop_string in prompt_completion_ids: - # Call the tool. - tool_response = self.tool_defn.call_tool(prompt_completion_ids) - out.append(tool_response) - # TODO: Feed the tool response back into the model. - return torch.cat(out, dim=0) + # Check if the stop string is in the completions + # We need to convert the tensor to a string. + prompt_completion_str = self.processing_class.tokenizer.decode(prompt_completion_ids[0], skip_special_tokens=True) + if self.tool_defn.completion_has_tool_call(prompt_completion_str): + tool_response = self.tool_defn.call_tool(prompt_completion_ids) + out.append(tool_response) + prompt_inputs = self._add_response_to_prompt_inputs(prompt_inputs, prompt_completion_ids) + prompt_inputs = self._add_response_to_prompt_inputs(prompt_inputs, tool_response) + # Note: we're gonna have to figure out images. + else: + # No tool call, so we're done. + break + all_out = torch.cat(out, dim=0) + return all_out def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: @@ -458,7 +494,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N else: # Regular generation path (not using vLLM) with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - prompt_completion_ids = self._generate_completions_with_tools(unwrapped_model, prompt_inputs) + if self.tool_defn: + prompt_completion_ids = self._generate_completions_with_tools(unwrapped_model, prompt_inputs) + else: + prompt_completion_ids = self._generate_completions(unwrapped_model, prompt_inputs) prompt_length = prompt_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] @@ -657,3 +696,17 @@ def create_model_card( ) model_card.save(os.path.join(self.args.output_dir, "README.md")) + + + +def print_random_alphabet(n:int=4) -> None: + """Print the uppercase alphabet in a random order with four spaces between each letter. + For the secret-decoder-ring task. + """ + import random + import string + letters = list(string.ascii_uppercase) + random.shuffle(letters) + js = " " * n + print(js.join(letters[0:13])) + print(js.join(letters[13:])) \ No newline at end of file From e6e4761b7d3aad7ab0df74e944884266f2b3a26c Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Wed, 5 Feb 2025 21:50:25 -0800 Subject: [PATCH 04/16] Computes completions + tool-calls individually and pads them together. --- trl/trainer/qwen_grpo_trainer.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 2830ec7de22..ec445d8b2a5 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -389,9 +389,11 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s return inputs def _generate_completions( - self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor] + self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor], num_generations: int | None = None ) -> torch.Tensor: - """Generate completions using the model.""" + """Generate completion(s) using the model.""" + if num_generations is not None: + self.generation_config.num_return_sequences = num_generations prompt_completion_ids = model.generate( **prompt_inputs, generation_config=self.generation_config, @@ -402,11 +404,27 @@ def _generate_completions( def _add_response_to_prompt_inputs(self, prompt_inputs: dict[str, torch.Tensor], response: torch.Tensor) -> dict[str, torch.Tensor]: """Add the response to the prompt inputs.""" prompt_inputs["input_ids"] = torch.cat([prompt_inputs["input_ids"], response], dim=1) - prompt_inputs["attention_mask"] = torch.cat([prompt_inputs["attention_mask"], torch.ones_like(tool_response)], dim=1) + prompt_inputs["attention_mask"] = torch.cat([prompt_inputs["attention_mask"], torch.ones_like(response)], dim=1) return prompt_inputs def _generate_completions_with_tools( self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor] + ) -> torch.Tensor: + """Generate the full set of completions with tools, and stitch them together. + """ + out = [] + for _ in range(self.num_generations): + out.append(self._generate_single_completion_with_tools(model, prompt_inputs)) + # Now we have a ragged list of tensors. We need to pad them to the same length. + max_length = max(completion.size(1) for completion in out) + for i in range(len(out)): + padding = torch.zeros(1, max_length - out[i].size(1), dtype=torch.long, device=out[i].device) + out[i] = torch.cat([out[i], padding], dim=1) + final = torch.cat(out, dim=0) + return final + + def _generate_single_completion_with_tools( + self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor] ) -> torch.Tensor: """Iterates between generation and tool calling. @@ -422,9 +440,9 @@ def _generate_completions_with_tools( out = [] # Loop until tool isn't called. while True: - prompt_completion_ids = self._generate_completions(model, prompt_inputs) - # prompt_completion_ids is a tensor of shape (B, L) - # Where B is (3) for the number of generations. + prompt_completion_ids = self._generate_completions(model, prompt_inputs, num_generations=1) + # prompt_completion_ids is a tensor of shape (1, L) + # Because we only generated one completion. # L is 875 here. It's just token ids, nothing else. out.append(prompt_completion_ids) # Check if the stop string is in the completions @@ -439,7 +457,7 @@ def _generate_completions_with_tools( else: # No tool call, so we're done. break - all_out = torch.cat(out, dim=0) + all_out = torch.cat(out, dim=1) # dim=1 means the token sequence gets longer. return all_out def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): From 7d72188d2a7acc3d86c626f387bb86f4de5f71ef Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Wed, 5 Feb 2025 22:08:15 -0800 Subject: [PATCH 05/16] Tool call with strings not ids. --- trl/trainer/qwen_grpo_trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index ec445d8b2a5..80ead5cc484 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -449,10 +449,11 @@ def _generate_single_completion_with_tools( # We need to convert the tensor to a string. prompt_completion_str = self.processing_class.tokenizer.decode(prompt_completion_ids[0], skip_special_tokens=True) if self.tool_defn.completion_has_tool_call(prompt_completion_str): - tool_response = self.tool_defn.call_tool(prompt_completion_ids) - out.append(tool_response) + tool_response_str = self.tool_defn.call_tool(prompt_completion_str) + tool_response_ids = self.processing_class.tokenizer.encode(tool_response_str, add_special_tokens=False) + out.append(tool_response_ids) prompt_inputs = self._add_response_to_prompt_inputs(prompt_inputs, prompt_completion_ids) - prompt_inputs = self._add_response_to_prompt_inputs(prompt_inputs, tool_response) + prompt_inputs = self._add_response_to_prompt_inputs(prompt_inputs, tool_response_ids) # Note: we're gonna have to figure out images. else: # No tool call, so we're done. From 69291c4a7b19d5f234b1a45db0ed5ba2b4024473 Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Wed, 5 Feb 2025 22:46:28 -0800 Subject: [PATCH 06/16] Much closer to incorporating tool responses. --- trl/trainer/qwen_grpo_trainer.py | 71 ++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 22 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 80ead5cc484..521e0161b51 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -74,6 +74,50 @@ def completion_has_tool_call(self, completion_str: str) -> bool: return self.stop_string in completion_str +class SingleConversationWithTools: + """Keeps track of the prompt, and knows how to put together the partial responses and the tool call responses.""" + + def __init__(self, prompt_inputs: dict[str, torch.Tensor], tool_defn: ToolDefinition, processing_class: PreTrainedTokenizerBase): + self.prompt_inputs = prompt_inputs + self.tool_defn = tool_defn + self.response = [] + self.processing_class = processing_class + + def process_response(self, prompt_completion_ids: torch.Tensor) -> bool: + """Adds the response to the conversation, including calling the tool if necessary. + Returns True if there was a tool call, and the conversation should continue. + Returns False if there was no tool call, and the conversation is complete. + """ + self.response.append(prompt_completion_ids) + prompt_completion_str = self.processing_class.tokenizer.decode(prompt_completion_ids[0], skip_special_tokens=True) + # Check if the stop string is in the completions + # We need to convert the tensor to a string. + if self.tool_defn.completion_has_tool_call(prompt_completion_str): + tool_response_str = self.tool_defn.call_tool(prompt_completion_str) + tool_response_ids_list = self.processing_class.tokenizer.encode(tool_response_str, add_special_tokens=False) + tool_response_ids = torch.tensor(tool_response_ids_list, device=prompt_completion_ids.device) # [L] + tool_response_ids = tool_response_ids[None, :] # [1, L] + self.response.append(tool_response_ids) + self.prompt_inputs = self._add_response_to_prompt_inputs(self.prompt_inputs, prompt_completion_ids) + self.prompt_inputs = self._add_response_to_prompt_inputs(self.prompt_inputs, tool_response_ids) + # Note: we're gonna have to figure out images. + return True + else: + # No tool call, so we're done. + return False + + def _add_response_to_prompt_inputs(self, prompt_inputs: dict[str, torch.Tensor], response: torch.Tensor) -> dict[str, torch.Tensor]: + """Add the response to the prompt inputs.""" + prompt_inputs["input_ids"] = torch.cat([prompt_inputs["input_ids"], response], dim=1) + ones = torch.ones_like(response, device=response.device) + prompt_inputs["attention_mask"] = torch.cat([prompt_inputs["attention_mask"], ones], dim=1) + return prompt_inputs + + + def get_response(self) -> torch.Tensor: + """Returns the response as a tensor.""" + return torch.cat(self.response, dim=1) + class QwenGRPOTrainer(Trainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -401,12 +445,6 @@ def _generate_completions( ) return prompt_completion_ids - def _add_response_to_prompt_inputs(self, prompt_inputs: dict[str, torch.Tensor], response: torch.Tensor) -> dict[str, torch.Tensor]: - """Add the response to the prompt inputs.""" - prompt_inputs["input_ids"] = torch.cat([prompt_inputs["input_ids"], response], dim=1) - prompt_inputs["attention_mask"] = torch.cat([prompt_inputs["attention_mask"], torch.ones_like(response)], dim=1) - return prompt_inputs - def _generate_completions_with_tools( self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor] ) -> torch.Tensor: @@ -437,29 +475,18 @@ def _generate_single_completion_with_tools( - image_grid_thw: a 1x3 tensor with values: [1, 46, 44]. (Note that 46*44 is 2024). """ - out = [] + conv = SingleConversationWithTools(prompt_inputs, self.tool_defn, self.processing_class) # Loop until tool isn't called. while True: prompt_completion_ids = self._generate_completions(model, prompt_inputs, num_generations=1) # prompt_completion_ids is a tensor of shape (1, L) # Because we only generated one completion. # L is 875 here. It's just token ids, nothing else. - out.append(prompt_completion_ids) - # Check if the stop string is in the completions - # We need to convert the tensor to a string. - prompt_completion_str = self.processing_class.tokenizer.decode(prompt_completion_ids[0], skip_special_tokens=True) - if self.tool_defn.completion_has_tool_call(prompt_completion_str): - tool_response_str = self.tool_defn.call_tool(prompt_completion_str) - tool_response_ids = self.processing_class.tokenizer.encode(tool_response_str, add_special_tokens=False) - out.append(tool_response_ids) - prompt_inputs = self._add_response_to_prompt_inputs(prompt_inputs, prompt_completion_ids) - prompt_inputs = self._add_response_to_prompt_inputs(prompt_inputs, tool_response_ids) - # Note: we're gonna have to figure out images. - else: - # No tool call, so we're done. + tool_used = conv.process_response(prompt_completion_ids) + if not tool_used: break - all_out = torch.cat(out, dim=1) # dim=1 means the token sequence gets longer. - return all_out + + return conv.get_response() def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: From 92fab09ba347e69a3498f24ee1cb37000015cac2 Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Wed, 5 Feb 2025 23:16:05 -0800 Subject: [PATCH 07/16] More debug output. --- trl/trainer/qwen_grpo_trainer.py | 36 ++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 521e0161b51..49a821021bc 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +VERBOSE = True + import os import textwrap import warnings @@ -108,6 +110,10 @@ def process_response(self, prompt_completion_ids: torch.Tensor) -> bool: def _add_response_to_prompt_inputs(self, prompt_inputs: dict[str, torch.Tensor], response: torch.Tensor) -> dict[str, torch.Tensor]: """Add the response to the prompt inputs.""" + if VERBOSE: + import pdb; pdb.set_trace() + addition_str = self.processing_class.decode(response[0]) + print(f"Adding response: {addition_str}") prompt_inputs["input_ids"] = torch.cat([prompt_inputs["input_ids"], response], dim=1) ones = torch.ones_like(response, device=response.device) prompt_inputs["attention_mask"] = torch.cat([prompt_inputs["attention_mask"], ones], dim=1) @@ -116,6 +122,7 @@ def _add_response_to_prompt_inputs(self, prompt_inputs: dict[str, torch.Tensor], def get_response(self) -> torch.Tensor: """Returns the response as a tensor.""" + # String together all the response tensors on their long dimension. return torch.cat(self.response, dim=1) class QwenGRPOTrainer(Trainer): @@ -462,7 +469,7 @@ def _generate_completions_with_tools( return final def _generate_single_completion_with_tools( - self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor] + self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor], max_steps: int = 10 ) -> torch.Tensor: """Iterates between generation and tool calling. @@ -476,17 +483,28 @@ def _generate_single_completion_with_tools( (Note that 46*44 is 2024). """ conv = SingleConversationWithTools(prompt_inputs, self.tool_defn, self.processing_class) - # Loop until tool isn't called. - while True: + # Loop until tool isn't called, of we max out + for step in range(max_steps): + if VERBOSE: + print(f"\n\n\nGenerating completion with tool call. Step {step}. Shapes of inputs:") + for key, val in prompt_inputs.items(): + print(f"{key}: {val.shape}") + print(f"Text of the prompt: {self.processing_class.decode(prompt_inputs['input_ids'][0])}") prompt_completion_ids = self._generate_completions(model, prompt_inputs, num_generations=1) - # prompt_completion_ids is a tensor of shape (1, L) - # Because we only generated one completion. - # L is 875 here. It's just token ids, nothing else. - tool_used = conv.process_response(prompt_completion_ids) - if not tool_used: + # prompt_completion_ids is a tensor of shape (1, L) Because we only generated one completion. + # Note that L includes both the prompt and the response. + # We only want to process the response, so we'll strip the prompt. + ids_to_process = prompt_completion_ids[:, len(prompt_inputs["input_ids"][0]):] + tool_was_used = conv.process_response(ids_to_process) + if not tool_was_used: break - return conv.get_response() + response_ids = conv.get_response() + import pdb; pdb.set_trace() + if VERBOSE: + print(f"\n\n\nDONE!") + print(f"Text of the response: {self.processing_class.decode(response_ids[0,:])}") + return response_ids def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: From f73e9135916564a44e2ea81b243afbcf11499574 Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Thu, 6 Feb 2025 09:13:20 -0800 Subject: [PATCH 08/16] Taking out pdb. --- trl/trainer/qwen_grpo_trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 49a821021bc..a05921f3333 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -111,7 +111,6 @@ def process_response(self, prompt_completion_ids: torch.Tensor) -> bool: def _add_response_to_prompt_inputs(self, prompt_inputs: dict[str, torch.Tensor], response: torch.Tensor) -> dict[str, torch.Tensor]: """Add the response to the prompt inputs.""" if VERBOSE: - import pdb; pdb.set_trace() addition_str = self.processing_class.decode(response[0]) print(f"Adding response: {addition_str}") prompt_inputs["input_ids"] = torch.cat([prompt_inputs["input_ids"], response], dim=1) @@ -500,7 +499,6 @@ def _generate_single_completion_with_tools( break response_ids = conv.get_response() - import pdb; pdb.set_trace() if VERBOSE: print(f"\n\n\nDONE!") print(f"Text of the response: {self.processing_class.decode(response_ids[0,:])}") From 8186689d64c9d1f046ad4b5c508a075bd9231189 Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Thu, 6 Feb 2025 16:28:56 -0800 Subject: [PATCH 09/16] Taking out secret decoder ring. --- trl/trainer/qwen_grpo_trainer.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 8ae41fd70e3..69c42aa4692 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -767,16 +767,3 @@ def create_model_card( model_card.save(os.path.join(self.args.output_dir, "README.md")) - - -def print_random_alphabet(n:int=4) -> None: - """Print the uppercase alphabet in a random order with four spaces between each letter. - For the secret-decoder-ring task. - """ - import random - import string - letters = list(string.ascii_uppercase) - random.shuffle(letters) - js = " " * n - print(js.join(letters[0:13])) - print(js.join(letters[13:])) \ No newline at end of file From c6da40070110110b83562b1d10c37da777e215e7 Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Thu, 6 Feb 2025 16:48:38 -0800 Subject: [PATCH 10/16] Fixing bug with including the prompt in the completion output. --- trl/trainer/qwen_grpo_trainer.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 69c42aa4692..e00d260b452 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -VERBOSE = True +VERBOSE = False import os import textwrap @@ -122,11 +122,16 @@ def _add_response_to_prompt_inputs(self, prompt_inputs: dict[str, torch.Tensor], return prompt_inputs - def get_response(self) -> torch.Tensor: - """Returns the response as a tensor.""" + def get_just_completion_ids(self) -> torch.Tensor: + """Returns the response (not including the prompt) as a tensor.""" # String together all the response tensors on their long dimension. return torch.cat(self.response, dim=1) + def get_prompt_completion_ids(self) -> torch.Tensor: + """Returns the prompt and completion as a tensor. The full completion includes the prompt and the response.""" + out = torch.cat([self.prompt_inputs["input_ids"], self.get_just_completion_ids()], dim=1) + return out + class QwenGRPOTrainer(Trainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -480,16 +485,19 @@ def _generate_single_completion_with_tools( # prompt_completion_ids is a tensor of shape (1, L) Because we only generated one completion. # Note that L includes both the prompt and the response. # We only want to process the response, so we'll strip the prompt. - ids_to_process = prompt_completion_ids[:, len(prompt_inputs["input_ids"][0]):] + input_length = len(prompt_inputs["input_ids"][0]) + ids_to_process = prompt_completion_ids[:, input_length:] tool_was_used = conv.process_response(ids_to_process) if not tool_was_used: break - response_ids = conv.get_response() if VERBOSE: + just_completion_ids = conv.get_just_completion_ids() print(f"\n\n\nDONE!") - print(f"Text of the response: {self.processing_class.decode(response_ids[0,:])}") - return response_ids + print(f"Text of the response: {self.processing_class.decode(just_completion_ids[0,:])}") + print(f"^^^ I said DONE!\n\n\n\n\n\n") + return conv.get_prompt_completion_ids() + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: From d2ee50b2de82efa8f19f60f13a41bcf2df756159 Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Thu, 6 Feb 2025 23:13:28 -0800 Subject: [PATCH 11/16] VERBOSE from environment. Catches exceptions in tool calls. --- trl/trainer/qwen_grpo_trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index e00d260b452..16e1cad5f49 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -VERBOSE = False import os import textwrap @@ -53,6 +52,7 @@ from .grpo_config import GRPOConfig from .utils import generate_model_card, get_comet_experiment_url, pad +VERBOSE = os.environ.get("VERBOSE", "false").lower() == "true" if is_peft_available(): from peft import PeftConfig, get_peft_model @@ -98,7 +98,10 @@ def process_response(self, prompt_completion_ids: torch.Tensor) -> bool: # Check if the stop string is in the completions # We need to convert the tensor to a string. if self.tool_defn.completion_has_tool_call(prompt_completion_str): - tool_response_str = self.tool_defn.call_tool(prompt_completion_str) + try: + tool_response_str = self.tool_defn.call_tool(prompt_completion_str) + except Exception as e: + tool_response_str = f"Tool failed: {e}\n" tool_response_ids_list = self.processing_class.tokenizer.encode(tool_response_str, add_special_tokens=False) tool_response_ids = torch.tensor(tool_response_ids_list, device=prompt_completion_ids.device) # [L] tool_response_ids = tool_response_ids[None, :] # [1, L] From 5ac1d04f4aee7ff5ae00f5af98aeb118c1c8bc75 Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Fri, 7 Feb 2025 13:12:42 -0800 Subject: [PATCH 12/16] Multiple verbosity levels. --- trl/trainer/qwen_grpo_trainer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 16e1cad5f49..5ce42e7b3ba 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -52,7 +52,7 @@ from .grpo_config import GRPOConfig from .utils import generate_model_card, get_comet_experiment_url, pad -VERBOSE = os.environ.get("VERBOSE", "false").lower() == "true" +VERBOSE = int(os.environ.get("VERBOSE", "0")) if is_peft_available(): from peft import PeftConfig, get_peft_model @@ -116,7 +116,7 @@ def process_response(self, prompt_completion_ids: torch.Tensor) -> bool: def _add_response_to_prompt_inputs(self, prompt_inputs: dict[str, torch.Tensor], response: torch.Tensor) -> dict[str, torch.Tensor]: """Add the response to the prompt inputs.""" - if VERBOSE: + if VERBOSE > 0: addition_str = self.processing_class.decode(response[0]) print(f"Adding response: {addition_str}") prompt_inputs["input_ids"] = torch.cat([prompt_inputs["input_ids"], response], dim=1) @@ -479,7 +479,7 @@ def _generate_single_completion_with_tools( conv = SingleConversationWithTools(prompt_inputs, self.tool_defn, self.processing_class) # Loop until tool isn't called, of we max out for step in range(max_steps): - if VERBOSE: + if VERBOSE > 1: print(f"\n\n\nGenerating completion with tool call. Step {step}. Shapes of inputs:") for key, val in prompt_inputs.items(): print(f"{key}: {val.shape}") @@ -494,12 +494,12 @@ def _generate_single_completion_with_tools( if not tool_was_used: break - if VERBOSE: - just_completion_ids = conv.get_just_completion_ids() - print(f"\n\n\nDONE!") - print(f"Text of the response: {self.processing_class.decode(just_completion_ids[0,:])}") - print(f"^^^ I said DONE!\n\n\n\n\n\n") - return conv.get_prompt_completion_ids() + full_completion = conv.get_prompt_completion_ids() + if VERBOSE > 0: + print(f"\nDONE!") + print(f"Final completion (with prompt):\n{self.processing_class.decode(full_completion[0,:])}") + print(f"^^^ Final Response!\n\n\n\n\n\n") + return full_completion def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): From efc25dfe7a710052ccc24f20b88a59a1781bc1b3 Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Fri, 7 Feb 2025 15:40:07 -0800 Subject: [PATCH 13/16] Prints when it's done generating. --- trl/trainer/qwen_grpo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 5ce42e7b3ba..346c4cdb724 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -582,6 +582,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # Stack all padded completions prompt_completion_ids = torch.cat(padded_completions, dim=0) + if VERBOSE > 0: + print(f"Done generating {num_generations} completions.") prompt_length = prompt_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] From d88f6c259d81f9f827aaba7e79bf1d32a1da37b8 Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Sat, 8 Feb 2025 12:01:58 -0800 Subject: [PATCH 14/16] Loss magnifier to avoid underflow. --- trl/trainer/grpo_config.py | 5 +++++ trl/trainer/qwen_grpo_trainer.py | 23 ++++++++++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 0fd0d9f5d28..e4736d3e5a4 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -174,3 +174,8 @@ class GRPOConfig(TrainingArguments): default=0.04, metadata={"help": "KL coefficient."}, ) + + loss_magnifier: float = field( + default=1.0e4, + metadata={"help": "Multiplies the loss on the way out to avoid underflow."}, + ) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 346c4cdb724..b5bc9408520 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -13,14 +13,15 @@ # limitations under the License. -import os -import textwrap -import warnings from collections import defaultdict from dataclasses import dataclass from typing import Any, Callable, Optional, Union from unittest.mock import patch import copy +import math +import os +import textwrap +import warnings import torch import torch.utils.data @@ -217,6 +218,10 @@ class QwenGRPOTrainer(Trainer): model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + tool_defn ([`~trl.ToolDefinition`], *optional*, defaults to `None`): + Tool definition used to define the tool call. + loss_magnifier (float, *optional*, defaults to 1000.0): + Multiplies the loss on the way out to avoid underflow. """ _tag_names = ["trl", "grpo"] @@ -235,6 +240,7 @@ def __init__( optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, tool_defn: Optional[ToolDefinition] = None, + loss_magnifier: float = 1.0e4, ): # Args if args is None: @@ -323,6 +329,7 @@ def data_collator(features): # No data collation is needed in GRPO self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper self.use_vllm = args.use_vllm + self.loss_magnifier = args.loss_magnifier self.beta = args.beta @@ -677,6 +684,12 @@ def get_per_token_logps(model, input_ids): per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) per_token_loss = -(per_token_loss - self.beta * per_token_kl) loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + # We could break this down like this: + #loss = (per_token_loss * completion_mask).sum(dim=1) # Now we have a [B] tensor of losses for each example + #loss = loss / completion_mask.sum(dim=1) # normalize by number of unmasked tokens. Still [B] + #loss = loss.mean() # average across the batch. + # Rescale to avoid underflow - we see losses underflow to 0 when they're around 1e-7, which is common + loss = loss * self.loss_magnifier # Log the metrics completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() @@ -780,3 +793,7 @@ def create_model_card( model_card.save(os.path.join(self.args.output_dir, "README.md")) + + +def simple_stats(x) -> str: + return f"Min: {x.min()}, Mean: {x.mean()}, Max: {x.max()}" \ No newline at end of file From ae74b4aa3b80d5edff0fb33cebccb3fd72d09093 Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Sat, 8 Feb 2025 17:57:05 -0800 Subject: [PATCH 15/16] Turning off loss magnifier by default. --- trl/trainer/grpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index e4736d3e5a4..cce3a1f6941 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -176,6 +176,6 @@ class GRPOConfig(TrainingArguments): ) loss_magnifier: float = field( - default=1.0e4, + default=1.0, metadata={"help": "Multiplies the loss on the way out to avoid underflow."}, ) From 68965d1613de28dc48a1145d60e63d63f692079d Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Sun, 9 Feb 2025 13:01:30 -0800 Subject: [PATCH 16/16] Reducing loss magnifier default to 1 --- trl/trainer/qwen_grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index b5bc9408520..c99938d2ee0 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -220,7 +220,7 @@ class QwenGRPOTrainer(Trainer): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. tool_defn ([`~trl.ToolDefinition`], *optional*, defaults to `None`): Tool definition used to define the tool call. - loss_magnifier (float, *optional*, defaults to 1000.0): + loss_magnifier (float, *optional*, defaults to 1.0): Multiplies the loss on the way out to avoid underflow. """ @@ -240,7 +240,7 @@ def __init__( optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, tool_defn: Optional[ToolDefinition] = None, - loss_magnifier: float = 1.0e4, + loss_magnifier: float = 1.0, ): # Args if args is None: