Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/garak.generators.llm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
garak.generators.llm
==========================

.. automodule:: garak.generators.llm
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/generators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ For a detailed oversight into how a generator operates, see :doc:`garak.generato
garak.generators.langchain
garak.generators.langchain_serve
garak.generators.litellm
garak.generators.llm
garak.generators.mistral
garak.generators.ollama
garak.generators.openai
Expand Down
120 changes: 120 additions & 0 deletions garak/generators/llm.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth implementing the _load_client() / _clear_client() pattern here to support parallelisation - see openai.OpenAICompatible for an example

Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""LLM (simonw/llm) generator support"""

import logging
from typing import List, Union

import llm

from garak import _config
from garak.attempt import Message, Conversation
from garak.generators.base import Generator


class LLMGenerator(Generator):
"""Class supporting simonw/llm-managed models

See https://pypi.org/project/llm/ and its provider plugins.

Calls model.prompt() with the prompt text and relays the response. Per-provider
options and API keys are all handled by `llm` (e.g., `llm keys set openai`).

Set --model_name to the `llm` model id or alias (e.g., "gpt-4o-mini",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Set --model_name to the `llm` model id or alias (e.g., "gpt-4o-mini",
Set --target_name to the `llm` model id or alias (e.g., "gpt-4o-mini",

"claude-3.5-haiku", or a local alias configured in `llm models`).

Explicitly, garak delegates the majority of responsibility here:

* the generator calls prompt() on the resolved `llm` model
* provider setup, auth, and model-specific options live in `llm`
* there's no support for chains; this is a direct LLM interface

Notes:
* Not all providers support all parameters (e.g., temperature, max_tokens).
We pass only non-None params; providers ignore what they don't support.
"""

DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"max_tokens": None,
"top_p": None,
"stop": [],
}

generator_family_name = "llm"

def __init__(self, name: str = "", config_root=_config):
self.target = None
self.name = name
self._load_config(config_root)
self.fullname = f"llm (simonw/llm) {self.name or '(default)'}"

self._load_client()

super().__init__(self.name, config_root=config_root)

self._clear_client()
Comment on lines +52 to +56
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be best to call super().__init__() before calling _load_client() to ensure the end object state is not impacted. Also no need to call _clear_client() during __init__().

Suggested change
self._load_client()
super().__init__(self.name, config_root=config_root)
self._clear_client()
super().__init__(self.name, config_root=config_root)
self._load_client()


def __getstate__(self) -> object:
self._clear_client()
return dict(self.__dict__)

def __setstate__(self, data: dict) -> None:
self.__dict__.update(data)
self._load_client()

def _load_client(self) -> None:
try:
self.target = llm.get_model(self.name) if self.name else llm.get_model()
except Exception as exc:
logging.error(
"Failed to resolve `llm` target '%s': %s", self.name, repr(exc)
)
raise

def _clear_client(self) -> None:
self.target = None

def _call_model(
self, prompt: Conversation, generations_this_call: int = 1
) -> List[Union[Message, None]]:
"""
Continuation generation method for LLM integrations via `llm`.

This calls model.prompt() once per generation and materializes the text().
"""
if self.target is None:
self._load_client()

system_turns = [turn for turn in prompt.turns if turn.role == "system"]
user_turns = [turn for turn in prompt.turns if turn.role == "user"]
assistant_turns = [turn for turn in prompt.turns if turn.role == "assistant"]

if assistant_turns:
raise ValueError("llm generator does not accept assistant turns")
if len(system_turns) > 1:
raise ValueError("llm generator supports at most one system turn")
if len(user_turns) != 1:
raise ValueError("llm generator requires exactly one user turn")
Comment on lines +93 to +98
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsupported prompts in _call_model() are currently expected to return None to avoid early termination of the test run. While I understand the thought process of using raise here, currently logging the reason for skipping the prompt would align better.

import logging
Suggested change
if assistant_turns:
raise ValueError("llm generator does not accept assistant turns")
if len(system_turns) > 1:
raise ValueError("llm generator supports at most one system turn")
if len(user_turns) != 1:
raise ValueError("llm generator requires exactly one user turn")
if assistant_turns:
logging.debug("llm generator does not accept assistant turns")
return [None] * generations_this_call
if len(system_turns) > 1:
logging.debug("llm generator supports at most one system turn")
return [None] * generations_this_call
if len(user_turns) != 1:
logging.debug("llm generator requires exactly one user turn")
return [None] * generations_this_call


text_prompt = prompt.last_message("user").text

# Build kwargs only for parameters explicitly set (non-None / non-empty)
prompt_kwargs = {
key: getattr(self, key)
for key in ("max_tokens", "temperature", "top_p")
if getattr(self, key) is not None
}
if self.stop:
prompt_kwargs["stop"] = self.stop
Comment on lines +103 to +109
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this inspect the accepted arguments to self.target.prompt() vs a hard coded list here? Something similar exists in the OpenAICompatible class, where we collect all options set on the generator that the target API accepts.


try:
response = self.target.prompt(text_prompt, **prompt_kwargs)
out = response.text()
return [Message(out)]
except Exception as e:
logging.error("`llm` generation failed: %s", repr(e))
return [None]


DEFAULT_CLASS = "LLMGenerator"
6 changes: 3 additions & 3 deletions garak/probes/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ class InjectAtbash(EncodingMixin, garak.probes.Probe):
encoding_name = "Atbash"
active = True
doc_uri = "https://en.wikipedia.org/wiki/Atbash"
encoding_funcs = []

@staticmethod
def atbash(text: bytes) -> bytes:
Expand All @@ -485,9 +486,8 @@ def atbash(text: bytes) -> bytes:
out.append(ch)
return "".join(out).encode("utf-8")

encoding_funcs = [atbash]

def __init__(self, config_root=None):
def __init__(self, config_root=_config):
self.encoding_funcs = [self.atbash]
garak.probes.Probe.__init__(self, config_root=config_root)
EncodingMixin.__init__(self)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should land #1199 before landing this, and then move this PR to the deferred loading pattern

Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ dependencies = [
"mistralai==1.5.2",
"pillow>=10.4.0",
"ftfy>=6.3.1",
"llm>=0.11",
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should land #1199 before landing this, and then move this PR to the deferred loading pattern

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ backoff>=2.1.1
rapidfuzz>=3.0.0
jinja2>=3.1.6
nltk>=3.9.1
llm>=0.11
accelerate>=0.23.0
avidtools==0.1.2
stdlibs>=2022.10.9
Expand Down
170 changes: 170 additions & 0 deletions tests/generators/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved.
Comment on lines +1 to +2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

# SPDX-FileCopyrightText: Portions Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Tests for simonw/llm-backed garak generator"""

import pytest
from unittest.mock import MagicMock

from garak.attempt import Conversation, Turn, Message
from garak._config import GarakSubConfig

# Adjust import path/module name to where you placed the wrapper
from garak.generators.llm import LLMGenerator


# ─── Helpers & Fixtures ─────────────────────────────────────────────────

class FakeResponse:
"""Minimal `llm` Response shim with .text()"""
def __init__(self, txt: str):
self._txt = txt
def text(self) -> str:
return self._txt


class FakeModel:
"""Minimal `llm` model shim with .prompt()"""
def __init__(self):
self.calls = []
def prompt(self, prompt_text: str, **kwargs):
self.calls.append((prompt_text, kwargs))
return FakeResponse("OK_FAKE")


@pytest.fixture
def cfg():
"""Minimal garak sub-config; extend if you wire defaults via config."""
c = GarakSubConfig()
c.generators = {}
return c


@pytest.fixture
def fake_llm(monkeypatch):
"""
Patch llm.get_model to return a fresh FakeModel for each test.
Return the FakeModel so tests can inspect call args.
"""
import llm
model = FakeModel()
monkeypatch.setattr(llm, "get_model", lambda *a, **k: model)
return model


# ─── Tests ──────────────────────────────────────────────────────────────

def test_instantiation_resolves_model(cfg, fake_llm):
gen = LLMGenerator(name="my-alias", config_root=cfg)
assert gen.name == "my-alias"
Comment on lines +60 to +61
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
gen = LLMGenerator(name="my-alias", config_root=cfg)
assert gen.name == "my-alias"
test_name = "my-alias"
gen = LLMGenerator(name=test_name, config_root=cfg)
assert gen.name == test_name

assert hasattr(gen, "target")
assert "llm (simonw/llm)" in gen.fullname


def test_generate_returns_message(cfg, fake_llm):
gen = LLMGenerator(name="alias", config_root=cfg)

test_txt = "ping"
conv = Conversation([Turn("user", Message(text=test_txt))])
out = gen._call_model(conv)

assert isinstance(out, list) and len(out) == 1
assert isinstance(out[0], Message)
assert out[0].text == "OK_FAKE"

prompt_text, kwargs = fake_llm.calls[0]
assert prompt_text == test_txt
assert kwargs == {}


def test_param_passthrough(cfg, fake_llm):
gen = LLMGenerator(name="alias", config_root=cfg)
temperature = 0.2
max_tokens = 64
top_p = 0.9
stop = ["\n\n"]

gen.temperature = temperature
gen.max_tokens = max_tokens
gen.top_p = top_p
gen.stop = stop

conv = Conversation([Turn("user", Message(text="hello"))])
_ = gen._call_model(conv)

_, kwargs = fake_llm.calls[0]
assert kwargs["temperature"] == temperature
assert kwargs["max_tokens"] == max_tokens
assert kwargs["top_p"] == top_p
assert kwargs["stop"] == stop


def test_wrapper_handles_llm_exception(cfg, monkeypatch):
"""If the underlying `llm` call explodes, wrapper returns [None]."""
import llm
class BoomModel:
def prompt(self, *a, **k):
raise RuntimeError("boom")
monkeypatch.setattr(llm, "get_model", lambda *a, **k: BoomModel())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice


gen = LLMGenerator(name="alias", config_root=cfg)
conv = Conversation([Turn("user", Message(text="ping"))])
out = gen._call_model(conv)
assert out == [None]


def test_default_model_when_name_empty(cfg, fake_llm, monkeypatch):
"""
If name is empty, wrapper should call llm.get_model() with no args,
i.e., use llm's configured default model.
"""
import llm
spy = MagicMock(side_effect=lambda *a, **k: fake_llm)
monkeypatch.setattr(llm, "get_model", spy)

gen = LLMGenerator(name="", config_root=cfg)
conv = Conversation([Turn("user", Message(text="x"))])
_ = gen._call_model(conv)

spy.assert_called_once()
assert spy.call_args.args == ()
assert spy.call_args.kwargs == {}


def test_rejects_multiple_user_turns(cfg, fake_llm):
gen = LLMGenerator(name="alias", config_root=cfg)
user_turns = [
Turn("user", Message(text="first")),
Turn("user", Message(text="second")),
]
conv = Conversation(user_turns)
with pytest.raises(ValueError):
gen._call_model(conv)


def test_rejects_assistant_turns(cfg, fake_llm):
gen = LLMGenerator(name="alias", config_root=cfg)
conv = Conversation(
[
Turn("system", Message(text="system prompt")),
Turn("assistant", Message(text="historic reply")),
Turn("user", Message(text="question")),
]
)
with pytest.raises(ValueError):
gen._call_model(conv)


def test_rejects_multiple_system_turns(cfg, fake_llm):
gen = LLMGenerator(name="alias", config_root=cfg)
conv = Conversation(
[
Turn("system", Message(text="one")),
Turn("system", Message(text="two")),
Turn("user", Message(text="ping")),
]
)
with pytest.raises(ValueError):
gen._call_model(conv)
2 changes: 1 addition & 1 deletion tests/probes/test_probes_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_encoding_prompt_trigger_match(classname):

@pytest.mark.parametrize(
"classname",
[classname for classname in ENCODING_PROBES if not CLEAR_TRIGGER_PROBES],
[classname for classname in ENCODING_PROBES if classname not in CLEAR_TRIGGER_PROBES],
)
def test_encoding_triggers_not_in_prompts(classname):
p = _plugins.load_plugin(classname)
Expand Down
Loading