Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 35 additions & 92 deletions validator/modules/llm_judge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from typing import List, Dict, Any
from validator.modules.llm_judge.prompt import get_prompt
from validator.modules.llm_judge.utils import download_file
from validator.modules.llm_judge.constant import SUPPORTED_BASE_MODELS
from validator.exceptions import LLMJudgeException, InvalidModelParametersException
from validator.modules.llm_judge.template import template_dict
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from validator.modules.base import (
Expand All @@ -37,7 +37,7 @@
class LLMJudgeConfig(BaseConfig):
gen_batch_size: int = 1
eval_batch_size: int = 16
gen_temperature: float = 0.1
gen_temperature: float = 0.7


class LLMJudgeMetrics(BaseMetrics):
Expand Down Expand Up @@ -144,7 +144,6 @@ def _load_model(self, repo_id: str, revision: str = "main", max_params: int = No
model_kwargs = dict(
trust_remote_code=True,
torch_dtype=compute_dtype,
use_cache=False,
device_map="auto",
)
if is_lora:
Expand All @@ -157,6 +156,18 @@ def _load_model(self, repo_id: str, revision: str = "main", max_params: int = No
with open("judge/adapter_config.json", "r") as f:
adapter_config = json.load(f)
base_model = adapter_config["base_model_name_or_path"]
if base_model in SUPPORTED_BASE_MODELS:
logger.info(
f"LoRA's base model '{base_model}' is in SUPPORTED_BASE_MODELS. "
f"Using it for tokenizer."
)
else:
logger.error(
f"LoRA's base model '{base_model}' is not in SUPPORTED_BASE_MODELS. "
f"Marking assignment as failed."
)
raise

self.hf_tokenizer = AutoTokenizer.from_pretrained(
base_model, trust_remote_code=True, use_fast=True, padding_side="left"
)
Expand Down Expand Up @@ -189,88 +200,6 @@ def _load_model(self, repo_id: str, revision: str = "main", max_params: int = No
f"Model parameters {total} exceed limit {max_params}"
)

def _construct_conversation_template(
self,
conversation: List[Dict[str, str]],
base_model: str,
) -> str:
try:
if base_model not in template_dict:
logger.info(f"Template {base_model} not found, using default")
base_model = "default"

template = template_dict[base_model]

conversation_parts = []

# Validate conversation structure
if not isinstance(conversation, dict):
raise LLMJudgeException(
f"Conversation must be a dict, got {type(conversation)}"
)

if "conversations" not in conversation:
raise LLMJudgeException(
f"Conversation dict must have 'conversations' key"
)

if not conversation["conversations"]:
raise LLMJudgeException(f"Conversation 'conversations' list is empty")

# Use provided system_text or fall back to template default
if template.system_format:
system_prompt = (
conversation["system"] if "system" in conversation else None
)
system_content = (
system_prompt if system_prompt else "You are a helpful assistant."
)
if system_content:
formatted_system = template.system_format.format(
content=system_content
)
conversation_parts.append(formatted_system)

# Multi-turn conversation: format each message according to template
for msg in conversation["conversations"]:
if (
not isinstance(msg, dict)
or "role" not in msg
or "content" not in msg
):
logger.warning(f"Skipping invalid message: {msg}")
continue

if msg["role"] == "user":
user_text = template.user_format.format(
content=msg["content"],
stop_token=self.hf_tokenizer.eos_token,
)
conversation_parts.append(user_text)
elif msg["role"] == "assistant":
assistant_text = template.assistant_format.format(
content=msg["content"],
stop_token=self.hf_tokenizer.eos_token,
)
conversation_parts.append(assistant_text)

conversation_format = "".join(conversation_parts)

if not conversation_format.strip():
logger.error(
f"Empty template generated. Template: {base_model}, Conversation: {conversation}, Parts: {conversation_parts}"
)
raise LLMJudgeException(
f"Generated conversation template is empty after formatting"
)

except Exception as e:
raise LLMJudgeException(
f"Failed to construct conversation template: {e}"
) from e

return conversation_format

def _generate_response(
self,
context_length: int,
Expand All @@ -294,9 +223,20 @@ def _generate_response(
# Apply chat template with fallback
batch_conversation_templates = []
for conversation in batch_conversations:
template = self._construct_conversation_template(
conversation,
base_model=base_model,

messages = []
if "system" in conversation:
messages.append({
"role": "system",
"content": conversation["system"]
})

messages += conversation["conversations"]
template = self.hf_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)

# Validate template is not empty
Expand Down Expand Up @@ -353,10 +293,10 @@ def _generate_response(
outputs = self.hf_model.generate(
**model_inputs,
max_new_tokens=max_length,
temperature=self.config.gen_temperature,
temperature=self.config.gen_temperature, # Non thinking-General 0.7 ,Reasoning 1
do_sample=True,
top_p=0.95, # Nucleus sampling for stability
top_k=50, # Limit vocabulary for stability
top_p=0.8, # Non thinking-General 0.8 ,Reasoning 0.95
top_k=20, #
pad_token_id=self.hf_tokenizer.eos_token_id,
eos_token_id=self.hf_tokenizer.eos_token_id,
)
Expand Down Expand Up @@ -916,7 +856,10 @@ def validate(self, data: LLMJudgeInputData, **kwargs) -> LLMJudgeMetrics:
self._load_model(data.hg_repo_id, data.revision, data.max_params)
except InvalidModelParametersException as e:
# lowest possible reward for invalid model parameters
logger.info(f"Invalid model parameters: {e}")
logger.error(f"Invalid model parameters: {e}")
return LLMJudgeMetrics(score=LOWEST_POSSIBLE_SCORE)
except Exception as e:
logger.error(f"Exception when load model: {e}")
return LLMJudgeMetrics(score=LOWEST_POSSIBLE_SCORE)

# Stage 1: Generate all responses
Expand Down
12 changes: 12 additions & 0 deletions validator/modules/llm_judge/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SUPPORTED_BASE_MODELS = [
# qwen3.5
"Qwen/Qwen3.5-0.8B",
"Qwen/Qwen3.5-0.8B-Base",
"Qwen/Qwen3.5-2B",
"Qwen/Qwen3.5-2B-Base",
"Qwen/Qwen3.5-4B",
"Qwen/Qwen3.5-4B-Base",
"Qwen/Qwen3.5-9B",
"Qwen/Qwen3.5-9B-Base",
"Qwen/Qwen3.5-27B",
]
6 changes: 3 additions & 3 deletions validator/modules/llm_judge/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ dependencies:
- openai>=1.0.0 # OpenAI API client
- httpx # HTTP client for OpenAI requests
- pydantic>=2.0.0 # Data validation and parsing
- transformers==4.49.0 # HuggingFace transformers library
- transformers==5.3.0 # HuggingFace transformers library
- torch>=1.13.1 # PyTorch for model inference
- accelerate>=0.27.2 # For efficient model loading
- loguru>=0.6.0 # Logging library
- huggingface-hub==0.29.1
- huggingface-hub==1.5.0
- tenacity
- peft>=0.10.0,<0.18.0
- peft==0.18.1
- python-dotenv # Load environment variables from .env file
184 changes: 0 additions & 184 deletions validator/modules/llm_judge/template.py

This file was deleted.