From 3727e0c3e70ce1ffcc2012eb8ccd3aebc73653d2 Mon Sep 17 00:00:00 2001 From: Jeffrey Martin Date: Mon, 13 Jan 2025 13:48:41 -0600 Subject: [PATCH] enforce proxies as `dict` validates the `proxies` value if provided is a `dict` and passed on to the requests call --- docs/source/garak.generators.rest.rst | 1 + garak/generators/rest.py | 10 +++++- tests/generators/test_rest.py | 50 +++++++++++++++++++++++++-- 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/docs/source/garak.generators.rest.rst b/docs/source/garak.generators.rest.rst index 6d303e063..65da65613 100644 --- a/docs/source/garak.generators.rest.rst +++ b/docs/source/garak.generators.rest.rst @@ -12,6 +12,7 @@ Uses the following options from ``_config.plugins.generators["rest.RestGenerator * ``req_template_json_object`` - (optional) the request template as a Python object, to be serialised as a JSON string before replacements * ``method`` - a string describing the HTTP method, to be passed to the requests module; default "post". * ``headers`` - dict describing HTTP headers to be sent with the request +* ``proxies`` - dict passed to ``requests`` method call see `required format`_. * ``response_json`` - Is the response in JSON format? (bool) * ``response_json_field`` - (optional) Which field of the response JSON should be used as the output string? Default ``text``. Can also be a JSONPath value, and ``response_json_field`` is used as such if it starts with ``$``. * ``request_timeout`` - How many seconds should we wait before timing out? Default 20 diff --git a/garak/generators/rest.py b/garak/generators/rest.py index 194199665..6516a8f14 100644 --- a/garak/generators/rest.py +++ b/garak/generators/rest.py @@ -16,7 +16,7 @@ from jsonpath_ng.exceptions import JsonPathParserError from garak import _config -from garak.exception import APIKeyMissingError, RateLimitHit +from garak.exception import APIKeyMissingError, BadGeneratorException, RateLimitHit from garak.generators.base import Generator @@ -120,6 +120,14 @@ def __init__(self, uri=None, config_root=_config): self.method = "post" self.http_function = getattr(requests, self.method) + # validate proxies formatting + # sanity check only leave actual parsing of values to the `requests` library on call. + if hasattr(self, "proxies") and self.proxies is not None: + if not isinstance(self.proxies, dict): + raise BadGeneratorException( + "`proxies` value provided is not in the required format. See documentation from the `requests` package for details on expected format. https://requests.readthedocs.io/en/latest/user/advanced/#proxies" + ) + # validate jsonpath if self.response_json and self.response_json_field: try: diff --git a/tests/generators/test_rest.py b/tests/generators/test_rest.py index f9a82422a..55aa9d128 100644 --- a/tests/generators/test_rest.py +++ b/tests/generators/test_rest.py @@ -1,7 +1,5 @@ import json import pytest -import requests_mock -from sympy import is_increasing from garak import _config, _plugins @@ -122,3 +120,51 @@ def test_rest_skip_code(requests_mock): ) output = generator._call_model("Who is Enabran Tain's son?") assert output == [None] + + +@pytest.mark.usefixtures("set_rest_config") +def test_rest_valid_proxy(mocker, requests_mock): + test_proxies = { + "http": "http://localhost:8080", + "https": "https://localhost:8443", + } + _config.plugins.generators["rest"]["RestGenerator"]["proxies"] = test_proxies + generator = _plugins.load_plugin( + "generators.rest.RestGenerator", config_root=_config + ) + requests_mock.post( + DEFAULT_URI, + text=json.dumps( + { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": DEFAULT_TEXT_RESPONSE, + }, + } + ] + } + ), + ) + mock_http_function = mocker.patch.object( + generator, "http_function", wraps=generator.http_function + ) + generator._call_model("Who is Enabran Tain's son?") + mock_http_function.assert_called_once() + assert mock_http_function.call_args_list[0].kwargs["proxies"] == test_proxies + + +@pytest.mark.usefixtures("set_rest_config") +def test_rest_invalid_proxy(requests_mock): + from garak.exception import GarakException + + test_proxies = [ + "http://localhost:8080", + "https://localhost:8443", + ] + _config.plugins.generators["rest"]["RestGenerator"]["proxies"] = test_proxies + with pytest.raises(GarakException) as exc_info: + _plugins.load_plugin("generators.rest.RestGenerator", config_root=_config) + assert "not in the required format" in str(exc_info.value)