diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 930caf59425..8a21969060e 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -36,7 +36,7 @@ "dpo_trainer": ["DPOTrainer"], "gkd_config": ["GKDConfig"], "gkd_trainer": ["GKDTrainer"], - "qwen_grpo_trainer": ["QwenGRPOTrainer"], + "qwen_grpo_trainer": ["QwenGRPOTrainer", "ToolDefinition"], "grpo_config": ["GRPOConfig"], "grpo_trainer": ["GRPOTrainer"], "iterative_sft_trainer": ["IterativeSFTTrainer"], @@ -134,7 +134,7 @@ from .ppo_trainer import PPOTrainer from .prm_config import PRMConfig from .prm_trainer import PRMTrainer - from .qwen_grpo_trainer import QwenGRPOTrainer + from .qwen_grpo_trainer import QwenGRPOTrainer, ToolDefinition from .reward_config import RewardConfig from .reward_trainer import RewardTrainer from .rloo_config import RLOOConfig @@ -164,3 +164,12 @@ import sys sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + +# Importing from qwen_grpo_trainer to expose them at the package level. +from .qwen_grpo_trainer import QwenGRPOTrainer, ToolDefinition + +# Define __all__ to explicitly list the public API of this package. +__all__ = [ + "QwenGRPOTrainer", + "ToolDefinition", +] diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 0fd0d9f5d28..cce3a1f6941 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -174,3 +174,8 @@ class GRPOConfig(TrainingArguments): default=0.04, metadata={"help": "KL coefficient."}, ) + + loss_magnifier: float = field( + default=1.0, + metadata={"help": "Multiplies the loss on the way out to avoid underflow."}, + ) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 339c529750a..c99938d2ee0 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import textwrap -import warnings + from collections import defaultdict +from dataclasses import dataclass from typing import Any, Callable, Optional, Union from unittest.mock import patch import copy +import math +import os +import textwrap +import warnings import torch import torch.utils.data @@ -50,6 +53,7 @@ from .grpo_config import GRPOConfig from .utils import generate_model_card, get_comet_experiment_url, pad +VERBOSE = int(os.environ.get("VERBOSE", "0")) if is_peft_available(): from peft import PeftConfig, get_peft_model @@ -65,6 +69,73 @@ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] +@dataclass +class ToolDefinition: + """Basic metadata that the trainer needs to know about the tools.""" + stop_string: str + call_tool: Callable[[torch.Tensor], torch.Tensor] + + def completion_has_tool_call(self, completion_str: str) -> bool: + """Check if the completion has a tool call.""" + return self.stop_string in completion_str + + +class SingleConversationWithTools: + """Keeps track of the prompt, and knows how to put together the partial responses and the tool call responses.""" + + def __init__(self, prompt_inputs: dict[str, torch.Tensor], tool_defn: ToolDefinition, processing_class: PreTrainedTokenizerBase): + self.prompt_inputs = prompt_inputs + self.tool_defn = tool_defn + self.response = [] + self.processing_class = processing_class + + def process_response(self, prompt_completion_ids: torch.Tensor) -> bool: + """Adds the response to the conversation, including calling the tool if necessary. + Returns True if there was a tool call, and the conversation should continue. + Returns False if there was no tool call, and the conversation is complete. + """ + self.response.append(prompt_completion_ids) + prompt_completion_str = self.processing_class.tokenizer.decode(prompt_completion_ids[0], skip_special_tokens=True) + # Check if the stop string is in the completions + # We need to convert the tensor to a string. + if self.tool_defn.completion_has_tool_call(prompt_completion_str): + try: + tool_response_str = self.tool_defn.call_tool(prompt_completion_str) + except Exception as e: + tool_response_str = f"Tool failed: {e}\n" + tool_response_ids_list = self.processing_class.tokenizer.encode(tool_response_str, add_special_tokens=False) + tool_response_ids = torch.tensor(tool_response_ids_list, device=prompt_completion_ids.device) # [L] + tool_response_ids = tool_response_ids[None, :] # [1, L] + self.response.append(tool_response_ids) + self.prompt_inputs = self._add_response_to_prompt_inputs(self.prompt_inputs, prompt_completion_ids) + self.prompt_inputs = self._add_response_to_prompt_inputs(self.prompt_inputs, tool_response_ids) + # Note: we're gonna have to figure out images. + return True + else: + # No tool call, so we're done. + return False + + def _add_response_to_prompt_inputs(self, prompt_inputs: dict[str, torch.Tensor], response: torch.Tensor) -> dict[str, torch.Tensor]: + """Add the response to the prompt inputs.""" + if VERBOSE > 0: + addition_str = self.processing_class.decode(response[0]) + print(f"Adding response: {addition_str}") + prompt_inputs["input_ids"] = torch.cat([prompt_inputs["input_ids"], response], dim=1) + ones = torch.ones_like(response, device=response.device) + prompt_inputs["attention_mask"] = torch.cat([prompt_inputs["attention_mask"], ones], dim=1) + return prompt_inputs + + + def get_just_completion_ids(self) -> torch.Tensor: + """Returns the response (not including the prompt) as a tensor.""" + # String together all the response tensors on their long dimension. + return torch.cat(self.response, dim=1) + + def get_prompt_completion_ids(self) -> torch.Tensor: + """Returns the prompt and completion as a tensor. The full completion includes the prompt and the response.""" + out = torch.cat([self.prompt_inputs["input_ids"], self.get_just_completion_ids()], dim=1) + return out + class QwenGRPOTrainer(Trainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -147,6 +218,10 @@ class QwenGRPOTrainer(Trainer): model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + tool_defn ([`~trl.ToolDefinition`], *optional*, defaults to `None`): + Tool definition used to define the tool call. + loss_magnifier (float, *optional*, defaults to 1.0): + Multiplies the loss on the way out to avoid underflow. """ _tag_names = ["trl", "grpo"] @@ -164,6 +239,8 @@ def __init__( callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, + tool_defn: Optional[ToolDefinition] = None, + loss_magnifier: float = 1.0, ): # Args if args is None: @@ -220,6 +297,8 @@ def __init__( ) self.reward_funcs = reward_funcs + self.tool_defn = tool_defn + # Reward processing class if reward_processing_classes is None: reward_processing_classes = [None] * len(reward_funcs) @@ -250,6 +329,7 @@ def data_collator(features): # No data collation is needed in GRPO self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper self.use_vllm = args.use_vllm + self.loss_magnifier = args.loss_magnifier self.beta = args.beta @@ -330,12 +410,19 @@ def data_collator(features): # No data collation is needed in GRPO # synchronize all processes after vLLM has been fully initialized. self.accelerator.wait_for_everyone() else: + # No vLLM, so we use the regular generation config + + stop_strings = None + if self.tool_defn: + stop_strings = [self.tool_defn.stop_string] + self.generation_config = GenerationConfig( max_new_tokens=self.max_completion_length, do_sample=True, temperature=args.temperature, num_return_sequences=self.num_generations, pad_token_id=processing_class.tokenizer.pad_token_id, + stop_strings=stop_strings, ) # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the @@ -369,6 +456,59 @@ def _set_signature_columns_if_needed(self): def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: return inputs + def _generate_completion( + self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor] + ) -> torch.Tensor: + """Generate completion(s) using the model.""" + temp_generation_config = copy.deepcopy(self.generation_config) + temp_generation_config.num_return_sequences = 1 + prompt_completion_ids = model.generate( + **prompt_inputs, + generation_config=temp_generation_config, + tokenizer=self.processing_class.tokenizer, + ) + return prompt_completion_ids + + def _generate_single_completion_with_tools( + self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor], max_steps: int = 10 + ) -> torch.Tensor: + """Iterates between generation and tool calling. + + Note this is currently only called from the non-vLLM path + + prompt_inputs is a dict with the following keys: + - input_ids: [1, 710] ints. Some stuff at the beginning and the end, the middle full of 151655 + - attention_mask: [1, 710] ints. All 1 + - pixel_values: 2024x1176 floats. The image. + - image_grid_thw: a 1x3 tensor with values: [1, 46, 44]. + (Note that 46*44 is 2024). + """ + conv = SingleConversationWithTools(prompt_inputs, self.tool_defn, self.processing_class) + # Loop until tool isn't called, of we max out + for step in range(max_steps): + if VERBOSE > 1: + print(f"\n\n\nGenerating completion with tool call. Step {step}. Shapes of inputs:") + for key, val in prompt_inputs.items(): + print(f"{key}: {val.shape}") + print(f"Text of the prompt: {self.processing_class.decode(prompt_inputs['input_ids'][0])}") + prompt_completion_ids = self._generate_completion(model, prompt_inputs) + # prompt_completion_ids is a tensor of shape (1, L) Because we only generated one completion. + # Note that L includes both the prompt and the response. + # We only want to process the response, so we'll strip the prompt. + input_length = len(prompt_inputs["input_ids"][0]) + ids_to_process = prompt_completion_ids[:, input_length:] + tool_was_used = conv.process_response(ids_to_process) + if not tool_was_used: + break + + full_completion = conv.get_prompt_completion_ids() + if VERBOSE > 0: + print(f"\nDONE!") + print(f"Final completion (with prompt):\n{self.processing_class.decode(full_completion[0,:])}") + print(f"^^^ Final Response!\n\n\n\n\n\n") + return full_completion + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") @@ -419,17 +559,17 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N prompt_inputs_repeated = torch.repeat_interleave(prompt_inputs["input_ids"], self.num_generations, dim=0) prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1) else: - # Regular generation path + # Regular generation path (not using vLLM) with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: # Generate N times, each generate one with the temp_generation_config num_generations = self.generation_config.num_return_sequences - temp_generation_config = copy.deepcopy(self.generation_config) - temp_generation_config.num_return_sequences = 1 all_completions = [] - for i in range(num_generations): - completion = unwrapped_model.generate(**prompt_inputs, generation_config=temp_generation_config) + if self.tool_defn: + completion = self._generate_single_completion_with_tools(unwrapped_model, prompt_inputs) + else: + completion = self._generate_completion(unwrapped_model, prompt_inputs) all_completions.append(completion) # Stack all completions and pad if needed @@ -449,6 +589,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # Stack all padded completions prompt_completion_ids = torch.cat(padded_completions, dim=0) + if VERBOSE > 0: + print(f"Done generating {num_generations} completions.") prompt_length = prompt_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] @@ -542,6 +684,12 @@ def get_per_token_logps(model, input_ids): per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) per_token_loss = -(per_token_loss - self.beta * per_token_kl) loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + # We could break this down like this: + #loss = (per_token_loss * completion_mask).sum(dim=1) # Now we have a [B] tensor of losses for each example + #loss = loss / completion_mask.sum(dim=1) # normalize by number of unmasked tokens. Still [B] + #loss = loss.mean() # average across the batch. + # Rescale to avoid underflow - we see losses underflow to 0 when they're around 1e-7, which is common + loss = loss * self.loss_magnifier # Log the metrics completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() @@ -644,3 +792,8 @@ def create_model_card( ) model_card.save(os.path.join(self.args.output_dir, "README.md")) + + + +def simple_stats(x) -> str: + return f"Min: {x.min()}, Mean: {x.mean()}, Max: {x.max()}" \ No newline at end of file