Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: configuration based rest proxy support #1073

Merged
merged 3 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/garak.generators.rest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Uses the following options from ``_config.plugins.generators["rest.RestGenerator
* ``req_template_json_object`` - (optional) the request template as a Python object, to be serialised as a JSON string before replacements
* ``method`` - a string describing the HTTP method, to be passed to the requests module; default "post".
* ``headers`` - dict describing HTTP headers to be sent with the request
* ``proxies`` - dict passed to ``requests`` method call see `required format<https://requests.readthedocs.io/en/latest/user/advanced/#proxies">`_.
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved
* ``response_json`` - Is the response in JSON format? (bool)
* ``response_json_field`` - (optional) Which field of the response JSON should be used as the output string? Default ``text``. Can also be a JSONPath value, and ``response_json_field`` is used as such if it starts with ``$``.
* ``request_timeout`` - How many seconds should we wait before timing out? Default 20
Expand Down
13 changes: 12 additions & 1 deletion garak/generators/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from jsonpath_ng.exceptions import JsonPathParserError

from garak import _config
from garak.exception import APIKeyMissingError, RateLimitHit
from garak.exception import APIKeyMissingError, BadGeneratorException, RateLimitHit
from garak.generators.base import Generator


Expand All @@ -35,6 +35,7 @@ class RestGenerator(Generator):
"response_json_field": None,
"req_template": "$INPUT",
"request_timeout": 20,
"proxies": None,
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved
}

ENV_VAR = "REST_API_KEY"
Expand All @@ -59,6 +60,7 @@ class RestGenerator(Generator):
"skip_codes",
"temperature",
"top_k",
"proxies",
)

def __init__(self, uri=None, config_root=_config):
Expand Down Expand Up @@ -118,6 +120,14 @@ def __init__(self, uri=None, config_root=_config):
self.method = "post"
self.http_function = getattr(requests, self.method)

# validate proxies formatting
# sanity check only leave actual parsing of values to the `requests` library on call.
if hasattr(self, "proxies") and self.proxies is not None:
if not isinstance(self.proxies, dict):
raise BadGeneratorException(
"`proxies` value provided is not in the required format. See documentation from the `requests` package for details on expected format. https://requests.readthedocs.io/en/latest/user/advanced/#proxies"
)

# validate jsonpath
if self.response_json and self.response_json_field:
try:
Expand Down Expand Up @@ -193,6 +203,7 @@ def _call_model(
data_kw: request_data,
"headers": request_headers,
"timeout": self.request_timeout,
"proxies": self.proxies,
}
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved
resp = self.http_function(self.uri, **req_kArgs)

Expand Down
50 changes: 48 additions & 2 deletions tests/generators/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import json
import pytest
import requests_mock
from sympy import is_increasing

from garak import _config, _plugins

Expand Down Expand Up @@ -122,3 +120,51 @@ def test_rest_skip_code(requests_mock):
)
output = generator._call_model("Who is Enabran Tain's son?")
assert output == [None]


@pytest.mark.usefixtures("set_rest_config")
def test_rest_valid_proxy(mocker, requests_mock):
test_proxies = {
"http": "http://localhost:8080",
"https": "https://localhost:8443",
}
_config.plugins.generators["rest"]["RestGenerator"]["proxies"] = test_proxies
generator = _plugins.load_plugin(
"generators.rest.RestGenerator", config_root=_config
)
requests_mock.post(
DEFAULT_URI,
text=json.dumps(
{
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": DEFAULT_TEXT_RESPONSE,
},
}
]
}
),
)
mock_http_function = mocker.patch.object(
generator, "http_function", wraps=generator.http_function
)
generator._call_model("Who is Enabran Tain's son?")
mock_http_function.assert_called_once()
assert mock_http_function.call_args_list[0].kwargs["proxies"] == test_proxies
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess maybe this isn't the strongest possible test we can do (we could check mock proxy logs), but validation shows that the code works (by tailing proxy logs)



@pytest.mark.usefixtures("set_rest_config")
def test_rest_invalid_proxy(requests_mock):
from garak.exception import GarakException

test_proxies = [
"http://localhost:8080",
"https://localhost:8443",
]
_config.plugins.generators["rest"]["RestGenerator"]["proxies"] = test_proxies
with pytest.raises(GarakException) as exc_info:
_plugins.load_plugin("generators.rest.RestGenerator", config_root=_config)
assert "not in the required format" in str(exc_info.value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good thinking

Loading