-
Notifications
You must be signed in to change notification settings - Fork 687
Wrapped LLM as a garak generator #1382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5f0b1af
a7d109b
6a54502
65645dc
094141d
7f20877
f2d730e
f87c2f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| garak.generators.llm | ||
| ========================== | ||
|
|
||
| .. automodule:: garak.generators.llm | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,154 @@ | ||||||||||||||
| # SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||
|
|
||||||||||||||
| """LLM (simonw/llm) generator support""" | ||||||||||||||
|
|
||||||||||||||
| import inspect | ||||||||||||||
| import logging | ||||||||||||||
| from typing import List, Union | ||||||||||||||
|
|
||||||||||||||
| import llm | ||||||||||||||
|
|
||||||||||||||
| from garak import _config | ||||||||||||||
| from garak.attempt import Message, Conversation | ||||||||||||||
| from garak.generators.base import Generator | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class LLMGenerator(Generator): | ||||||||||||||
| """Class supporting simonw/llm-managed models | ||||||||||||||
|
|
||||||||||||||
| See https://pypi.org/project/llm/ and its provider plugins. | ||||||||||||||
|
|
||||||||||||||
| Calls model.prompt() with the prompt text and relays the response. Per-provider | ||||||||||||||
| options and API keys are all handled by `llm` (e.g., `llm keys set openai`). | ||||||||||||||
|
|
||||||||||||||
| Set --model_name to the `llm` model id or alias (e.g., "gpt-4o-mini", | ||||||||||||||
jmartin-tech marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
| "claude-3.5-haiku", or a local alias configured in `llm models`). | ||||||||||||||
| Set --target_name to the `llm` model id or alias (e.g., "gpt-4o-mini", | ||||||||||||||
| "claude-3.5-haiku", or a local alias configured in `llm models`). | ||||||||||||||
|
Comment on lines
+25
to
+28
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There looks to be some duplication in this section.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| Explicitly, garak delegates the majority of responsibility here: | ||||||||||||||
|
|
||||||||||||||
| * the generator calls prompt() on the resolved `llm` model | ||||||||||||||
| * provider setup, auth, and model-specific options live in `llm` | ||||||||||||||
| * there's no support for chains; this is a direct LLM interface | ||||||||||||||
|
|
||||||||||||||
| Notes: | ||||||||||||||
| * Not all providers support all parameters (e.g., temperature, max_tokens). | ||||||||||||||
| We pass only non-None params; providers ignore what they don't support. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | { | ||||||||||||||
| "max_tokens": None, | ||||||||||||||
| "top_p": None, | ||||||||||||||
| "stop": [], | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| generator_family_name = "llm" | ||||||||||||||
|
|
||||||||||||||
| def __init__(self, name: str = "", config_root=_config): | ||||||||||||||
| self.target = None | ||||||||||||||
| self.name = name | ||||||||||||||
| self._load_config(config_root) | ||||||||||||||
| self.fullname = f"llm (simonw/llm) {self.name or '(default)'}" | ||||||||||||||
|
|
||||||||||||||
| super().__init__(self.name, config_root=config_root) | ||||||||||||||
| self._load_client() | ||||||||||||||
|
|
||||||||||||||
| def __getstate__(self) -> object: | ||||||||||||||
| self._clear_client() | ||||||||||||||
| return dict(self.__dict__) | ||||||||||||||
|
|
||||||||||||||
| def __setstate__(self, data: dict) -> None: | ||||||||||||||
| self.__dict__.update(data) | ||||||||||||||
| self._load_client() | ||||||||||||||
|
|
||||||||||||||
| def _load_client(self) -> None: | ||||||||||||||
| try: | ||||||||||||||
| self.target = llm.get_model(self.name) if self.name else llm.get_model() | ||||||||||||||
| except Exception as exc: | ||||||||||||||
| logging.error( | ||||||||||||||
| "Failed to resolve `llm` target '%s': %s", self.name, repr(exc) | ||||||||||||||
| ) | ||||||||||||||
| raise | ||||||||||||||
|
|
||||||||||||||
| def _clear_client(self) -> None: | ||||||||||||||
| self.target = None | ||||||||||||||
|
|
||||||||||||||
| def _call_model( | ||||||||||||||
| self, prompt: Conversation, generations_this_call: int = 1 | ||||||||||||||
| ) -> List[Union[Message, None]]: | ||||||||||||||
| """ | ||||||||||||||
| Continuation generation method for LLM integrations via `llm`. | ||||||||||||||
|
|
||||||||||||||
| This calls model.prompt() once per generation and materializes the text(). | ||||||||||||||
| """ | ||||||||||||||
| if self.target is None: | ||||||||||||||
| self._load_client() | ||||||||||||||
|
|
||||||||||||||
| system_turns = [turn for turn in prompt.turns if turn.role == "system"] | ||||||||||||||
| user_turns = [turn for turn in prompt.turns if turn.role == "user"] | ||||||||||||||
| assistant_turns = [turn for turn in prompt.turns if turn.role == "assistant"] | ||||||||||||||
|
|
||||||||||||||
| if assistant_turns: | ||||||||||||||
| logging.debug("llm generator does not accept assistant turns") | ||||||||||||||
| return [None] * generations_this_call | ||||||||||||||
| if len(system_turns) > 1: | ||||||||||||||
| logging.debug("llm generator supports at most one system turn") | ||||||||||||||
| return [None] * generations_this_call | ||||||||||||||
| if len(user_turns) != 1: | ||||||||||||||
| logging.debug("llm generator requires exactly one user turn") | ||||||||||||||
| return [None] * generations_this_call | ||||||||||||||
|
|
||||||||||||||
| text_prompt = prompt.last_message("user").text | ||||||||||||||
|
|
||||||||||||||
| prompt_kwargs = {} | ||||||||||||||
| try: | ||||||||||||||
| signature = inspect.signature(self.target.prompt) | ||||||||||||||
| accepted_params = signature.parameters | ||||||||||||||
| accepts_var_kwargs = any( | ||||||||||||||
| param.kind == inspect.Parameter.VAR_KEYWORD | ||||||||||||||
| for param in accepted_params.values() | ||||||||||||||
| ) | ||||||||||||||
| except (TypeError, ValueError): | ||||||||||||||
| accepted_params = {} | ||||||||||||||
| accepts_var_kwargs = False | ||||||||||||||
|
|
||||||||||||||
| if accepted_params: | ||||||||||||||
| for key, param in accepted_params.items(): | ||||||||||||||
| if key in {"prompt", "prompt_text", "text", "self"}: | ||||||||||||||
| continue | ||||||||||||||
| if not hasattr(self, key): | ||||||||||||||
| continue | ||||||||||||||
| value = getattr(self, key) | ||||||||||||||
| if value is None: | ||||||||||||||
| continue | ||||||||||||||
| if key == "stop" and not value: | ||||||||||||||
| continue | ||||||||||||||
| prompt_kwargs[key] = value | ||||||||||||||
|
|
||||||||||||||
| # Fallback to a conservative parameter subset if signature inspection fails | ||||||||||||||
| # or the target accepts arbitrary kwargs (so we should pass anything we have) | ||||||||||||||
| fallback_keys = ("max_tokens", "temperature", "top_p") | ||||||||||||||
| needs_fallback = not prompt_kwargs or accepts_var_kwargs or not accepted_params | ||||||||||||||
| if needs_fallback: | ||||||||||||||
| for key in fallback_keys: | ||||||||||||||
| if key in prompt_kwargs: | ||||||||||||||
| continue | ||||||||||||||
| value = getattr(self, key, None) | ||||||||||||||
| if value is not None: | ||||||||||||||
| prompt_kwargs[key] = value | ||||||||||||||
| stop_value = getattr(self, "stop", None) | ||||||||||||||
| if stop_value: | ||||||||||||||
| prompt_kwargs.setdefault("stop", stop_value) | ||||||||||||||
|
|
||||||||||||||
| try: | ||||||||||||||
| response = self.target.prompt(text_prompt, **prompt_kwargs) | ||||||||||||||
| out = response.text() | ||||||||||||||
| return [Message(out)] | ||||||||||||||
| except Exception as e: | ||||||||||||||
| logging.error("`llm` generation failed: %s", repr(e)) | ||||||||||||||
| return [None] | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| DEFAULT_CLASS = "LLMGenerator" | ||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should land #1199 before landing this, and then move this PR to the deferred loading pattern |
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should land #1199 before landing this, and then move this PR to the deferred loading pattern |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,173 @@ | ||||||||||||
| # SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & | ||||||||||||
| # AFFILIATES. All rights reserved. | ||||||||||||
|
Comment on lines
+1
to
+2
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| # SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||
|
Comment on lines
+1
to
+3
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again duplication to be removed.
Suggested change
|
||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||
|
|
||||||||||||
| """Tests for simonw/llm-backed garak generator""" | ||||||||||||
|
|
||||||||||||
| import pytest | ||||||||||||
| from unittest.mock import MagicMock | ||||||||||||
|
|
||||||||||||
| from garak.attempt import Conversation, Turn, Message | ||||||||||||
| from garak._config import GarakSubConfig | ||||||||||||
|
|
||||||||||||
| # Adjust import path/module name to where you placed the wrapper | ||||||||||||
| from garak.generators.llm import LLMGenerator | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ─── Helpers & Fixtures ───────────────────────────────────────────────── | ||||||||||||
|
|
||||||||||||
| class FakeResponse: | ||||||||||||
| """Minimal `llm` Response shim with .text()""" | ||||||||||||
| def __init__(self, txt: str): | ||||||||||||
| self._txt = txt | ||||||||||||
| def text(self) -> str: | ||||||||||||
| return self._txt | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class FakeModel: | ||||||||||||
| """Minimal `llm` model shim with .prompt()""" | ||||||||||||
| def __init__(self): | ||||||||||||
| self.calls = [] | ||||||||||||
| def prompt(self, prompt_text: str, **kwargs): | ||||||||||||
| self.calls.append((prompt_text, kwargs)) | ||||||||||||
| return FakeResponse("OK_FAKE") | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.fixture | ||||||||||||
| def cfg(): | ||||||||||||
| """Minimal garak sub-config; extend if you wire defaults via config.""" | ||||||||||||
| c = GarakSubConfig() | ||||||||||||
| c.generators = {} | ||||||||||||
| return c | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.fixture | ||||||||||||
| def fake_llm(monkeypatch): | ||||||||||||
| """ | ||||||||||||
| Patch llm.get_model to return a fresh FakeModel for each test. | ||||||||||||
| Return the FakeModel so tests can inspect call args. | ||||||||||||
| """ | ||||||||||||
| import llm | ||||||||||||
| model = FakeModel() | ||||||||||||
| monkeypatch.setattr(llm, "get_model", lambda *a, **k: model) | ||||||||||||
| return model | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ─── Tests ────────────────────────────────────────────────────────────── | ||||||||||||
|
|
||||||||||||
| def test_instantiation_resolves_model(cfg, fake_llm): | ||||||||||||
| gen = LLMGenerator(name="my-alias", config_root=cfg) | ||||||||||||
| assert gen.name == "my-alias" | ||||||||||||
|
Comment on lines
+60
to
+61
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| assert hasattr(gen, "target") | ||||||||||||
| assert "llm (simonw/llm)" in gen.fullname | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_generate_returns_message(cfg, fake_llm): | ||||||||||||
| gen = LLMGenerator(name="alias", config_root=cfg) | ||||||||||||
|
|
||||||||||||
| test_txt = "ping" | ||||||||||||
| conv = Conversation([Turn("user", Message(text=test_txt))]) | ||||||||||||
| out = gen._call_model(conv) | ||||||||||||
|
|
||||||||||||
| assert isinstance(out, list) and len(out) == 1 | ||||||||||||
| assert isinstance(out[0], Message) | ||||||||||||
| assert out[0].text == "OK_FAKE" | ||||||||||||
|
|
||||||||||||
| prompt_text, kwargs = fake_llm.calls[0] | ||||||||||||
| assert prompt_text == test_txt | ||||||||||||
| assert kwargs == {} | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_param_passthrough(cfg, fake_llm): | ||||||||||||
| gen = LLMGenerator(name="alias", config_root=cfg) | ||||||||||||
| temperature = 0.2 | ||||||||||||
| max_tokens = 64 | ||||||||||||
| top_p = 0.9 | ||||||||||||
| stop = ["\n\n"] | ||||||||||||
|
|
||||||||||||
| gen.temperature = temperature | ||||||||||||
| gen.max_tokens = max_tokens | ||||||||||||
| gen.top_p = top_p | ||||||||||||
| gen.stop = stop | ||||||||||||
|
|
||||||||||||
| conv = Conversation([Turn("user", Message(text="hello"))]) | ||||||||||||
| _ = gen._call_model(conv) | ||||||||||||
|
|
||||||||||||
| _, kwargs = fake_llm.calls[0] | ||||||||||||
| assert kwargs["temperature"] == temperature | ||||||||||||
| assert kwargs["max_tokens"] == max_tokens | ||||||||||||
| assert kwargs["top_p"] == top_p | ||||||||||||
| assert kwargs["stop"] == stop | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_wrapper_handles_llm_exception(cfg, monkeypatch): | ||||||||||||
| """If the underlying `llm` call explodes, wrapper returns [None].""" | ||||||||||||
| import llm | ||||||||||||
| class BoomModel: | ||||||||||||
| def prompt(self, *a, **k): | ||||||||||||
| raise RuntimeError("boom") | ||||||||||||
| monkeypatch.setattr(llm, "get_model", lambda *a, **k: BoomModel()) | ||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||||||||||||
|
|
||||||||||||
| gen = LLMGenerator(name="alias", config_root=cfg) | ||||||||||||
| conv = Conversation([Turn("user", Message(text="ping"))]) | ||||||||||||
| out = gen._call_model(conv) | ||||||||||||
| assert out == [None] | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_default_model_when_name_empty(cfg, fake_llm, monkeypatch): | ||||||||||||
| """ | ||||||||||||
| If name is empty, wrapper should call llm.get_model() with no args, | ||||||||||||
| i.e., use llm's configured default model. | ||||||||||||
| """ | ||||||||||||
| import llm | ||||||||||||
| spy = MagicMock(side_effect=lambda *a, **k: fake_llm) | ||||||||||||
| monkeypatch.setattr(llm, "get_model", spy) | ||||||||||||
|
|
||||||||||||
| gen = LLMGenerator(name="", config_root=cfg) | ||||||||||||
| conv = Conversation([Turn("user", Message(text="x"))]) | ||||||||||||
| _ = gen._call_model(conv) | ||||||||||||
|
|
||||||||||||
| spy.assert_called_once() | ||||||||||||
| assert spy.call_args.args == () | ||||||||||||
| assert spy.call_args.kwargs == {} | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_skips_multiple_user_turns(cfg, fake_llm): | ||||||||||||
| gen = LLMGenerator(name="alias", config_root=cfg) | ||||||||||||
| user_turns = [ | ||||||||||||
| Turn("user", Message(text="first")), | ||||||||||||
| Turn("user", Message(text="second")), | ||||||||||||
| ] | ||||||||||||
| conv = Conversation(user_turns) | ||||||||||||
| out = gen._call_model(conv) | ||||||||||||
| assert out == [None] | ||||||||||||
| assert fake_llm.calls == [] | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_skips_assistant_turns(cfg, fake_llm): | ||||||||||||
| gen = LLMGenerator(name="alias", config_root=cfg) | ||||||||||||
| conv = Conversation( | ||||||||||||
| [ | ||||||||||||
| Turn("system", Message(text="system prompt")), | ||||||||||||
| Turn("assistant", Message(text="historic reply")), | ||||||||||||
| Turn("user", Message(text="question")), | ||||||||||||
| ] | ||||||||||||
| ) | ||||||||||||
| out = gen._call_model(conv) | ||||||||||||
| assert out == [None] | ||||||||||||
| assert fake_llm.calls == [] | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_skips_multiple_system_turns(cfg, fake_llm): | ||||||||||||
| gen = LLMGenerator(name="alias", config_root=cfg) | ||||||||||||
| conv = Conversation( | ||||||||||||
| [ | ||||||||||||
| Turn("system", Message(text="one")), | ||||||||||||
| Turn("system", Message(text="two")), | ||||||||||||
| Turn("user", Message(text="ping")), | ||||||||||||
| ] | ||||||||||||
| ) | ||||||||||||
| out = gen._call_model(conv) | ||||||||||||
| assert out == [None] | ||||||||||||
| assert fake_llm.calls == [] | ||||||||||||
Uh oh!
There was an error while loading. Please reload this page.