-
Notifications
You must be signed in to change notification settings - Fork 680
Wrapped LLM as a garak generator #1382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
5f0b1af
a7d109b
6a54502
65645dc
094141d
7f20877
f2d730e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| garak.generators.llm | ||
| ========================== | ||
|
|
||
| .. automodule:: garak.generators.llm | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,120 @@ | ||||||||||||||||||||||||||||||||
| # SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| """LLM (simonw/llm) generator support""" | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||
| from typing import List, Union | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| import llm | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| from garak import _config | ||||||||||||||||||||||||||||||||
| from garak.attempt import Message, Conversation | ||||||||||||||||||||||||||||||||
| from garak.generators.base import Generator | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| class LLMGenerator(Generator): | ||||||||||||||||||||||||||||||||
| """Class supporting simonw/llm-managed models | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| See https://pypi.org/project/llm/ and its provider plugins. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Calls model.prompt() with the prompt text and relays the response. Per-provider | ||||||||||||||||||||||||||||||||
| options and API keys are all handled by `llm` (e.g., `llm keys set openai`). | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Set --model_name to the `llm` model id or alias (e.g., "gpt-4o-mini", | ||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
| "claude-3.5-haiku", or a local alias configured in `llm models`). | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Explicitly, garak delegates the majority of responsibility here: | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| * the generator calls prompt() on the resolved `llm` model | ||||||||||||||||||||||||||||||||
| * provider setup, auth, and model-specific options live in `llm` | ||||||||||||||||||||||||||||||||
| * there's no support for chains; this is a direct LLM interface | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Notes: | ||||||||||||||||||||||||||||||||
| * Not all providers support all parameters (e.g., temperature, max_tokens). | ||||||||||||||||||||||||||||||||
| We pass only non-None params; providers ignore what they don't support. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | { | ||||||||||||||||||||||||||||||||
| "max_tokens": None, | ||||||||||||||||||||||||||||||||
| "top_p": None, | ||||||||||||||||||||||||||||||||
| "stop": [], | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| generator_family_name = "llm" | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def __init__(self, name: str = "", config_root=_config): | ||||||||||||||||||||||||||||||||
| self.target = None | ||||||||||||||||||||||||||||||||
| self.name = name | ||||||||||||||||||||||||||||||||
| self._load_config(config_root) | ||||||||||||||||||||||||||||||||
| self.fullname = f"llm (simonw/llm) {self.name or '(default)'}" | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| self._load_client() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| super().__init__(self.name, config_root=config_root) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| self._clear_client() | ||||||||||||||||||||||||||||||||
|
Comment on lines
+52
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may be best to call
Suggested change
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def __getstate__(self) -> object: | ||||||||||||||||||||||||||||||||
| self._clear_client() | ||||||||||||||||||||||||||||||||
| return dict(self.__dict__) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def __setstate__(self, data: dict) -> None: | ||||||||||||||||||||||||||||||||
| self.__dict__.update(data) | ||||||||||||||||||||||||||||||||
| self._load_client() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _load_client(self) -> None: | ||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||
| self.target = llm.get_model(self.name) if self.name else llm.get_model() | ||||||||||||||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||||||||||||||
| logging.error( | ||||||||||||||||||||||||||||||||
| "Failed to resolve `llm` target '%s': %s", self.name, repr(exc) | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| raise | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _clear_client(self) -> None: | ||||||||||||||||||||||||||||||||
| self.target = None | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _call_model( | ||||||||||||||||||||||||||||||||
| self, prompt: Conversation, generations_this_call: int = 1 | ||||||||||||||||||||||||||||||||
| ) -> List[Union[Message, None]]: | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| Continuation generation method for LLM integrations via `llm`. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| This calls model.prompt() once per generation and materializes the text(). | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| if self.target is None: | ||||||||||||||||||||||||||||||||
| self._load_client() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| system_turns = [turn for turn in prompt.turns if turn.role == "system"] | ||||||||||||||||||||||||||||||||
| user_turns = [turn for turn in prompt.turns if turn.role == "user"] | ||||||||||||||||||||||||||||||||
| assistant_turns = [turn for turn in prompt.turns if turn.role == "assistant"] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if assistant_turns: | ||||||||||||||||||||||||||||||||
| raise ValueError("llm generator does not accept assistant turns") | ||||||||||||||||||||||||||||||||
| if len(system_turns) > 1: | ||||||||||||||||||||||||||||||||
| raise ValueError("llm generator supports at most one system turn") | ||||||||||||||||||||||||||||||||
| if len(user_turns) != 1: | ||||||||||||||||||||||||||||||||
| raise ValueError("llm generator requires exactly one user turn") | ||||||||||||||||||||||||||||||||
|
Comment on lines
+93
to
+98
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unsupported prompts in
Suggested change
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| text_prompt = prompt.last_message("user").text | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Build kwargs only for parameters explicitly set (non-None / non-empty) | ||||||||||||||||||||||||||||||||
| prompt_kwargs = { | ||||||||||||||||||||||||||||||||
| key: getattr(self, key) | ||||||||||||||||||||||||||||||||
| for key in ("max_tokens", "temperature", "top_p") | ||||||||||||||||||||||||||||||||
| if getattr(self, key) is not None | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| if self.stop: | ||||||||||||||||||||||||||||||||
| prompt_kwargs["stop"] = self.stop | ||||||||||||||||||||||||||||||||
|
Comment on lines
+103
to
+109
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this inspect the accepted arguments to |
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||
| response = self.target.prompt(text_prompt, **prompt_kwargs) | ||||||||||||||||||||||||||||||||
| out = response.text() | ||||||||||||||||||||||||||||||||
| return [Message(out)] | ||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||
| logging.error("`llm` generation failed: %s", repr(e)) | ||||||||||||||||||||||||||||||||
| return [None] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| DEFAULT_CLASS = "LLMGenerator" | ||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should land #1199 before landing this, and then move this PR to the deferred loading pattern |
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should land #1199 before landing this, and then move this PR to the deferred loading pattern |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,170 @@ | ||||||||||||
| # SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & | ||||||||||||
| # AFFILIATES. All rights reserved. | ||||||||||||
|
Comment on lines
+1
to
+2
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| # SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||
|
|
||||||||||||
| """Tests for simonw/llm-backed garak generator""" | ||||||||||||
|
|
||||||||||||
| import pytest | ||||||||||||
| from unittest.mock import MagicMock | ||||||||||||
|
|
||||||||||||
| from garak.attempt import Conversation, Turn, Message | ||||||||||||
| from garak._config import GarakSubConfig | ||||||||||||
|
|
||||||||||||
| # Adjust import path/module name to where you placed the wrapper | ||||||||||||
| from garak.generators.llm import LLMGenerator | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ─── Helpers & Fixtures ───────────────────────────────────────────────── | ||||||||||||
|
|
||||||||||||
| class FakeResponse: | ||||||||||||
| """Minimal `llm` Response shim with .text()""" | ||||||||||||
| def __init__(self, txt: str): | ||||||||||||
| self._txt = txt | ||||||||||||
| def text(self) -> str: | ||||||||||||
| return self._txt | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class FakeModel: | ||||||||||||
| """Minimal `llm` model shim with .prompt()""" | ||||||||||||
| def __init__(self): | ||||||||||||
| self.calls = [] | ||||||||||||
| def prompt(self, prompt_text: str, **kwargs): | ||||||||||||
| self.calls.append((prompt_text, kwargs)) | ||||||||||||
| return FakeResponse("OK_FAKE") | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.fixture | ||||||||||||
| def cfg(): | ||||||||||||
| """Minimal garak sub-config; extend if you wire defaults via config.""" | ||||||||||||
| c = GarakSubConfig() | ||||||||||||
| c.generators = {} | ||||||||||||
| return c | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.fixture | ||||||||||||
| def fake_llm(monkeypatch): | ||||||||||||
| """ | ||||||||||||
| Patch llm.get_model to return a fresh FakeModel for each test. | ||||||||||||
| Return the FakeModel so tests can inspect call args. | ||||||||||||
| """ | ||||||||||||
| import llm | ||||||||||||
| model = FakeModel() | ||||||||||||
| monkeypatch.setattr(llm, "get_model", lambda *a, **k: model) | ||||||||||||
| return model | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ─── Tests ────────────────────────────────────────────────────────────── | ||||||||||||
|
|
||||||||||||
| def test_instantiation_resolves_model(cfg, fake_llm): | ||||||||||||
| gen = LLMGenerator(name="my-alias", config_root=cfg) | ||||||||||||
| assert gen.name == "my-alias" | ||||||||||||
|
Comment on lines
+60
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| assert hasattr(gen, "target") | ||||||||||||
| assert "llm (simonw/llm)" in gen.fullname | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_generate_returns_message(cfg, fake_llm): | ||||||||||||
| gen = LLMGenerator(name="alias", config_root=cfg) | ||||||||||||
|
|
||||||||||||
| test_txt = "ping" | ||||||||||||
| conv = Conversation([Turn("user", Message(text=test_txt))]) | ||||||||||||
| out = gen._call_model(conv) | ||||||||||||
|
|
||||||||||||
| assert isinstance(out, list) and len(out) == 1 | ||||||||||||
| assert isinstance(out[0], Message) | ||||||||||||
| assert out[0].text == "OK_FAKE" | ||||||||||||
|
|
||||||||||||
| prompt_text, kwargs = fake_llm.calls[0] | ||||||||||||
| assert prompt_text == test_txt | ||||||||||||
| assert kwargs == {} | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_param_passthrough(cfg, fake_llm): | ||||||||||||
| gen = LLMGenerator(name="alias", config_root=cfg) | ||||||||||||
| temperature = 0.2 | ||||||||||||
| max_tokens = 64 | ||||||||||||
| top_p = 0.9 | ||||||||||||
| stop = ["\n\n"] | ||||||||||||
|
|
||||||||||||
| gen.temperature = temperature | ||||||||||||
| gen.max_tokens = max_tokens | ||||||||||||
| gen.top_p = top_p | ||||||||||||
| gen.stop = stop | ||||||||||||
|
|
||||||||||||
| conv = Conversation([Turn("user", Message(text="hello"))]) | ||||||||||||
| _ = gen._call_model(conv) | ||||||||||||
|
|
||||||||||||
| _, kwargs = fake_llm.calls[0] | ||||||||||||
| assert kwargs["temperature"] == temperature | ||||||||||||
| assert kwargs["max_tokens"] == max_tokens | ||||||||||||
| assert kwargs["top_p"] == top_p | ||||||||||||
| assert kwargs["stop"] == stop | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_wrapper_handles_llm_exception(cfg, monkeypatch): | ||||||||||||
| """If the underlying `llm` call explodes, wrapper returns [None].""" | ||||||||||||
| import llm | ||||||||||||
| class BoomModel: | ||||||||||||
| def prompt(self, *a, **k): | ||||||||||||
| raise RuntimeError("boom") | ||||||||||||
| monkeypatch.setattr(llm, "get_model", lambda *a, **k: BoomModel()) | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||||||||||||
|
|
||||||||||||
| gen = LLMGenerator(name="alias", config_root=cfg) | ||||||||||||
| conv = Conversation([Turn("user", Message(text="ping"))]) | ||||||||||||
| out = gen._call_model(conv) | ||||||||||||
| assert out == [None] | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_default_model_when_name_empty(cfg, fake_llm, monkeypatch): | ||||||||||||
| """ | ||||||||||||
| If name is empty, wrapper should call llm.get_model() with no args, | ||||||||||||
| i.e., use llm's configured default model. | ||||||||||||
| """ | ||||||||||||
| import llm | ||||||||||||
| spy = MagicMock(side_effect=lambda *a, **k: fake_llm) | ||||||||||||
| monkeypatch.setattr(llm, "get_model", spy) | ||||||||||||
|
|
||||||||||||
| gen = LLMGenerator(name="", config_root=cfg) | ||||||||||||
| conv = Conversation([Turn("user", Message(text="x"))]) | ||||||||||||
| _ = gen._call_model(conv) | ||||||||||||
|
|
||||||||||||
| spy.assert_called_once() | ||||||||||||
| assert spy.call_args.args == () | ||||||||||||
| assert spy.call_args.kwargs == {} | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_rejects_multiple_user_turns(cfg, fake_llm): | ||||||||||||
| gen = LLMGenerator(name="alias", config_root=cfg) | ||||||||||||
| user_turns = [ | ||||||||||||
| Turn("user", Message(text="first")), | ||||||||||||
| Turn("user", Message(text="second")), | ||||||||||||
| ] | ||||||||||||
| conv = Conversation(user_turns) | ||||||||||||
| with pytest.raises(ValueError): | ||||||||||||
| gen._call_model(conv) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_rejects_assistant_turns(cfg, fake_llm): | ||||||||||||
| gen = LLMGenerator(name="alias", config_root=cfg) | ||||||||||||
| conv = Conversation( | ||||||||||||
| [ | ||||||||||||
| Turn("system", Message(text="system prompt")), | ||||||||||||
| Turn("assistant", Message(text="historic reply")), | ||||||||||||
| Turn("user", Message(text="question")), | ||||||||||||
| ] | ||||||||||||
| ) | ||||||||||||
| with pytest.raises(ValueError): | ||||||||||||
| gen._call_model(conv) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_rejects_multiple_system_turns(cfg, fake_llm): | ||||||||||||
| gen = LLMGenerator(name="alias", config_root=cfg) | ||||||||||||
| conv = Conversation( | ||||||||||||
| [ | ||||||||||||
| Turn("system", Message(text="one")), | ||||||||||||
| Turn("system", Message(text="two")), | ||||||||||||
| Turn("user", Message(text="ping")), | ||||||||||||
| ] | ||||||||||||
| ) | ||||||||||||
| with pytest.raises(ValueError): | ||||||||||||
| gen._call_model(conv) | ||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's worth implementing the
_load_client()/_clear_client()pattern here to support parallelisation - seeopenai.OpenAICompatiblefor an example