Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/garak.generators.llm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
garak.generators.llm
==========================

.. automodule:: garak.generators.llm
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/generators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ For a detailed oversight into how a generator operates, see :doc:`garak.generato
garak.generators.langchain
garak.generators.langchain_serve
garak.generators.litellm
garak.generators.llm
garak.generators.mistral
garak.generators.ollama
garak.generators.openai
Expand Down
154 changes: 154 additions & 0 deletions garak/generators/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""LLM (simonw/llm) generator support"""

import inspect
import logging
from typing import List, Union

import llm

from garak import _config
from garak.attempt import Message, Conversation
from garak.generators.base import Generator


class LLMGenerator(Generator):
"""Class supporting simonw/llm-managed models

See https://pypi.org/project/llm/ and its provider plugins.

Calls model.prompt() with the prompt text and relays the response. Per-provider
options and API keys are all handled by `llm` (e.g., `llm keys set openai`).

Set --model_name to the `llm` model id or alias (e.g., "gpt-4o-mini",
"claude-3.5-haiku", or a local alias configured in `llm models`).
Set --target_name to the `llm` model id or alias (e.g., "gpt-4o-mini",
"claude-3.5-haiku", or a local alias configured in `llm models`).
Comment on lines +25 to +28
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There looks to be some duplication in this section.

Suggested change
Set --model_name to the `llm` model id or alias (e.g., "gpt-4o-mini",
"claude-3.5-haiku", or a local alias configured in `llm models`).
Set --target_name to the `llm` model id or alias (e.g., "gpt-4o-mini",
"claude-3.5-haiku", or a local alias configured in `llm models`).
Set --target_name to the `llm` model id or alias (e.g., "gpt-4o-mini",
"claude-3.5-haiku", or a local alias configured in `llm models`).


Explicitly, garak delegates the majority of responsibility here:

* the generator calls prompt() on the resolved `llm` model
* provider setup, auth, and model-specific options live in `llm`
* there's no support for chains; this is a direct LLM interface

Notes:
* Not all providers support all parameters (e.g., temperature, max_tokens).
We pass only non-None params; providers ignore what they don't support.
"""

DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"max_tokens": None,
"top_p": None,
"stop": [],
}

generator_family_name = "llm"

def __init__(self, name: str = "", config_root=_config):
self.target = None
self.name = name
self._load_config(config_root)
self.fullname = f"llm (simonw/llm) {self.name or '(default)'}"

super().__init__(self.name, config_root=config_root)
self._load_client()

def __getstate__(self) -> object:
self._clear_client()
return dict(self.__dict__)

def __setstate__(self, data: dict) -> None:
self.__dict__.update(data)
self._load_client()

def _load_client(self) -> None:
try:
self.target = llm.get_model(self.name) if self.name else llm.get_model()
except Exception as exc:
logging.error(
"Failed to resolve `llm` target '%s': %s", self.name, repr(exc)
)
raise

def _clear_client(self) -> None:
self.target = None

def _call_model(
self, prompt: Conversation, generations_this_call: int = 1
) -> List[Union[Message, None]]:
"""
Continuation generation method for LLM integrations via `llm`.

This calls model.prompt() once per generation and materializes the text().
"""
if self.target is None:
self._load_client()

system_turns = [turn for turn in prompt.turns if turn.role == "system"]
user_turns = [turn for turn in prompt.turns if turn.role == "user"]
assistant_turns = [turn for turn in prompt.turns if turn.role == "assistant"]

if assistant_turns:
logging.debug("llm generator does not accept assistant turns")
return [None] * generations_this_call
if len(system_turns) > 1:
logging.debug("llm generator supports at most one system turn")
return [None] * generations_this_call
if len(user_turns) != 1:
logging.debug("llm generator requires exactly one user turn")
return [None] * generations_this_call

text_prompt = prompt.last_message("user").text

prompt_kwargs = {}
try:
signature = inspect.signature(self.target.prompt)
accepted_params = signature.parameters
accepts_var_kwargs = any(
param.kind == inspect.Parameter.VAR_KEYWORD
for param in accepted_params.values()
)
except (TypeError, ValueError):
accepted_params = {}
accepts_var_kwargs = False

if accepted_params:
for key, param in accepted_params.items():
if key in {"prompt", "prompt_text", "text", "self"}:
continue
if not hasattr(self, key):
continue
value = getattr(self, key)
if value is None:
continue
if key == "stop" and not value:
continue
prompt_kwargs[key] = value

# Fallback to a conservative parameter subset if signature inspection fails
# or the target accepts arbitrary kwargs (so we should pass anything we have)
fallback_keys = ("max_tokens", "temperature", "top_p")
needs_fallback = not prompt_kwargs or accepts_var_kwargs or not accepted_params
if needs_fallback:
for key in fallback_keys:
if key in prompt_kwargs:
continue
value = getattr(self, key, None)
if value is not None:
prompt_kwargs[key] = value
stop_value = getattr(self, "stop", None)
if stop_value:
prompt_kwargs.setdefault("stop", stop_value)

try:
response = self.target.prompt(text_prompt, **prompt_kwargs)
out = response.text()
return [Message(out)]
except Exception as e:
logging.error("`llm` generation failed: %s", repr(e))
return [None]


DEFAULT_CLASS = "LLMGenerator"
6 changes: 3 additions & 3 deletions garak/probes/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ class InjectAtbash(EncodingMixin, garak.probes.Probe):
encoding_name = "Atbash"
active = True
doc_uri = "https://en.wikipedia.org/wiki/Atbash"
encoding_funcs = []

@staticmethod
def atbash(text: bytes) -> bytes:
Expand All @@ -485,9 +486,8 @@ def atbash(text: bytes) -> bytes:
out.append(ch)
return "".join(out).encode("utf-8")

encoding_funcs = [atbash]

def __init__(self, config_root=None):
def __init__(self, config_root=_config):
self.encoding_funcs = [self.atbash]
garak.probes.Probe.__init__(self, config_root=config_root)
EncodingMixin.__init__(self)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should land #1199 before landing this, and then move this PR to the deferred loading pattern

Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ dependencies = [
"mistralai==1.5.2",
"pillow>=10.4.0",
"ftfy>=6.3.1",
"llm>=0.11",
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should land #1199 before landing this, and then move this PR to the deferred loading pattern

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ backoff>=2.1.1
rapidfuzz>=3.0.0
jinja2>=3.1.6
nltk>=3.9.1
llm>=0.11
accelerate>=0.23.0
avidtools==0.1.2
stdlibs>=2022.10.9
Expand Down
173 changes: 173 additions & 0 deletions tests/generators/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved.
Comment on lines +1 to +2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Comment on lines +1 to +3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again duplication to be removed.

Suggested change
# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

# SPDX-License-Identifier: Apache-2.0

"""Tests for simonw/llm-backed garak generator"""

import pytest
from unittest.mock import MagicMock

from garak.attempt import Conversation, Turn, Message
from garak._config import GarakSubConfig

# Adjust import path/module name to where you placed the wrapper
from garak.generators.llm import LLMGenerator


# ─── Helpers & Fixtures ─────────────────────────────────────────────────

class FakeResponse:
"""Minimal `llm` Response shim with .text()"""
def __init__(self, txt: str):
self._txt = txt
def text(self) -> str:
return self._txt


class FakeModel:
"""Minimal `llm` model shim with .prompt()"""
def __init__(self):
self.calls = []
def prompt(self, prompt_text: str, **kwargs):
self.calls.append((prompt_text, kwargs))
return FakeResponse("OK_FAKE")


@pytest.fixture
def cfg():
"""Minimal garak sub-config; extend if you wire defaults via config."""
c = GarakSubConfig()
c.generators = {}
return c


@pytest.fixture
def fake_llm(monkeypatch):
"""
Patch llm.get_model to return a fresh FakeModel for each test.
Return the FakeModel so tests can inspect call args.
"""
import llm
model = FakeModel()
monkeypatch.setattr(llm, "get_model", lambda *a, **k: model)
return model


# ─── Tests ──────────────────────────────────────────────────────────────

def test_instantiation_resolves_model(cfg, fake_llm):
gen = LLMGenerator(name="my-alias", config_root=cfg)
assert gen.name == "my-alias"
Comment on lines +60 to +61
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
gen = LLMGenerator(name="my-alias", config_root=cfg)
assert gen.name == "my-alias"
test_name = "my-alias"
gen = LLMGenerator(name=test_name, config_root=cfg)
assert gen.name == test_name

assert hasattr(gen, "target")
assert "llm (simonw/llm)" in gen.fullname


def test_generate_returns_message(cfg, fake_llm):
gen = LLMGenerator(name="alias", config_root=cfg)

test_txt = "ping"
conv = Conversation([Turn("user", Message(text=test_txt))])
out = gen._call_model(conv)

assert isinstance(out, list) and len(out) == 1
assert isinstance(out[0], Message)
assert out[0].text == "OK_FAKE"

prompt_text, kwargs = fake_llm.calls[0]
assert prompt_text == test_txt
assert kwargs == {}


def test_param_passthrough(cfg, fake_llm):
gen = LLMGenerator(name="alias", config_root=cfg)
temperature = 0.2
max_tokens = 64
top_p = 0.9
stop = ["\n\n"]

gen.temperature = temperature
gen.max_tokens = max_tokens
gen.top_p = top_p
gen.stop = stop

conv = Conversation([Turn("user", Message(text="hello"))])
_ = gen._call_model(conv)

_, kwargs = fake_llm.calls[0]
assert kwargs["temperature"] == temperature
assert kwargs["max_tokens"] == max_tokens
assert kwargs["top_p"] == top_p
assert kwargs["stop"] == stop


def test_wrapper_handles_llm_exception(cfg, monkeypatch):
"""If the underlying `llm` call explodes, wrapper returns [None]."""
import llm
class BoomModel:
def prompt(self, *a, **k):
raise RuntimeError("boom")
monkeypatch.setattr(llm, "get_model", lambda *a, **k: BoomModel())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice


gen = LLMGenerator(name="alias", config_root=cfg)
conv = Conversation([Turn("user", Message(text="ping"))])
out = gen._call_model(conv)
assert out == [None]


def test_default_model_when_name_empty(cfg, fake_llm, monkeypatch):
"""
If name is empty, wrapper should call llm.get_model() with no args,
i.e., use llm's configured default model.
"""
import llm
spy = MagicMock(side_effect=lambda *a, **k: fake_llm)
monkeypatch.setattr(llm, "get_model", spy)

gen = LLMGenerator(name="", config_root=cfg)
conv = Conversation([Turn("user", Message(text="x"))])
_ = gen._call_model(conv)

spy.assert_called_once()
assert spy.call_args.args == ()
assert spy.call_args.kwargs == {}


def test_skips_multiple_user_turns(cfg, fake_llm):
gen = LLMGenerator(name="alias", config_root=cfg)
user_turns = [
Turn("user", Message(text="first")),
Turn("user", Message(text="second")),
]
conv = Conversation(user_turns)
out = gen._call_model(conv)
assert out == [None]
assert fake_llm.calls == []


def test_skips_assistant_turns(cfg, fake_llm):
gen = LLMGenerator(name="alias", config_root=cfg)
conv = Conversation(
[
Turn("system", Message(text="system prompt")),
Turn("assistant", Message(text="historic reply")),
Turn("user", Message(text="question")),
]
)
out = gen._call_model(conv)
assert out == [None]
assert fake_llm.calls == []


def test_skips_multiple_system_turns(cfg, fake_llm):
gen = LLMGenerator(name="alias", config_root=cfg)
conv = Conversation(
[
Turn("system", Message(text="one")),
Turn("system", Message(text="two")),
Turn("user", Message(text="ping")),
]
)
out = gen._call_model(conv)
assert out == [None]
assert fake_llm.calls == []
2 changes: 1 addition & 1 deletion tests/probes/test_probes_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_encoding_prompt_trigger_match(classname):

@pytest.mark.parametrize(
"classname",
[classname for classname in ENCODING_PROBES if not CLEAR_TRIGGER_PROBES],
[classname for classname in ENCODING_PROBES if classname not in CLEAR_TRIGGER_PROBES],
)
def test_encoding_triggers_not_in_prompts(classname):
p = _plugins.load_plugin(classname)
Expand Down
Loading