diff --git a/validator/modules/llm_judge/__init__.py b/validator/modules/llm_judge/__init__.py index 21c3d63..cb11487 100644 --- a/validator/modules/llm_judge/__init__.py +++ b/validator/modules/llm_judge/__init__.py @@ -12,11 +12,12 @@ from loguru import logger from huggingface_hub import HfApi from typing import List, Dict, Any -from validator.modules.llm_judge.prompt import get_prompt +from validator.modules.llm_judge.prompt import get_prompt,template_str from validator.modules.llm_judge.utils import download_file +from validator.modules.llm_judge.constant import SUPPORTED_BASE_MODELS from validator.exceptions import LLMJudgeException, InvalidModelParametersException -from validator.modules.llm_judge.template import template_dict from peft import PeftModel +from jinja2 import Environment from transformers import AutoTokenizer, AutoModelForCausalLM from validator.modules.base import ( BaseValidationModule, @@ -37,7 +38,7 @@ class LLMJudgeConfig(BaseConfig): gen_batch_size: int = 1 eval_batch_size: int = 16 - gen_temperature: float = 0.1 + gen_temperature: float = 0.7 class LLMJudgeMetrics(BaseMetrics): @@ -144,7 +145,6 @@ def _load_model(self, repo_id: str, revision: str = "main", max_params: int = No model_kwargs = dict( trust_remote_code=True, torch_dtype=compute_dtype, - use_cache=False, device_map="auto", ) if is_lora: @@ -157,6 +157,18 @@ def _load_model(self, repo_id: str, revision: str = "main", max_params: int = No with open("judge/adapter_config.json", "r") as f: adapter_config = json.load(f) base_model = adapter_config["base_model_name_or_path"] + if base_model in SUPPORTED_BASE_MODELS: + logger.info( + f"LoRA's base model '{base_model}' is in SUPPORTED_BASE_MODELS. " + f"Using it for tokenizer." + ) + else: + logger.error( + f"LoRA's base model '{base_model}' is not in SUPPORTED_BASE_MODELS. " + f"Marking assignment as failed." + ) + raise + self.hf_tokenizer = AutoTokenizer.from_pretrained( base_model, trust_remote_code=True, use_fast=True, padding_side="left" ) @@ -189,88 +201,6 @@ def _load_model(self, repo_id: str, revision: str = "main", max_params: int = No f"Model parameters {total} exceed limit {max_params}" ) - def _construct_conversation_template( - self, - conversation: List[Dict[str, str]], - base_model: str, - ) -> str: - try: - if base_model not in template_dict: - logger.info(f"Template {base_model} not found, using default") - base_model = "default" - - template = template_dict[base_model] - - conversation_parts = [] - - # Validate conversation structure - if not isinstance(conversation, dict): - raise LLMJudgeException( - f"Conversation must be a dict, got {type(conversation)}" - ) - - if "conversations" not in conversation: - raise LLMJudgeException( - f"Conversation dict must have 'conversations' key" - ) - - if not conversation["conversations"]: - raise LLMJudgeException(f"Conversation 'conversations' list is empty") - - # Use provided system_text or fall back to template default - if template.system_format: - system_prompt = ( - conversation["system"] if "system" in conversation else None - ) - system_content = ( - system_prompt if system_prompt else "You are a helpful assistant." - ) - if system_content: - formatted_system = template.system_format.format( - content=system_content - ) - conversation_parts.append(formatted_system) - - # Multi-turn conversation: format each message according to template - for msg in conversation["conversations"]: - if ( - not isinstance(msg, dict) - or "role" not in msg - or "content" not in msg - ): - logger.warning(f"Skipping invalid message: {msg}") - continue - - if msg["role"] == "user": - user_text = template.user_format.format( - content=msg["content"], - stop_token=self.hf_tokenizer.eos_token, - ) - conversation_parts.append(user_text) - elif msg["role"] == "assistant": - assistant_text = template.assistant_format.format( - content=msg["content"], - stop_token=self.hf_tokenizer.eos_token, - ) - conversation_parts.append(assistant_text) - - conversation_format = "".join(conversation_parts) - - if not conversation_format.strip(): - logger.error( - f"Empty template generated. Template: {base_model}, Conversation: {conversation}, Parts: {conversation_parts}" - ) - raise LLMJudgeException( - f"Generated conversation template is empty after formatting" - ) - - except Exception as e: - raise LLMJudgeException( - f"Failed to construct conversation template: {e}" - ) from e - - return conversation_format - def _generate_response( self, context_length: int, @@ -294,9 +224,28 @@ def _generate_response( # Apply chat template with fallback batch_conversation_templates = [] for conversation in batch_conversations: - template = self._construct_conversation_template( - conversation, - base_model=base_model, + + messages = [] + if "system" in conversation: + messages.append({ + "role": "system", + "content": conversation["system"] + }) + + messages += conversation["conversations"] + tools_for_template = conversation.get("tools", None) + try: + if isinstance(tools_for_template, str): + tools_for_template = json.loads(tools_for_template) + except Exception: + # leave tools_for_template as-is if parsing fails + pass + template = self.hf_tokenizer.apply_chat_template( + messages, + tools=tools_for_template, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, ) # Validate template is not empty @@ -353,10 +302,10 @@ def _generate_response( outputs = self.hf_model.generate( **model_inputs, max_new_tokens=max_length, - temperature=self.config.gen_temperature, + temperature=self.config.gen_temperature, # Non thinking-General 0.7 ,Reasoning 1 do_sample=True, - top_p=0.95, # Nucleus sampling for stability - top_k=50, # Limit vocabulary for stability + top_p=0.8, # Non thinking-General 0.8 ,Reasoning 0.95 + top_k=20, # pad_token_id=self.hf_tokenizer.eos_token_id, eos_token_id=self.hf_tokenizer.eos_token_id, ) @@ -608,7 +557,8 @@ def _load_jsonl_conversations( conversation_to_process = [] reference_response = None tools_info = None - + pending_tool_call_ids: list[str] = [] + tool_call_counter = 0 if "conversations" in json_data: conversations = json_data["conversations"] if isinstance(conversations, list) and conversations: @@ -616,21 +566,70 @@ def _load_jsonl_conversations( for msg in conversations: role = msg.get("role", "") content = msg.get("content", "").strip() - if ( - role in ["user", "assistant", "function_call"] - and content - ): + if not content: + continue + if role == "function_call": + tool_call_counter += 1 + tool_call_id = f"call_{tool_call_counter}" + + try: + call_data = json.loads(content) + tool_call_msg = { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": tool_call_id, + "type": "function", + "function": { + "name": call_data.get("name", ""), + "arguments": call_data.get("arguments", {}) + }, + } + ], + } + conversation_to_process.append(tool_call_msg) + except (json.JSONDecodeError, KeyError): + tool_call_id = None + conversation_to_process.append( + {"role": "assistant", "content": content} + ) + if tool_call_id: + pending_tool_call_ids.append(tool_call_id) + + elif role == "observation": + if pending_tool_call_ids: + tool_call_id = pending_tool_call_ids.pop(0) + else: + tool_call_id = "call_unknown" + + conversation_to_process.append( + { + "role": "tool", + "tool_call_id": tool_call_id, + "content": content, + } + ) + elif role in ["user", "assistant"]: conversation_to_process.append( {"role": role, "content": content} ) # Extract reference response (last assistant or function_call message) reference_response = None + if conversation_to_process: - last_msg = conversation_to_process[-1] - if last_msg["role"] in ["assistant", "function_call"]: + last_msg = conversations[-1] + if last_msg["role"] in ["assistant"]: reference_response = last_msg["content"] conversation_to_process = conversation_to_process[:-1] + elif last_msg["role"] in ["function_call"]: + env = Environment(trim_blocks=True, lstrip_blocks=True) + conversation_template = env.from_string(template_str) + reference_response = conversation_template.render( + messages=[conversation_to_process[-1]], trim_blocks=True, + lstrip_blocks=True) + conversation_to_process = conversation_to_process[:-1] # Extract tools information if available (for function_call evaluation) if "tools" in json_data: @@ -651,6 +650,8 @@ def _load_jsonl_conversations( continue input_conversations_data["conversations"] = conversation_to_process + if tools_info is not None: + input_conversations_data["tools"] = tools_info input_conversations.append( { @@ -916,7 +917,10 @@ def validate(self, data: LLMJudgeInputData, **kwargs) -> LLMJudgeMetrics: self._load_model(data.hg_repo_id, data.revision, data.max_params) except InvalidModelParametersException as e: # lowest possible reward for invalid model parameters - logger.info(f"Invalid model parameters: {e}") + logger.error(f"Invalid model parameters: {e}") + return LLMJudgeMetrics(score=LOWEST_POSSIBLE_SCORE) + except Exception as e: + logger.error(f"Exception when load model: {e}") return LLMJudgeMetrics(score=LOWEST_POSSIBLE_SCORE) # Stage 1: Generate all responses diff --git a/validator/modules/llm_judge/constant.py b/validator/modules/llm_judge/constant.py new file mode 100644 index 0000000..30efd9f --- /dev/null +++ b/validator/modules/llm_judge/constant.py @@ -0,0 +1,12 @@ +SUPPORTED_BASE_MODELS = [ + # qwen3.5 + "Qwen/Qwen3.5-0.8B", + "Qwen/Qwen3.5-0.8B-Base", + "Qwen/Qwen3.5-2B", + "Qwen/Qwen3.5-2B-Base", + "Qwen/Qwen3.5-4B", + "Qwen/Qwen3.5-4B-Base", + "Qwen/Qwen3.5-9B", + "Qwen/Qwen3.5-9B-Base", + "Qwen/Qwen3.5-27B", +] \ No newline at end of file diff --git a/validator/modules/llm_judge/environment.yml b/validator/modules/llm_judge/environment.yml index cc75b2f..361d284 100644 --- a/validator/modules/llm_judge/environment.yml +++ b/validator/modules/llm_judge/environment.yml @@ -8,11 +8,11 @@ dependencies: - openai>=1.0.0 # OpenAI API client - httpx # HTTP client for OpenAI requests - pydantic>=2.0.0 # Data validation and parsing - - transformers==4.49.0 # HuggingFace transformers library + - transformers==5.3.0 # HuggingFace transformers library - torch>=1.13.1 # PyTorch for model inference - accelerate>=0.27.2 # For efficient model loading - loguru>=0.6.0 # Logging library - - huggingface-hub==0.29.1 + - huggingface-hub==1.5.0 - tenacity - - peft>=0.10.0,<0.18.0 + - peft==0.18.1 - python-dotenv # Load environment variables from .env file \ No newline at end of file diff --git a/validator/modules/llm_judge/prompt.py b/validator/modules/llm_judge/prompt.py index 726d3e7..5bbf37e 100644 --- a/validator/modules/llm_judge/prompt.py +++ b/validator/modules/llm_judge/prompt.py @@ -181,3 +181,45 @@ def function_call_ref_eval_prompt( tools=Tools, assistant_response=assistant_response, ) + +template_str= """{% for message in messages %} + +{% if message.role == "system" %} + +{{ message.content }} + + +{% elif message.role == "user" %} + +{{ message.content }} + + +{% elif message.role == "assistant" %} + + {% if message.tool_calls %} + + {% for tool in message.tool_calls %} + + {% set args = tool.function.arguments %} + {% if args is string %} + {% set args = args | from_json %} + {% endif %} + {% for key, value in args.items() %} +{{ value }} + {% endfor %} + + {% endfor %} + + {% else %} + +{{ message.content }} + + {% endif %} +{% elif message.role == "tool" %} + +{{ message.content }} + + +{% endif %} + +{% endfor %}""" \ No newline at end of file diff --git a/validator/modules/llm_judge/template.py b/validator/modules/llm_judge/template.py deleted file mode 100644 index 862c52f..0000000 --- a/validator/modules/llm_judge/template.py +++ /dev/null @@ -1,184 +0,0 @@ -from dataclasses import dataclass -from typing import Dict - - -@dataclass -class Template: - template_name: str - system_format: str - user_format: str - assistant_format: str - tool_format: str - function_format: str - observation_format: str - system: str - stop_word: str - - -template_dict: Dict[str, Template] = dict() - - -def register_template( - template_name, - system_format, - user_format, - assistant_format, - tool_format, - function_format, - observation_format, - system, - stop_word=None, -): - template_dict[template_name] = Template( - template_name=template_name, - system_format=system_format, - user_format=user_format, - assistant_format=assistant_format, - tool_format=tool_format, - function_format=function_format, - observation_format=observation_format, - system=system, - stop_word=stop_word, - ) - - -register_template( - template_name="default", - system_format="System: {content}\n\n", - user_format="User: {content}\nAssistant: ", - assistant_format="{content} {stop_token}", - tool_format="{content}", - function_format="{content}", - observation_format="Tool\n{content}\n", - system=None, - stop_word=None, -) - - -register_template( - template_name="qwen1.5", - system_format="<|im_start|>system\n{content}<|im_end|>\n", - user_format="<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n", - assistant_format="{content}<|im_end|>\n", - tool_format="{content}", - function_format="{content}", - observation_format="<|im_start|>tool\n{content}\n<|im_start|>assistant\n", - system="You are a helpful assistant.", - stop_word="<|im_end|>", -) - -register_template( - template_name="yi", - system_format="<|im_start|>system\n{content}<|im_end|>\n", - user_format="<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n", - assistant_format="{content}<|im_end|>\n", - tool_format="{content}", - function_format="{content}", - observation_format="<|im_start|>tool\n{content}\n<|im_start|>assistant\n", - system=None, - stop_word="<|im_end|>", -) - - -register_template( - template_name="zephyr", - system_format="<|system|>\n{content}", - user_format="<|user|>\n{content}\n<|assistant|>\n", - assistant_format="{content}\n", - tool_format="{content}", - function_format="{content}", - observation_format="<|tool|>\n{content}\n<|assistant|>\n", - system=None, - stop_word="", -) - -register_template( - template_name="mistral", - system_format="", - user_format="[INST]{content}[/INST]", - assistant_format="{content}", - tool_format="{content}", - function_format="{content}", - observation_format="{content}", - system="", - stop_word="", -) - -register_template( - template_name="mixtral", - system_format="", - user_format="[INST]{content}[/INST]", - assistant_format="{content}", - tool_format="{content}", - function_format="{content}", - observation_format="{content}", - system="", - stop_word="", -) - -register_template( - template_name="llama2", - system_format="<>\n{content}\n<>\n\n", - user_format="[INST]{content}[/INST]", - assistant_format="{content} ", - tool_format="{content}", - function_format="{content}", - observation_format="{content}", - system="You are a helpful, respectful and honest assistant. " - "Always answer as helpfully as possible, while being safe. " - "Your answers should not include any harmful, unethical, " - "racist, sexist, toxic, dangerous, or illegal content. " - "Please ensure that your responses are socially unbiased and positive in nature.\n\n" - "If a question does not make any sense, or is not factually coherent, " - "explain why instead of answering something not correct. " - "If you don't know the answer to a question, please don't share false information.", - stop_word="", -) - -register_template( - template_name="gemma", - system_format="", - user_format="user\n{content}\nmodel\n", - assistant_format="{content}\n", - tool_format="{content}", - function_format="{content}", - observation_format="tool\n{content}\nmodel\n", - system="", - stop_word="", -) - -register_template( - template_name="llama3", - system_format="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>", - user_format="<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", - assistant_format="{content}<|eot_id|>", - tool_format="{content}", - function_format="{content}", - observation_format="<|start_header_id|>tool<|end_header_id|>\n\n{content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", - system=None, - stop_word="<|eot_id|>", -) - -register_template( - template_name="phi3", - system_format=None, - user_format="<|user|>\n{content}<|end|>\n<|assistant|>", - assistant_format="{content}<|end|>\n", - tool_format="{content}", - function_format="{content}", - observation_format="<|tool|>\n{content}<|end|>\n<|assistant|>", - system=None, - stop_word="<|end|>", -) - -register_template( - template_name="phi4", - system_format=None, - user_format="<|user|>\n{content}<|end|>\n<|assistant|>", - assistant_format="{content}<|end|>\n", - tool_format="<|tool|>{content}<|/tool|>", - function_format="<|tool_call|>{content}<|/tool_call|>", - observation_format="<|tool|>\n{content}<|end|>\n<|assistant|>", - system=None, - stop_word="<|end|>", -)