diff --git a/cli/eval/__init__.py b/cli/eval/__init__.py new file mode 100644 index 00000000..7f625a26 --- /dev/null +++ b/cli/eval/__init__.py @@ -0,0 +1 @@ +"""CLI for test-based evaluation""" diff --git a/cli/eval/commands.py b/cli/eval/commands.py new file mode 100644 index 00000000..ebc85dd6 --- /dev/null +++ b/cli/eval/commands.py @@ -0,0 +1,49 @@ +"""Use the eval command for LLM-as-a-judge evaluation, given a (set of) test file(s) consisting of prompts, instructions, and optionally, targets. +Instantiate a generator model to produce candidate responses, and a judge model to determine whether the instructions have been followed.""" + +import typer + +eval_app = typer.Typer(name="eval") + + +def eval_run( + test_files: list[str] = typer.Argument( + ..., help="List of paths to json/jsonl files containing test cases" + ), + backend: str = typer.Option("ollama", "--backend", "-b", help="Generation backend"), + model: str = typer.Option(None, "--model", help="Generation model name"), + max_gen_tokens: int = typer.Option( + 256, "--max-gen-tokens", help="Max tokens to generate for responses" + ), + judge_backend: str = typer.Option( + None, "--judge-backend", "-jb", help="Judge backend" + ), + judge_model: str = typer.Option(None, "--judge-model", help="Judge model name"), + max_judge_tokens: int = typer.Option( + 256, "--max-judge-tokens", help="Max tokens for the judge model's judgement." + ), + output_path: str = typer.Option( + "eval_results", "--output-path", "-o", help="Output path for results" + ), + output_format: str = typer.Option( + "json", "--output-format", help="Either json or jsonl format for results" + ), + continue_on_error: bool = typer.Option(True, "--continue-on-error"), +): + from cli.eval.runner import run_evaluations + + run_evaluations( + test_files=test_files, + backend=backend, + model=model, + max_gen_tokens=max_gen_tokens, + judge_backend=judge_backend, + judge_model=judge_model, + max_judge_tokens=max_judge_tokens, + output_path=output_path, + output_format=output_format, + continue_on_error=continue_on_error, + ) + + +eval_app.command("run")(eval_run) diff --git a/cli/eval/runner.py b/cli/eval/runner.py new file mode 100644 index 00000000..199581f1 --- /dev/null +++ b/cli/eval/runner.py @@ -0,0 +1,350 @@ +import json +import re +from pathlib import Path +from typing import List + +import mellea +from mellea.stdlib.base import ModelOutputThunk +from mellea.stdlib.test_based_eval import TestBasedEval +from mellea.backends.types import ModelOption + +from rich.console import Console +from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn + +console = Console() + + +class InputEvalResult: + """Store results of a single input evaluation (within a unit test).""" + + def __init__( + self, + input_text: str, + model_output: str, + validation_passed: bool, + score: int, + validation_reason: str, # add input_id + ): + self.input_text = input_text + self.model_output = model_output + self.validation_passed = validation_passed + self.score = score + self.validation_reason = validation_reason + + def to_dict(self): + return { + "input": self.input_text, + "model_output": self.model_output, + "passed": self.validation_passed, + "score": self.score, + "justification": self.validation_reason, + } + + +class TestEvalResult: + """Store results of a single test evaluation.""" + + def __init__(self, test_eval: TestBasedEval, input_results: list[InputEvalResult]): + self.test_eval = test_eval + self.input_results = input_results + + def to_dict(self): + return { + "test_id": self.test_eval.test_id, + "source": self.test_eval.source, + "name": self.test_eval.name, + "instructions": self.test_eval.instructions, + "input_results": [r.to_dict() for r in self.input_results], + "expected_targets": self.test_eval.targets, + "passed": self.passed_count, + "total_count": self.total_count, + "pass_rate": self.pass_rate, + } + + @property + def passed_count(self) -> int: + return sum(1 for r in self.input_results if r.validation_passed) + + @property + def total_count(self) -> int: + return len(self.input_results) + + @property + def pass_rate(self) -> float: + return self.passed_count / self.total_count if self.total_count > 0 else 0.0 + + +def create_session( + backend: str, model: str | None, max_tokens: int | None +) -> mellea.MelleaSession: + """Create a mellea session with the specified backend and model.""" + + model_id = None + if model: + if model.isupper() or "_" in model: + if hasattr(mellea.model_ids, model): + model_id = getattr(mellea.model_ids, model) + else: + model_id = model + else: + model_id = model + else: + model_id = mellea.model_ids.IBM_GRANITE_4_MICRO_3B + + try: + backend_lower = backend.lower() + + if backend_lower == "ollama": + from mellea.backends.ollama import OllamaModelBackend + + backend_instance = OllamaModelBackend( + model_id=model_id, + model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}, + ) + + elif backend_lower == "openai": + from mellea.backends.openai import OpenAIBackend + + backend_instance = OpenAIBackend( + model_id=model_id, + model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}, + ) + + elif backend_lower in ["hf", "huggingface"]: + from mellea.backends.huggingface import LocalHFBackend + + backend_instance = LocalHFBackend( + model_id=model_id, + model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}, + ) + + elif backend_lower == "watsonx": + from mellea.backends.watsonx import WatsonxAIBackend + + backend_instance = WatsonxAIBackend( + model_id=model_id, + model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}, + ) + + elif backend_lower == "litellm": + from mellea.backends.litellm import LiteLLMBackend + + backend_instance = LiteLLMBackend( + model_id=model_id, + model_options={ModelOption.MAX_NEW_TOKENS: max_tokens}, + ) + + else: + raise ValueError( + f"Unknown backend: {backend}. Supported: ollama, openai, hf, watsonx, litellm" + ) + + # create session with backend instance + from mellea.stdlib.base import SimpleContext + + session = mellea.MelleaSession(backend=backend_instance, ctx=SimpleContext()) + return session + + except Exception as e: + console.print( + f"[red]Error creating session with backend={backend}, model={model_id}: {e}[/red]" + ) + raise + + +def run_evaluations( + test_files: List[str], + backend: str, + model: str | None, + max_gen_tokens: int | None, + judge_backend: str | None, + judge_model: str | None, + max_judge_tokens: int | None, + output_path: str, + output_format: str, + continue_on_error: bool, +): + """Run all 'unit test' evaluations + + Each test file should be a json containing: + "id": an id that is unique to this test file + "source": the origin for the evaluation prompts, else "N/A" + "name": an instruction-following attribute that the user intends to evaluate through this test + "instructions": a set (in string form) of requirements which the generation should follow; the judge will evaluate if these are satisfied + "examples": a list of entries containing an input_id, an input(prompt), and a list of targets. Each input may have multiple (or no) targets; inputs and targets are in messages format. + """ + all_test_evals: List[TestBasedEval] = [] + + for test_file in test_files: + try: + test_evals = TestBasedEval.from_json_file(test_file) + all_test_evals.extend(test_evals) + console.print(f"Loaded {len(test_evals)} test evaluations from {test_file}") + except Exception as e: + console.print(f"Error loading {test_file}") + + if not all_test_evals: + console.print("Failed to load any test evaluations") + return + + console.print(f"Total test evals to run: {len(all_test_evals)}") + total_inputs = sum(len(test_eval.inputs) for test_eval in all_test_evals) + console.print(f"Total inputs to run: {total_inputs}") + + console.print(f"Generation model: {model}") + console.print(f"Judge model: {judge_model}") + + m = create_session(backend=backend, model=model, max_tokens=max_gen_tokens) + judge_session = create_session( + backend=judge_backend, model=judge_model, max_tokens=max_judge_tokens + ) + + all_results = [] + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + console=console, + ) as progress: + task = progress.add_task("Running evals", total=len(all_test_evals)) + for test_eval in all_test_evals: + try: + result = execute_test_eval( + test_eval=test_eval, + generation_session=m, + judge_session=judge_session, + ) + all_results.append(result) + except Exception as e: + console.print(f"Error {e} on test {test_eval.test_id}") + if not continue_on_error: + raise + + progress.advance(task) + + summary_stats(all_results) + save_results(all_results, output_path, output_format) + + m.cleanup() + judge_session.cleanup() + + +def execute_test_eval( + test_eval: TestBasedEval, + generation_session: mellea.MelleaSession, + judge_session: mellea.MelleaSession, +) -> TestEvalResult: + """Execute a single test evaluation + For each input in the test, generate a response using generation_session + Then, after all inputs are processed, validate using judge_session. + """ + + input_results = [] + + # for all inputs, generate responses with generator + for idx, input_text in enumerate(test_eval.inputs): + result: ModelOutputThunk = generation_session.act(input_text) + model_output = str(result) + + targets_for_input = ( + test_eval.targets[idx] if idx < len(test_eval.targets) else [] + ) + + # query the judge + test_eval.set_judge_context( + input_text=input_text, + prediction=model_output, + targets_for_input=targets_for_input, + ) + judge_output_thunk = judge_session.act(test_eval) + judge_output = str(judge_output_thunk) + score, justification = parse_judge_output(judge_output) + passed = score == 1 if score is not None else False + + input_result = InputEvalResult( + input_text=input_text, + model_output=model_output, + validation_passed=passed, + score=score, + validation_reason=justification, + ) + input_results.append(input_result) + + # reset both generator and judge + generation_session.reset() + judge_session.reset() + + test_result = TestEvalResult(test_eval=test_eval, input_results=input_results) + return test_result + + +def parse_judge_output(judge_output: str): + try: + json_match = re.search(r'\{[^}]*"score"[^}]*\}', judge_output, re.DOTALL) + if json_match: + json_str = json_match.group(0) + data = json.loads(json_str) + score = data.get("score") + justification = data.get("justification") + return score, justification + except (json.JSONDecodeError, AttributeError): + pass + + # if the above fails, search the text for the score + score_match = re.search(r'score["\s:]+(\d+)', judge_output, re.IGNORECASE) + if score_match: + score = int(score_match.group(1)) + return score, judge_output + + return None, judge_output + + +def save_results(results: List[TestEvalResult], output_path: str, output_format: str): + output_path_obj = Path(output_path) + if output_path_obj.suffix != f".{output_format}": + output_path_obj = Path(f"{output_path}.{output_format}") + + total_inputs = sum(r.total_count for r in results) + passed_inputs = sum(r.passed_count for r in results) + overall_pass_rate = passed_inputs / total_inputs if total_inputs > 0 else 0.0 + + if output_format == "jsonl": + with output_path_obj.open("w") as f: + for result in results: + f.write(json.dumps(result.to_dict()) + "\n") + else: # json + summary = { + "total_tests": len(results), + "total_inputs": total_inputs, + "passed_inputs": passed_inputs, + "failed_inputs": total_inputs - passed_inputs, + "overall_pass_rate": overall_pass_rate, + } + + with output_path_obj.open("w") as f: + json.dump( + {"summary": summary, "results": [r.to_dict() for r in results]}, + f, + indent=2, + ) + + console.print(f"Results saved to {output_path}") + + +def summary_stats(results: List[TestEvalResult]): + total_inputs = sum(r.total_count for r in results) + passed_inputs = sum(r.passed_count for r in results) + overall_pass_rate = passed_inputs / total_inputs if total_inputs > 0 else 0.0 + + console.print(f"Total number of inputs across tests: {total_inputs}") + console.print(f"Number of inputs passed across tests: {passed_inputs}") + console.print(f"Cumulative Pass Rate: {overall_pass_rate * 100:.1f}%") + + if len(results) > 1: + console.print("Per-Test Breakdown:") + for result in results: + console.print( + f"{result.test_eval.name}:\n\t{result.passed_count}/{result.total_count} ({result.pass_rate * 100:.1f}%)\n\n" + ) diff --git a/cli/m.py b/cli/m.py index 3aa32aa1..07fc14b9 100644 --- a/cli/m.py +++ b/cli/m.py @@ -5,6 +5,7 @@ from cli.alora.commands import alora_app from cli.decompose import app as decompose_app from cli.serve.app import serve +from cli.eval.commands import eval_app cli = typer.Typer(name="m", no_args_is_help=True) @@ -25,3 +26,5 @@ def callback() -> None: # as documented: https://typer.tiangolo.com/tutorial/subcommands/add-typer/#put-them-together. cli.add_typer(alora_app) cli.add_typer(decompose_app) + +cli.add_typer(eval_app) diff --git a/mellea/stdlib/reqlib/md.py b/mellea/stdlib/reqlib/md.py index fc6c4386..9a1836ed 100644 --- a/mellea/stdlib/reqlib/md.py +++ b/mellea/stdlib/reqlib/md.py @@ -14,11 +14,15 @@ def as_markdown_list(ctx: Context) -> list[str] | None: raw_output = ctx.last_output() assert raw_output is not None try: - parsed = mistletoe.Document(raw_output.value) # type: ignore - for child in parsed.children: # type: ignore + assert raw_output.value is not None + parsed = mistletoe.Document(raw_output.value) + assert parsed.children is not None + children = list(parsed.children) + for child in children: if type(child) is not mistletoe.block_token.List: return None - for item in child.children: # type: ignore + assert child.children is not None + for item in child.children: xs.append(mistletoe.base_renderer.BaseRenderer().render(item)) return xs except Exception: @@ -44,10 +48,13 @@ def _md_table(ctx: Context): raw_output = ctx.last_output() assert raw_output is not None try: - parsed = mistletoe.Document(raw_output.value) # type: ignore - if len(parsed.children) != 1: # type: ignore + assert raw_output.value is not None + parsed = mistletoe.Document(raw_output.value) + assert parsed.children is not None + children = list(parsed.children) + if len(children) != 1: return False - return type(parsed.children[0]) is mistletoe.block_token.Table # type: ignore + return type(children[0]) is mistletoe.block_token.Table except Exception: return False diff --git a/mellea/stdlib/test_based_eval.py b/mellea/stdlib/test_based_eval.py new file mode 100644 index 00000000..1e96ad61 --- /dev/null +++ b/mellea/stdlib/test_based_eval.py @@ -0,0 +1,144 @@ +"""LLM Evaluation with Unit Tests in Mellea.""" + +import json +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, Field, field_validator + +from mellea.stdlib.base import CBlock, Component, TemplateRepresentation + + +class Message(BaseModel): + """Schema for a message in the test data.""" + + role: str + content: str + + +class Example(BaseModel): + """Schema for an example in the test data.""" + + input: list[Message] + targets: list[Message] = Field(default_factory=list) + input_id: str = "" + + +class TestData(BaseModel): + """Schema for test data loaded from json.""" + + source: str + name: str + instructions: str + examples: list[Example] = Field(default_factory=list) + id: str + + @field_validator("examples") + @classmethod + def validate_examples(cls, v): + """Ensure examples list is not empty.""" + if not v: + raise ValueError("examples list cannot be empty") + return v + + +class TestBasedEval(Component): + """Each TestBasedEval represents a single unit test.""" + + def __init__( + self, + source: str, + name: str, + instructions: str, + inputs: list[str], + targets: list[list[str]] | None = None, # can be optional + test_id: str | None = None, + input_ids: list[str] | None = None, + ): + """Initialize TestBasedEval (for a single unit test).""" + self.source = source + self.name = name + self.instructions = instructions + self.inputs = inputs + self.targets = targets or [] + self.test_id = test_id + self.input_ids = input_ids or [] + + def parts(self) -> list[Component | CBlock]: + """The set of constituent parts of the Component.""" + return [] + + def format_for_llm(self) -> TemplateRepresentation: + """Formats the test for judge evaluation.""" + return TemplateRepresentation( + obj=self, + args=self._judge_context if hasattr(self, "_judge_context") else {}, + template_order=["*"], + ) + + def set_judge_context( + self, input_text: str, prediction: str, targets_for_input: list[str] + ): + """Set context for judge evaluation.""" + if len(targets_for_input) == 0: # no reference + target_text = "N/A" + elif len(targets_for_input) == 1: + target_text = targets_for_input[0] + else: # enumerate when there are multiple targets + target_text = "\n".join( + [f"{i}. {target}" for i, target in enumerate(targets_for_input, 1)] + ) + + self._judge_context: dict[str, Any] = { + "input": input_text, + "prediction": prediction, + "target": target_text, + "guidelines": self.instructions, + } + + @classmethod + def from_json_file(cls, filepath: str) -> list["TestBasedEval"]: + """Load test evaluations from json/jsonl file, return list of TestBasedEval instances, one per 'unit test'.""" + path = Path(filepath) + + with path.open("r") as f: + data = json.load(f) + + if not isinstance(data, list): + data = [data] + + test_evals = [] + for test_data_dict in data: + try: + test_data = TestData(**test_data_dict) + except Exception as e: + raise ValueError(f"Invalid test data in {filepath}: {e}") + + inputs = [] + targets = [] + input_ids = [] + + for example in test_data.examples: + user_messages = [msg for msg in example.input if msg.role == "user"] + if user_messages: + inputs.append(user_messages[-1].content) + + targets_for_input = [ + msg.content for msg in example.targets if msg.role == "assistant" + ] + targets.append(targets_for_input) + + input_ids.append(example.input_id) + + test_eval = cls( + source=test_data.source, + name=test_data.name, + instructions=test_data.instructions, + inputs=inputs, + targets=targets, + test_id=test_data.id, + input_ids=input_ids, + ) + test_evals.append(test_eval) + + return test_evals diff --git a/mellea/templates/prompts/default/TestBasedEval.jinja2 b/mellea/templates/prompts/default/TestBasedEval.jinja2 new file mode 100644 index 00000000..57a5688b --- /dev/null +++ b/mellea/templates/prompts/default/TestBasedEval.jinja2 @@ -0,0 +1,27 @@ +**Input to the model** + +{{ input }} + +**Model output to be rated** + +{{ prediction }} + +{% if target and target != "N/A" %} +**Ground truth text** + +{{ target }} +{% endif %} + +**Rating Guidelines** +The model output should adhere to the following guidelines: +{{ guidelines }} + +**Scoring Criteria** +* Score 0: The model output violates any of the guidelines. +* Score 1: The model output is well aligned with the ground truth{% if target and target != "N/A" %} - if it exists{% endif %}, the input to the model, and adheres to all guidelines. + +**Return Your Rating** +Return your rating in the following format: +{"score": your_score, "justification": "your_justification"} + +Your rating: \ No newline at end of file