From f62e9924094cec4c532db212a5390ea68ab3768e Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Mon, 17 Feb 2025 17:33:35 -0500 Subject: [PATCH 01/17] Added gradio scorer implementation --- pyproject.toml | 8 +- pyrit/score/human_in_the_loop_gradio.py | 36 ++++ pyrit/ui/app.py | 57 ++++++ pyrit/ui/connection_status.py | 55 ++++++ pyrit/ui/rpc.py | 238 ++++++++++++++++++++++++ pyrit/ui/rpc_client.py | 110 +++++++++++ pyrit/ui/scorer.py | 94 ++++++++++ 7 files changed, 597 insertions(+), 1 deletion(-) create mode 100644 pyrit/score/human_in_the_loop_gradio.py create mode 100644 pyrit/ui/app.py create mode 100644 pyrit/ui/connection_status.py create mode 100644 pyrit/ui/rpc.py create mode 100644 pyrit/ui/rpc_client.py create mode 100644 pyrit/ui/scorer.py diff --git a/pyproject.toml b/pyproject.toml index 6329e7035..f404ea730 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ classifiers = [ requires-python = ">=3.10, <3.13" dependencies = [ "aioconsole>=0.7.1", - "aiofiles>=24.1.0", + "aiofiles>=23.2.1", "appdirs>=1.4.0", "art==6.1.0", "azure-cognitiveservices-speech>=1.36.0", @@ -113,6 +113,12 @@ playwright = [ "ollama>=0.4.4" ] +gradio = [ + "gradio>=5.16.0", + "rpyc>=6.0.1", + "pywebview>==5.4" +] + all = [ "accelerate==0.34.2", "azureml-mlflow==1.57.0", diff --git a/pyrit/score/human_in_the_loop_gradio.py b/pyrit/score/human_in_the_loop_gradio.py new file mode 100644 index 000000000..16ef84131 --- /dev/null +++ b/pyrit/score/human_in_the_loop_gradio.py @@ -0,0 +1,36 @@ +import asyncio +from pyrit.score.scorer import Scorer +from pyrit.models import Score, PromptRequestPiece +from typing import Optional + +from ui.rpc import AppRpcServer + + +class HumanInTheLoopScorerGradio(Scorer): + + def __init__(self, *, scorer: Scorer = None, re_scorers: list[Scorer] = None) -> None: + self._scorer = scorer + self._re_scorers = re_scorers + self._rpc_server = AppRpcServer() + self._rpc_server.start() + + + async def score_async(self, request: PromptRequestPiece, *, task: Optional[str] = None) -> list[Score]: + try: + return await asyncio.to_thread(self.score_prompt_manually, request, task=task) + except asyncio.CancelledError: + self._rpc_server.stop() + raise + + + def score_prompt_manually(self, request_prompt: PromptRequestPiece, *, task: Optional[str] = None) -> list[Score]: + self._rpc_server.wait_for_client() + self._rpc_server.send_score_prompt(request_prompt) + score = self._rpc_server.wait_for_score() + return [score] + + def validate(self, request_response: PromptRequestPiece, *, task: Optional[str] = None): + pass + + def __del__(self): + self._rpc_server.stop() \ No newline at end of file diff --git a/pyrit/ui/app.py b/pyrit/ui/app.py new file mode 100644 index 000000000..556b7a51f --- /dev/null +++ b/pyrit/ui/app.py @@ -0,0 +1,57 @@ +import os +import sys +import subprocess +import traceback + +GLOBAL_MUTEX_NAME = "PyRIT-Gradio" + +def launch_app(): + # Launch a new process to run the gradio UI. + # Locate the python executable and run this file. + current_path = os.path.abspath(__file__) + python_path = sys.executable + + # Start a new process to run it + subprocess.Popen([python_path, current_path], creationflags=subprocess.CREATE_NEW_CONSOLE) + +def is_app_running(): + if sys.platform != "win32": + raise NotImplementedError("This function is only supported on Windows.") + return True + + import ctypes.wintypes + + SYNCHRONIZE = 0x00100000 + mutex = ctypes.windll.kernel32.OpenMutexW(SYNCHRONIZE, False, GLOBAL_MUTEX_NAME) + if not mutex: + return False + + # Close the handle to the mutex + ctypes.windll.kernel32.CloseHandle(mutex) + return True + +if __name__ == "__main__": + def create_mutex(): + if sys.platform != "win32": + raise NotImplementedError("This function is only supported on Windows.") + + # TODO make sure to add cross-platform support for this. + import ctypes.wintypes + mutex = ctypes.windll.kernel32.CreateMutexW(None, False, GLOBAL_MUTEX_NAME) + last_error = ctypes.windll.kernel32.GetLastError() + if last_error == 183: # ERROR_ALREADY_EXISTS + return False + return True + + if not create_mutex(): + print("Gradio UI is already running.") + sys.exit(1) + print("Starting Gradio Interface please wait...") + try: + from scorer import GradioApp + app = GradioApp() + app.start_gradio(open_browser=True) + except: + # Print the error message and traceback + print(traceback.format_exc()) + input("Press Enter to exit.") \ No newline at end of file diff --git a/pyrit/ui/connection_status.py b/pyrit/ui/connection_status.py new file mode 100644 index 000000000..161ade692 --- /dev/null +++ b/pyrit/ui/connection_status.py @@ -0,0 +1,55 @@ +import gradio as gr + +from rpc_client import RpcClient + +class ConnectionStatusHandler: + def __init__(self, + is_connected_state: gr.State, + rpc_client: RpcClient): + self.state = is_connected_state + self.server_disconnected = False + self.rpc_client = rpc_client + self.next_prompt = "" + + def setup(self, main_interface: gr.Column, loading_animation: gr.Column, next_prompt_state: gr.State): + self.state.change(fn=self.__on_state_change, inputs=[self.state], outputs=[main_interface, loading_animation, next_prompt_state]) + + connection_status_timer = gr.Timer(1) + connection_status_timer.tick( + fn=self.__check_connection_status, + inputs=[self.state], + outputs=[self.state] + ).then( + fn=self.__reconnect_if_needed, + outputs=[self.state] + ) + + def set_ready(self): + self.server_disconnected = False + + def set_disconnected(self): + self.server_disconnected = True + + def set_next_prompt(self, next_prompt: str): + self.next_prompt = next_prompt + + def __on_state_change(self, is_connected: bool): + print("Connection status changed to: ", is_connected, " - ", self.next_prompt) + if is_connected: + return [gr.Column(visible=True), gr.Row(visible=False), self.next_prompt] + return [gr.Column(visible=False), gr.Row(visible=True), self.next_prompt] + + def __check_connection_status(self, is_connected: bool): + if self.server_disconnected or not is_connected: + print("Gradio disconnected") + return False + return True + + def __reconnect_if_needed(self): + if self.server_disconnected: + print("Attempting to reconnect") + self.rpc_client.reconnect() + prompt = self.rpc_client.wait_for_prompt() + self.next_prompt = str(prompt.original_value) + self.server_disconnected = False + return True \ No newline at end of file diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py new file mode 100644 index 000000000..f5e02f24f --- /dev/null +++ b/pyrit/ui/rpc.py @@ -0,0 +1,238 @@ +import rpyc +import time +import logging + +from typing import Callable, Optional +from threading import Thread, Semaphore + +from ui.app import is_app_running, launch_app + +from pyrit.models import Score, PromptRequestPiece + + +DEFAULT_PORT = 18812 + +logger = logging.getLogger(__name__) + +# Exceptions +class RpcAppException(Exception): + def __init__(self, message: str): + super().__init__(message) + +class RpcAlreadyRunningException(RpcAppException): + """ + This exception is thrown when an RPC server is already running and the user tries to start another one. + """ + def __init__(self): + super().__init__("RPC server is already running.") + +class RpcClientNotReadyException(RpcAppException): + """ + This exception is thrown when the RPC client is not ready to receive messages. + """ + def __init__(self): + super().__init__("RPC client is not ready.") + +class RpcServerStoppedException(RpcAppException): + """ + This exception is thrown when the RPC server is stopped. + """ + def __init__(self): + super().__init__("RPC server is stopped.") + + +# RPC Server +class AppRpcServer: + def __init__(self): + self.__server = None + self.__server_thread = None + self.__rpc_service = None + self.__is_alive_thread = None + self.__is_alive_stop = False + self.__score_received_sem = None + self.__client_ready_sem = None + self.__server_is_running = False + + def start(self): + """ + Attempt to start the RPC server. If the server is already running, this method will throw an exception. + """ + + # Check if the server is already running by checking if the port is already in use. + # If the port is already in use, throw an exception. + if self.__is_instance_running(): + raise RpcAlreadyRunningException() + + self.__score_received_sem = Semaphore(0) + self.__client_ready_sem = Semaphore(0) + + # Start the RPC server. + self.__rpc_service = RpcService(self.__score_received_sem, self.__client_ready_sem) + self.__server = rpyc.ThreadedServer(self.__rpc_service, port=DEFAULT_PORT, protocol_config={"allow_all_attrs": True}) + self.__server_thread = Thread(target=self.__server.start) + self.__server_thread.start() + + # Start a thread to check if the client is still alive + self.__is_alive_stop = False + self.__is_alive_thread = Thread(target=self.__is_alive) + self.__is_alive_thread.start() + + self.__server_is_running = True + + logger.info("RPC server started") + + if not is_app_running(): + logger.info("Launching Gradio UI") + launch_app() + else: + logger.info("Gradio UI is already running. Will not launch another instance.") + + def stop(self): + """ + Stop the RPC server and free up the listening port. + """ + self.stop_request() + if self.__server is not None: + self.__server_thread.join() + + + if self.__is_alive_thread is not None: + self.__is_alive_thread.join() + + logger.info("RPC server stopped") + + def stop_request(self): + """ + Request the RPC server to stop. This method is does not block while waiting for the server to stop. + """ + + logger.info("RPC server stopping") + if self.__server is not None: + self.__server.close() + self.__server = None + + + if self.__is_alive_thread is not None: + self.__is_alive_stop = True + + self.__server_is_running = False + + self.__client_ready_sem.release() + self.__score_received_sem.release() + + def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = None): + """ + Send a score prompt to the client. + """ + if self.__rpc_service is None: + raise RpcAppException("RPC server is not running.") + + self.__rpc_service.send_score_prompt(prompt, task) + + def wait_for_score(self) -> Score: + """ + Wait for the client to send a score. Should always return a score, but if the synchronisation fails it will return None. + """ + if self.__score_received_sem is None or self.__rpc_service is None: + raise RpcAppException("RPC server is not running.") + + self.__score_received_sem.acquire() + if not self.__server_is_running: + raise RpcServerStoppedException() + + score = self.__rpc_service.pop_score_received() + if score is None: + return None + return score + + def wait_for_client(self): + """ + Wait for the client to be ready to receive messages. + """ + if self.__client_ready_sem is None: + raise RpcAppException("RPC server is not running.") + + if self.__rpc_service.is_client_ready(): + return + + logger.info("Waiting for client to be ready") + self.__client_ready_sem.acquire() + + if not self.__server_is_running: + raise RpcServerStoppedException() + + logger.info("Client is ready") + + def __is_instance_running(self): + """ + Check if the RPC server is running. + """ + import socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(('localhost', DEFAULT_PORT)) == 0 + + def __is_alive(self): + """ + Check if a ping has been missed. If a ping has been missed, stop the server. + """ + while not self.__is_alive_stop: + if self.__rpc_service.is_ping_missed(): + logger.error("Ping missed. Stopping server.") + self.stop_request() + break + time.sleep(1) + +# RPC Service +class RpcService(rpyc.Service): + """ + RPC service is the service that RPyC is using + """ + def __init__(self, score_received_sem: Semaphore, client_ready_sem: Semaphore): + super().__init__() + self.__callback_score_prompt = None + self.__last_ping = None + self.__scores_received = [] + self.__score_received_sem = score_received_sem + self.__client_ready_sem = client_ready_sem + + def on_connect(self, conn): + logger.info("Client connected") + + def on_disconnect(self, conn): + logger.info("Client disconnected") + + def exposed_receive_score(self, score: Score): + logger.info(f"Score received: {score}") + self.__scores_received.append(score) + self.__score_received_sem.release() + + def exposed_receive_ping(self): + # A ping should be received every 2s from the client. If a client misses a ping then the server should stoped + self.__last_ping = time.time() + logger.debug("Ping received") + + def exposed_callback_score_prompt(self, callback: Callable[[PromptRequestPiece, Optional[str]], None]): + self.__callback_score_prompt = callback + self.__client_ready_sem.release() + + def is_client_ready(self): + if self.__callback_score_prompt is None: + return False + return True + + def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = None): + if not self.is_client_ready(): + raise RpcClientNotReadyException() + self.__callback_score_prompt(prompt, task) + + def is_ping_missed(self): + if self.__last_ping is None: + return False + + return time.time() - self.__last_ping > 2 + + def pop_score_received(self) -> Score | None: + try: + return self.__scores_received.pop() + except IndexError: + return None \ No newline at end of file diff --git a/pyrit/ui/rpc_client.py b/pyrit/ui/rpc_client.py new file mode 100644 index 000000000..9f05e159e --- /dev/null +++ b/pyrit/ui/rpc_client.py @@ -0,0 +1,110 @@ +import rpyc +import time +import socket + +from typing import Callable, Optional +from threading import Thread, Semaphore, Event +from pyrit.models import PromptRequestPiece, Score + +DEFAULT_PORT = 18812 + +class RpcClient: + def __init__(self, callback_disconnected: Optional[Callable] = None): + self.__c = None + self.__bgsrv = None + + self.__ping_thread = None + self.__bgsrv_thread = None + self.__is_running = False + + self.__shutdown_event = None + self.__prompt_received_sem = None + + self.__prompt_received = None + self.__callback_disconnected = callback_disconnected + + def start(self): + # Check if the port is open + self.__wait_for_server_avaible() + self.__prompt_received_sem = Semaphore(0) + + self.__c = rpyc.connect("localhost", DEFAULT_PORT, config={'allow_public_attrs': True}) + self.__is_running = True + self.__shutdown_event = Event() + self.__bgsrv_thread = Thread(target=self.__bgsrv_lifecycle) + self.__bgsrv_thread.start() + + def wait_for_prompt(self) -> PromptRequestPiece: + self.__prompt_received_sem.acquire() + return self.__prompt_received + + def send_prompt_response(self, response: bool): + score = Score( + score_value=response, + score_type="true_false", + score_category="safety", + score_value_description="Safe" if response else "Unsafe", + score_rationale="The prompt is safe" if response else "The prompt is unsafe", + score_metadata={"prompt_target_identifier": self.__prompt_received.prompt_target_identifier}, + prompt_request_response_id=self.__prompt_received.conversation_id + ) + self.__c.root.receive_score(score) + + def __wait_for_server_avaible(self): + # Wait for the server to be available + while not self.__is_server_running(): + print("Server is not running. Waiting for server to start...") + time.sleep(1) + + def stop(self): + """ + Stop the client. + """ + # Send a signal to the thread to stop + self.__shutdown_event.set() + + def reconnect(self): + """ + Reconnect to the server. + """ + self.stop() + print("Reconnecting to server...") + self.start() + + def __receive_prompt(self, prompt_request: PromptRequestPiece, task: Optional[str] = None): + print(f"Received prompt: {prompt_request}") + self.__prompt_received = prompt_request + self.__prompt_received_sem.release() + + def __ping(self): + try: + while self.__is_running: + self.__c.root.receive_ping() + time.sleep(1.5) + except EOFError: + print("Connection closed") + if self.__callback_disconnected is not None: + self.__callback_disconnected() + + def __bgsrv_lifecycle(self): + self.__bgsrv = rpyc.BgServingThread(self.__c) + self.__ping_thread = Thread(target=self.__ping) + self.__ping_thread.start() + + # Register callback + self.__c.root.callback_score_prompt(self.__receive_prompt) + + # Wait for the server to be disconnected + self.__shutdown_event.wait() + + self.__is_running = False + self.__ping_thread.join() + + # Avoid calling stop() twice if the server is already stopped. This can happen if the server is stopped + # by the ping request. + if self.__bgsrv._active: + self.__bgsrv.stop() + + def __is_server_running(self): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(('localhost', DEFAULT_PORT)) == 0 \ No newline at end of file diff --git a/pyrit/ui/scorer.py b/pyrit/ui/scorer.py new file mode 100644 index 000000000..606bebf8f --- /dev/null +++ b/pyrit/ui/scorer.py @@ -0,0 +1,94 @@ +import gradio as gr +import webview +from rpc_client import RpcClient +from connection_status import ConnectionStatusHandler + +class GradioApp: + def __init__(self): + self.i = 0 + self.rpc_client = RpcClient(self.__disconnected_rpc_callback) + self.connect_status = None + self.url = "" + + def start_gradio(self, open_browser=False): + with gr.Blocks() as demo: + is_connected = gr.State(False) + next_prompt_state = gr.State("") + + self.connect_status = ConnectionStatusHandler(is_connected, self.rpc_client) + with gr.Column(visible=False) as main_interface: + prompt = gr.Markdown("Prompt: ") + prompt.height = "200px" + with gr.Row(): + safe = gr.Button("Safe") + unsafe = gr.Button("Unsafe") + + safe.click( + fn=lambda: [gr.update(interactive=False)]*2 + [""],outputs=[safe, unsafe, next_prompt_state] + ).then( + fn=self.__safe_clicked, outputs=next_prompt_state + ) + unsafe.click( + fn=lambda: [gr.update(interactive=False)]*2 + [""], outputs=[safe, unsafe, next_prompt_state] + ).then( + fn=self.__unsafe_clicked, outputs=next_prompt_state + ) + + with gr.Row() as loading_animation: + loading_text = gr.Markdown("Connecting to PyRIT") + timer = gr.Timer(0.5) + timer.tick(fn=self.__loading_dots, outputs=loading_text) + + next_prompt_state.change(fn=self.__on_next_prompt_change, inputs=[next_prompt_state], outputs=[prompt, safe, unsafe]) + self.connect_status.setup(main_interface, loading_animation, next_prompt_state) + + demo.load( + fn=self.__main_inteface_loaded, + outputs=[main_interface, loading_animation, next_prompt_state, is_connected] + ) + + if open_browser: + demo.launch(inbrowser=True) + else: + _, url, _ = demo.launch(prevent_thread_lock=True) + self.url = url + print("Gradio launched") + webview.create_window("PyRIT - Scorer", self.url) + webview.start() + print("Webview closed!") + + if self.rpc_client: + self.rpc_client.stop() + + def __safe_clicked(self): + self.rpc_client.send_prompt_response(True) + prompt_request = self.rpc_client.wait_for_prompt() + return str(prompt_request.original_value) + + def __unsafe_clicked(self): + self.rpc_client.send_prompt_response(False) + prompt_request = self.rpc_client.wait_for_prompt() + return str(prompt_request.original_value) + + def __on_next_prompt_change(self, next_prompt): + if next_prompt == "": + return [gr.Markdown(f"Waiting for next prompt..."), gr.update(interactive=False), gr.update(interactive=False)] + return [gr.Markdown("Prompt: " + next_prompt), gr.update(interactive=True), gr.update(interactive=True)] + + def __loading_dots(self): + self.i = (self.i + 1) % 4 + return gr.Markdown("Connecting to PyRIT" + "." * self.i) + + def __disconnected_rpc_callback(self): + self.connect_status.set_disconnected() + + def __main_inteface_loaded(self): + print("Showing main interface") + self.rpc_client.start() + prompt_request = self.rpc_client.wait_for_prompt() + next_prompt = str(prompt_request.original_value) + self.connect_status.set_next_prompt(next_prompt) + self.connect_status.set_ready() + print("PyRIT connected") + return [gr.Column(visible=True), gr.Row(visible=False), next_prompt, True] + From c5a24d5317f6ed709d7b19c78e620ba01a52ce05 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Mon, 17 Feb 2025 17:47:15 -0500 Subject: [PATCH 02/17] Added missing export --- pyrit/score/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyrit/score/__init__.py b/pyrit/score/__init__.py index 3043fafa1..28085eb57 100644 --- a/pyrit/score/__init__.py +++ b/pyrit/score/__init__.py @@ -5,6 +5,7 @@ from pyrit.score.float_scale_threshold_scorer import FloatScaleThresholdScorer from pyrit.score.gandalf_scorer import GandalfScorer from pyrit.score.human_in_the_loop_scorer import HumanInTheLoopScorer +from pyrit.score.human_in_the_loop_gradio import HumanInTheLoopScorerGradio from pyrit.score.insecure_code_scorer import InsecureCodeScorer from pyrit.score.markdown_injection import MarkdownInjectionScorer from pyrit.score.prompt_shield_scorer import PromptShieldScorer @@ -23,6 +24,7 @@ "FloatScaleThresholdScorer", "GandalfScorer", "HumanInTheLoopScorer", + "HumanInTheLoopScorerGradio", "InsecureCodeScorer", "LikertScalePaths", "MarkdownInjectionScorer", From b1df05983a8e1a33dae712e0e264ea8efc5659a6 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Tue, 18 Feb 2025 10:13:08 -0500 Subject: [PATCH 03/17] Fixed a few issues with PyRIT integration --- pyrit/score/human_in_the_loop_gradio.py | 10 +++++----- pyrit/ui/__init__.py | 0 pyrit/ui/app.py | 10 +++++++--- pyrit/ui/rpc.py | 12 ++++++------ pyrit/ui/rpc_client.py | 2 +- 5 files changed, 19 insertions(+), 15 deletions(-) create mode 100644 pyrit/ui/__init__.py diff --git a/pyrit/score/human_in_the_loop_gradio.py b/pyrit/score/human_in_the_loop_gradio.py index 16ef84131..97dbfe057 100644 --- a/pyrit/score/human_in_the_loop_gradio.py +++ b/pyrit/score/human_in_the_loop_gradio.py @@ -3,21 +3,21 @@ from pyrit.models import Score, PromptRequestPiece from typing import Optional -from ui.rpc import AppRpcServer +from pyrit.ui.rpc import AppRpcServer class HumanInTheLoopScorerGradio(Scorer): - def __init__(self, *, scorer: Scorer = None, re_scorers: list[Scorer] = None) -> None: + def __init__(self, *, open_browser=False, scorer: Scorer = None, re_scorers: list[Scorer] = None) -> None: self._scorer = scorer self._re_scorers = re_scorers - self._rpc_server = AppRpcServer() + self._rpc_server = AppRpcServer(open_browser=open_browser) self._rpc_server.start() - async def score_async(self, request: PromptRequestPiece, *, task: Optional[str] = None) -> list[Score]: + async def score_async(self, request_response: PromptRequestPiece, *, task: Optional[str] = None) -> list[Score]: try: - return await asyncio.to_thread(self.score_prompt_manually, request, task=task) + return await asyncio.to_thread(self.score_prompt_manually, request_response, task=task) except asyncio.CancelledError: self._rpc_server.stop() raise diff --git a/pyrit/ui/__init__.py b/pyrit/ui/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyrit/ui/app.py b/pyrit/ui/app.py index 556b7a51f..4fba66962 100644 --- a/pyrit/ui/app.py +++ b/pyrit/ui/app.py @@ -5,14 +5,14 @@ GLOBAL_MUTEX_NAME = "PyRIT-Gradio" -def launch_app(): +def launch_app(open_browser=False): # Launch a new process to run the gradio UI. # Locate the python executable and run this file. current_path = os.path.abspath(__file__) python_path = sys.executable # Start a new process to run it - subprocess.Popen([python_path, current_path], creationflags=subprocess.CREATE_NEW_CONSOLE) + subprocess.Popen([python_path, current_path, str(open_browser)], creationflags=subprocess.CREATE_NEW_CONSOLE) def is_app_running(): if sys.platform != "win32": @@ -48,9 +48,13 @@ def create_mutex(): sys.exit(1) print("Starting Gradio Interface please wait...") try: + open_browser = False + if len(sys.argv) > 1: + open_browser = sys.argv[1] == "True" + from scorer import GradioApp app = GradioApp() - app.start_gradio(open_browser=True) + app.start_gradio(open_browser=open_browser) except: # Print the error message and traceback print(traceback.format_exc()) diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py index f5e02f24f..9ba030dd2 100644 --- a/pyrit/ui/rpc.py +++ b/pyrit/ui/rpc.py @@ -5,8 +5,7 @@ from typing import Callable, Optional from threading import Thread, Semaphore -from ui.app import is_app_running, launch_app - +from pyrit.ui.app import is_app_running, launch_app from pyrit.models import Score, PromptRequestPiece @@ -43,7 +42,7 @@ def __init__(self): # RPC Server class AppRpcServer: - def __init__(self): + def __init__(self, open_browser: bool = False): self.__server = None self.__server_thread = None self.__rpc_service = None @@ -52,6 +51,7 @@ def __init__(self): self.__score_received_sem = None self.__client_ready_sem = None self.__server_is_running = False + self.__open_browser = open_browser def start(self): """ @@ -83,7 +83,7 @@ def start(self): if not is_app_running(): logger.info("Launching Gradio UI") - launch_app() + launch_app(open_browser=self.__open_browser) else: logger.info("Gradio UI is already running. Will not launch another instance.") @@ -143,6 +143,8 @@ def wait_for_score(self) -> Score: score = self.__rpc_service.pop_score_received() if score is None: return None + + self.__client_ready_sem.release() return score def wait_for_client(self): @@ -152,8 +154,6 @@ def wait_for_client(self): if self.__client_ready_sem is None: raise RpcAppException("RPC server is not running.") - if self.__rpc_service.is_client_ready(): - return logger.info("Waiting for client to be ready") self.__client_ready_sem.acquire() diff --git a/pyrit/ui/rpc_client.py b/pyrit/ui/rpc_client.py index 9f05e159e..4a556f5b5 100644 --- a/pyrit/ui/rpc_client.py +++ b/pyrit/ui/rpc_client.py @@ -40,7 +40,7 @@ def wait_for_prompt(self) -> PromptRequestPiece: def send_prompt_response(self, response: bool): score = Score( - score_value=response, + score_value=str(response), score_type="true_false", score_category="safety", score_value_description="Safe" if response else "Unsafe", From 829532da8cc71b6c18dbc5cdc6f63645625dba38 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Tue, 18 Feb 2025 10:44:53 -0500 Subject: [PATCH 04/17] Fixed a typo --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f404ea730..3ee60bada 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,7 +116,7 @@ playwright = [ gradio = [ "gradio>=5.16.0", "rpyc>=6.0.1", - "pywebview>==5.4" + "pywebview>=5.4" ] all = [ From 1499dc2bb833f9a02088034635c4ae03d0c562cc Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Tue, 18 Feb 2025 10:52:35 -0500 Subject: [PATCH 05/17] Fixed a typing issue --- pyrit/ui/rpc.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py index 9ba030dd2..2e19473c7 100644 --- a/pyrit/ui/rpc.py +++ b/pyrit/ui/rpc.py @@ -1,4 +1,7 @@ -import rpyc +from typing import TYPE_CHECKING +if not TYPE_CHECKING: + import rpyc + import time import logging @@ -153,7 +156,7 @@ def wait_for_client(self): """ if self.__client_ready_sem is None: raise RpcAppException("RPC server is not running.") - + logger.info("Waiting for client to be ready") self.__client_ready_sem.acquire() From 2557cfd4e73d8cf0362d26520cf40f7b1ab32085 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Tue, 18 Feb 2025 11:17:19 -0500 Subject: [PATCH 06/17] Fixed import errors --- pyproject.toml | 3 ++ pyrit/ui/rpc.py | 123 ++++++++++++++++++++++++------------------------ 2 files changed, 64 insertions(+), 62 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3ee60bada..01a8b4439 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,6 +144,9 @@ all = [ "flask>=3.1.0", "ollama>=0.4.4", "types-PyYAML>=6.0.12.9", + "gradio>=5.16.0", + "rpyc>=6.0.1", + "pywebview>=5.4" ] [project.scripts] diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py index 2e19473c7..1fd8e2a2f 100644 --- a/pyrit/ui/rpc.py +++ b/pyrit/ui/rpc.py @@ -1,7 +1,3 @@ -from typing import TYPE_CHECKING -if not TYPE_CHECKING: - import rpyc - import time import logging @@ -45,6 +41,64 @@ def __init__(self): # RPC Server class AppRpcServer: + import rpyc + # RPC Service + class RpcService(rpyc.Service): + """ + RPC service is the service that RPyC is using + """ + def __init__(self, score_received_sem: Semaphore, client_ready_sem: Semaphore): + super().__init__() + self.__callback_score_prompt = None + self.__last_ping = None + self.__scores_received = [] + self.__score_received_sem = score_received_sem + self.__client_ready_sem = client_ready_sem + + def on_connect(self, conn): + logger.info("Client connected") + + def on_disconnect(self, conn): + logger.info("Client disconnected") + + def exposed_receive_score(self, score: Score): + logger.info(f"Score received: {score}") + self.__scores_received.append(score) + self.__score_received_sem.release() + + def exposed_receive_ping(self): + # A ping should be received every 2s from the client. If a client misses a ping then the server should stoped + self.__last_ping = time.time() + logger.debug("Ping received") + + def exposed_callback_score_prompt(self, callback: Callable[[PromptRequestPiece, Optional[str]], None]): + self.__callback_score_prompt = callback + self.__client_ready_sem.release() + + def is_client_ready(self): + if self.__callback_score_prompt is None: + return False + return True + + def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = None): + if not self.is_client_ready(): + raise RpcClientNotReadyException() + self.__callback_score_prompt(prompt, task) + + def is_ping_missed(self): + if self.__last_ping is None: + return False + + return time.time() - self.__last_ping > 2 + + def pop_score_received(self) -> Score | None: + try: + return self.__scores_received.pop() + except IndexError: + return None + + + def __init__(self, open_browser: bool = False): self.__server = None self.__server_thread = None @@ -70,8 +124,8 @@ def start(self): self.__client_ready_sem = Semaphore(0) # Start the RPC server. - self.__rpc_service = RpcService(self.__score_received_sem, self.__client_ready_sem) - self.__server = rpyc.ThreadedServer(self.__rpc_service, port=DEFAULT_PORT, protocol_config={"allow_all_attrs": True}) + self.__rpc_service = self.RpcService(self.__score_received_sem, self.__client_ready_sem) + self.__server = self.rpyc.ThreadedServer(self.__rpc_service, port=DEFAULT_PORT, protocol_config={"allow_all_attrs": True}) self.__server_thread = Thread(target=self.__server.start) self.__server_thread.start() @@ -183,59 +237,4 @@ def __is_alive(self): logger.error("Ping missed. Stopping server.") self.stop_request() break - time.sleep(1) - -# RPC Service -class RpcService(rpyc.Service): - """ - RPC service is the service that RPyC is using - """ - def __init__(self, score_received_sem: Semaphore, client_ready_sem: Semaphore): - super().__init__() - self.__callback_score_prompt = None - self.__last_ping = None - self.__scores_received = [] - self.__score_received_sem = score_received_sem - self.__client_ready_sem = client_ready_sem - - def on_connect(self, conn): - logger.info("Client connected") - - def on_disconnect(self, conn): - logger.info("Client disconnected") - - def exposed_receive_score(self, score: Score): - logger.info(f"Score received: {score}") - self.__scores_received.append(score) - self.__score_received_sem.release() - - def exposed_receive_ping(self): - # A ping should be received every 2s from the client. If a client misses a ping then the server should stoped - self.__last_ping = time.time() - logger.debug("Ping received") - - def exposed_callback_score_prompt(self, callback: Callable[[PromptRequestPiece, Optional[str]], None]): - self.__callback_score_prompt = callback - self.__client_ready_sem.release() - - def is_client_ready(self): - if self.__callback_score_prompt is None: - return False - return True - - def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = None): - if not self.is_client_ready(): - raise RpcClientNotReadyException() - self.__callback_score_prompt(prompt, task) - - def is_ping_missed(self): - if self.__last_ping is None: - return False - - return time.time() - self.__last_ping > 2 - - def pop_score_received(self) -> Score | None: - try: - return self.__scores_received.pop() - except IndexError: - return None \ No newline at end of file + time.sleep(1) \ No newline at end of file From 3ba7ca6f555e03d20fc634fb17fbb7cb20e96f81 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Tue, 18 Feb 2025 11:21:28 -0500 Subject: [PATCH 07/17] Changed global import to scoped import --- pyrit/score/human_in_the_loop_gradio.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyrit/score/human_in_the_loop_gradio.py b/pyrit/score/human_in_the_loop_gradio.py index 97dbfe057..1a8eaca96 100644 --- a/pyrit/score/human_in_the_loop_gradio.py +++ b/pyrit/score/human_in_the_loop_gradio.py @@ -3,15 +3,14 @@ from pyrit.models import Score, PromptRequestPiece from typing import Optional -from pyrit.ui.rpc import AppRpcServer - - class HumanInTheLoopScorerGradio(Scorer): + # Import here to avoid importing rpyc in the main module that might not be installed + from pyrit.ui.rpc import AppRpcServer def __init__(self, *, open_browser=False, scorer: Scorer = None, re_scorers: list[Scorer] = None) -> None: self._scorer = scorer self._re_scorers = re_scorers - self._rpc_server = AppRpcServer(open_browser=open_browser) + self._rpc_server = self.AppRpcServer(open_browser=open_browser) self._rpc_server.start() From 53b9b9c76a77236dc07a467e48ae11b35930ffe6 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Tue, 18 Feb 2025 11:30:28 -0500 Subject: [PATCH 08/17] Fixed an import issue --- pyrit/score/human_in_the_loop_gradio.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyrit/score/human_in_the_loop_gradio.py b/pyrit/score/human_in_the_loop_gradio.py index 1a8eaca96..676bf4df1 100644 --- a/pyrit/score/human_in_the_loop_gradio.py +++ b/pyrit/score/human_in_the_loop_gradio.py @@ -4,13 +4,14 @@ from typing import Optional class HumanInTheLoopScorerGradio(Scorer): - # Import here to avoid importing rpyc in the main module that might not be installed - from pyrit.ui.rpc import AppRpcServer def __init__(self, *, open_browser=False, scorer: Scorer = None, re_scorers: list[Scorer] = None) -> None: + # Import here to avoid importing rpyc in the main module that might not be installed + from pyrit.ui.rpc import AppRpcServer + self._scorer = scorer self._re_scorers = re_scorers - self._rpc_server = self.AppRpcServer(open_browser=open_browser) + self._rpc_server = AppRpcServer(open_browser=open_browser) self._rpc_server.start() From 4e279b149cd4ba79fba5a40bf5d672050be55f7d Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Tue, 18 Feb 2025 11:36:58 -0500 Subject: [PATCH 09/17] Added HumanInTheLoopScorerGradio to doc --- doc/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api.rst b/doc/api.rst index 3921926fc..68ce0f8c5 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -365,6 +365,7 @@ API Reference FloatScaleThresholdScorer GandalfScorer HumanInTheLoopScorer + HumanInTheLoopScorerGradio LikertScalePaths MarkdownInjectionScorer PromptShieldScorer From 6616eee36284c35703782b9f1e4d59469ff64501 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Mon, 3 Mar 2025 09:47:14 -0500 Subject: [PATCH 10/17] Added missing copyright --- pyrit/score/human_in_the_loop_gradio.py | 3 +++ pyrit/ui/__init__.py | 2 ++ pyrit/ui/app.py | 3 +++ pyrit/ui/connection_status.py | 3 +++ pyrit/ui/rpc.py | 3 +++ pyrit/ui/rpc_client.py | 3 +++ pyrit/ui/scorer.py | 3 +++ 7 files changed, 20 insertions(+) diff --git a/pyrit/score/human_in_the_loop_gradio.py b/pyrit/score/human_in_the_loop_gradio.py index 676bf4df1..9f1c4422f 100644 --- a/pyrit/score/human_in_the_loop_gradio.py +++ b/pyrit/score/human_in_the_loop_gradio.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import asyncio from pyrit.score.scorer import Scorer from pyrit.models import Score, PromptRequestPiece diff --git a/pyrit/ui/__init__.py b/pyrit/ui/__init__.py index e69de29bb..b14b47650 100644 --- a/pyrit/ui/__init__.py +++ b/pyrit/ui/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/pyrit/ui/app.py b/pyrit/ui/app.py index 4fba66962..c9358f3c8 100644 --- a/pyrit/ui/app.py +++ b/pyrit/ui/app.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import os import sys import subprocess diff --git a/pyrit/ui/connection_status.py b/pyrit/ui/connection_status.py index 161ade692..5078bdaf0 100644 --- a/pyrit/ui/connection_status.py +++ b/pyrit/ui/connection_status.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import gradio as gr from rpc_client import RpcClient diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py index 1fd8e2a2f..d9ccc3658 100644 --- a/pyrit/ui/rpc.py +++ b/pyrit/ui/rpc.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import time import logging diff --git a/pyrit/ui/rpc_client.py b/pyrit/ui/rpc_client.py index 4a556f5b5..a008fcf1c 100644 --- a/pyrit/ui/rpc_client.py +++ b/pyrit/ui/rpc_client.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import rpyc import time import socket diff --git a/pyrit/ui/scorer.py b/pyrit/ui/scorer.py index 606bebf8f..18084f597 100644 --- a/pyrit/ui/scorer.py +++ b/pyrit/ui/scorer.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import gradio as gr import webview from rpc_client import RpcClient From 097031447b38b1669f994f1f45a90a2d621c4da8 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Mon, 3 Mar 2025 09:53:32 -0500 Subject: [PATCH 11/17] Added docstring to constructor --- pyrit/score/human_in_the_loop_gradio.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pyrit/score/human_in_the_loop_gradio.py b/pyrit/score/human_in_the_loop_gradio.py index 9f1c4422f..71abbd75f 100644 --- a/pyrit/score/human_in_the_loop_gradio.py +++ b/pyrit/score/human_in_the_loop_gradio.py @@ -7,7 +7,15 @@ from typing import Optional class HumanInTheLoopScorerGradio(Scorer): + """ + Create scores from manual human input using Gradio and adds them to the database. + Parameters: + scorer (Scorer): The scorer to use for the initial scoring. + re_scorers (list[Scorer]): The scorers to use for re-scoring. + open_browser(bool): The scorer will open the Gradio interface in a browser instead of opening it in PyWebview + """ + def __init__(self, *, open_browser=False, scorer: Scorer = None, re_scorers: list[Scorer] = None) -> None: # Import here to avoid importing rpyc in the main module that might not be installed from pyrit.ui.rpc import AppRpcServer From 9a45e55deae6e31cfb467a3cad31f4b17dcf6a8e Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Mon, 3 Mar 2025 10:02:58 -0500 Subject: [PATCH 12/17] Changed RPC capitalization --- pyrit/score/human_in_the_loop_gradio.py | 4 ++-- pyrit/ui/connection_status.py | 4 ++-- pyrit/ui/rpc.py | 28 ++++++++++++------------- pyrit/ui/rpc_client.py | 2 +- pyrit/ui/scorer.py | 4 ++-- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/pyrit/score/human_in_the_loop_gradio.py b/pyrit/score/human_in_the_loop_gradio.py index 71abbd75f..6e8ce89b5 100644 --- a/pyrit/score/human_in_the_loop_gradio.py +++ b/pyrit/score/human_in_the_loop_gradio.py @@ -18,11 +18,11 @@ class HumanInTheLoopScorerGradio(Scorer): def __init__(self, *, open_browser=False, scorer: Scorer = None, re_scorers: list[Scorer] = None) -> None: # Import here to avoid importing rpyc in the main module that might not be installed - from pyrit.ui.rpc import AppRpcServer + from pyrit.ui.rpc import AppRPCServer self._scorer = scorer self._re_scorers = re_scorers - self._rpc_server = AppRpcServer(open_browser=open_browser) + self._rpc_server = AppRPCServer(open_browser=open_browser) self._rpc_server.start() diff --git a/pyrit/ui/connection_status.py b/pyrit/ui/connection_status.py index 5078bdaf0..403645ea8 100644 --- a/pyrit/ui/connection_status.py +++ b/pyrit/ui/connection_status.py @@ -3,12 +3,12 @@ import gradio as gr -from rpc_client import RpcClient +from rpc_client import RPCClient class ConnectionStatusHandler: def __init__(self, is_connected_state: gr.State, - rpc_client: RpcClient): + rpc_client: RPCClient): self.state = is_connected_state self.server_disconnected = False self.rpc_client = rpc_client diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py index d9ccc3658..919315b21 100644 --- a/pyrit/ui/rpc.py +++ b/pyrit/ui/rpc.py @@ -16,25 +16,25 @@ logger = logging.getLogger(__name__) # Exceptions -class RpcAppException(Exception): +class RPCAppException(Exception): def __init__(self, message: str): super().__init__(message) -class RpcAlreadyRunningException(RpcAppException): +class RPCAlreadyRunningException(RPCAppException): """ This exception is thrown when an RPC server is already running and the user tries to start another one. """ def __init__(self): super().__init__("RPC server is already running.") -class RpcClientNotReadyException(RpcAppException): +class RPCClientNotReadyException(RPCAppException): """ This exception is thrown when the RPC client is not ready to receive messages. """ def __init__(self): super().__init__("RPC client is not ready.") -class RpcServerStoppedException(RpcAppException): +class RPCServerStoppedException(RPCAppException): """ This exception is thrown when the RPC server is stopped. """ @@ -43,10 +43,10 @@ def __init__(self): # RPC Server -class AppRpcServer: +class AppRPCServer: import rpyc # RPC Service - class RpcService(rpyc.Service): + class RPCService(rpyc.Service): """ RPC service is the service that RPyC is using """ @@ -85,7 +85,7 @@ def is_client_ready(self): def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = None): if not self.is_client_ready(): - raise RpcClientNotReadyException() + raise RPCClientNotReadyException() self.__callback_score_prompt(prompt, task) def is_ping_missed(self): @@ -121,13 +121,13 @@ def start(self): # Check if the server is already running by checking if the port is already in use. # If the port is already in use, throw an exception. if self.__is_instance_running(): - raise RpcAlreadyRunningException() + raise RPCAlreadyRunningException() self.__score_received_sem = Semaphore(0) self.__client_ready_sem = Semaphore(0) # Start the RPC server. - self.__rpc_service = self.RpcService(self.__score_received_sem, self.__client_ready_sem) + self.__rpc_service = self.RPCService(self.__score_received_sem, self.__client_ready_sem) self.__server = self.rpyc.ThreadedServer(self.__rpc_service, port=DEFAULT_PORT, protocol_config={"allow_all_attrs": True}) self.__server_thread = Thread(target=self.__server.start) self.__server_thread.start() @@ -185,7 +185,7 @@ def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = No Send a score prompt to the client. """ if self.__rpc_service is None: - raise RpcAppException("RPC server is not running.") + raise RPCAppException("RPC server is not running.") self.__rpc_service.send_score_prompt(prompt, task) @@ -194,11 +194,11 @@ def wait_for_score(self) -> Score: Wait for the client to send a score. Should always return a score, but if the synchronisation fails it will return None. """ if self.__score_received_sem is None or self.__rpc_service is None: - raise RpcAppException("RPC server is not running.") + raise RPCAppException("RPC server is not running.") self.__score_received_sem.acquire() if not self.__server_is_running: - raise RpcServerStoppedException() + raise RPCServerStoppedException() score = self.__rpc_service.pop_score_received() if score is None: @@ -212,14 +212,14 @@ def wait_for_client(self): Wait for the client to be ready to receive messages. """ if self.__client_ready_sem is None: - raise RpcAppException("RPC server is not running.") + raise RPCAppException("RPC server is not running.") logger.info("Waiting for client to be ready") self.__client_ready_sem.acquire() if not self.__server_is_running: - raise RpcServerStoppedException() + raise RPCServerStoppedException() logger.info("Client is ready") diff --git a/pyrit/ui/rpc_client.py b/pyrit/ui/rpc_client.py index a008fcf1c..9d0432227 100644 --- a/pyrit/ui/rpc_client.py +++ b/pyrit/ui/rpc_client.py @@ -11,7 +11,7 @@ DEFAULT_PORT = 18812 -class RpcClient: +class RPCClient: def __init__(self, callback_disconnected: Optional[Callable] = None): self.__c = None self.__bgsrv = None diff --git a/pyrit/ui/scorer.py b/pyrit/ui/scorer.py index 18084f597..3c3a0705f 100644 --- a/pyrit/ui/scorer.py +++ b/pyrit/ui/scorer.py @@ -3,13 +3,13 @@ import gradio as gr import webview -from rpc_client import RpcClient +from rpc_client import RPCClient from connection_status import ConnectionStatusHandler class GradioApp: def __init__(self): self.i = 0 - self.rpc_client = RpcClient(self.__disconnected_rpc_callback) + self.rpc_client = RPCClient(self.__disconnected_rpc_callback) self.connect_status = None self.url = "" From bf0931e0ec1b0546c33b3e37bdca517b40d64ec7 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Mon, 3 Mar 2025 10:14:43 -0500 Subject: [PATCH 13/17] Added RPC code description --- pyrit/ui/rpc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py index 919315b21..56371bc38 100644 --- a/pyrit/ui/rpc.py +++ b/pyrit/ui/rpc.py @@ -48,7 +48,10 @@ class AppRPCServer: # RPC Service class RPCService(rpyc.Service): """ - RPC service is the service that RPyC is using + RPC service is the service that RPyC is using. RPC (Remote Procedure Call) is a way to interact with code that + is hosted in another process or on an other machine. RPyC is a library that implements RPC and we are using to + exchange information between PyRIT's main process and Gradio's process. This way the interface is + independent of which PyRIT code is running the process. """ def __init__(self, score_received_sem: Semaphore, client_ready_sem: Semaphore): super().__init__() From e9959bea143e848c6949248faf8903548e1389db Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Mon, 3 Mar 2025 10:19:47 -0500 Subject: [PATCH 14/17] Added a comment about Gradio aiofiles dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 01a8b4439..0b3d99c93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ classifiers = [ requires-python = ">=3.10, <3.13" dependencies = [ "aioconsole>=0.7.1", - "aiofiles>=23.2.1", + "aiofiles==23.2.1", # Pin the version to downgrade aiofiles to make sure it works with Gradio. "appdirs>=1.4.0", "art==6.1.0", "azure-cognitiveservices-speech>=1.36.0", From ff7b5fbf5406eec95094327d382ca7f93bb2c851 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Mon, 3 Mar 2025 10:31:35 -0500 Subject: [PATCH 15/17] Changed coding style for private members --- pyrit/ui/connection_status.py | 12 ++-- pyrit/ui/rpc.py | 124 +++++++++++++++++----------------- pyrit/ui/rpc_client.py | 86 +++++++++++------------ pyrit/ui/scorer.py | 24 +++---- 4 files changed, 123 insertions(+), 123 deletions(-) diff --git a/pyrit/ui/connection_status.py b/pyrit/ui/connection_status.py index 403645ea8..ab8c9e78b 100644 --- a/pyrit/ui/connection_status.py +++ b/pyrit/ui/connection_status.py @@ -15,15 +15,15 @@ def __init__(self, self.next_prompt = "" def setup(self, main_interface: gr.Column, loading_animation: gr.Column, next_prompt_state: gr.State): - self.state.change(fn=self.__on_state_change, inputs=[self.state], outputs=[main_interface, loading_animation, next_prompt_state]) + self.state.change(fn=self._on_state_change, inputs=[self.state], outputs=[main_interface, loading_animation, next_prompt_state]) connection_status_timer = gr.Timer(1) connection_status_timer.tick( - fn=self.__check_connection_status, + fn=self._check_connection_status, inputs=[self.state], outputs=[self.state] ).then( - fn=self.__reconnect_if_needed, + fn=self._reconnect_if_needed, outputs=[self.state] ) @@ -36,19 +36,19 @@ def set_disconnected(self): def set_next_prompt(self, next_prompt: str): self.next_prompt = next_prompt - def __on_state_change(self, is_connected: bool): + def _on_state_change(self, is_connected: bool): print("Connection status changed to: ", is_connected, " - ", self.next_prompt) if is_connected: return [gr.Column(visible=True), gr.Row(visible=False), self.next_prompt] return [gr.Column(visible=False), gr.Row(visible=True), self.next_prompt] - def __check_connection_status(self, is_connected: bool): + def _check_connection_status(self, is_connected: bool): if self.server_disconnected or not is_connected: print("Gradio disconnected") return False return True - def __reconnect_if_needed(self): + def _reconnect_if_needed(self): if self.server_disconnected: print("Attempting to reconnect") self.rpc_client.reconnect() diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py index 56371bc38..d803c83e7 100644 --- a/pyrit/ui/rpc.py +++ b/pyrit/ui/rpc.py @@ -55,11 +55,11 @@ class RPCService(rpyc.Service): """ def __init__(self, score_received_sem: Semaphore, client_ready_sem: Semaphore): super().__init__() - self.__callback_score_prompt = None - self.__last_ping = None - self.__scores_received = [] - self.__score_received_sem = score_received_sem - self.__client_ready_sem = client_ready_sem + self._callback_score_prompt = None + self._last_ping = None + self._scores_received = [] + self._score_received_sem = score_received_sem + self._client_ready_sem = client_ready_sem def on_connect(self, conn): logger.info("Client connected") @@ -69,52 +69,52 @@ def on_disconnect(self, conn): def exposed_receive_score(self, score: Score): logger.info(f"Score received: {score}") - self.__scores_received.append(score) - self.__score_received_sem.release() + self._scores_received.append(score) + self._score_received_sem.release() def exposed_receive_ping(self): # A ping should be received every 2s from the client. If a client misses a ping then the server should stoped - self.__last_ping = time.time() + self._last_ping = time.time() logger.debug("Ping received") def exposed_callback_score_prompt(self, callback: Callable[[PromptRequestPiece, Optional[str]], None]): - self.__callback_score_prompt = callback - self.__client_ready_sem.release() + self._callback_score_prompt = callback + self._client_ready_sem.release() def is_client_ready(self): - if self.__callback_score_prompt is None: + if self._callback_score_prompt is None: return False return True def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = None): if not self.is_client_ready(): raise RPCClientNotReadyException() - self.__callback_score_prompt(prompt, task) + self._callback_score_prompt(prompt, task) def is_ping_missed(self): - if self.__last_ping is None: + if self._last_ping is None: return False - return time.time() - self.__last_ping > 2 + return time.time() - self._last_ping > 2 def pop_score_received(self) -> Score | None: try: - return self.__scores_received.pop() + return self._scores_received.pop() except IndexError: return None def __init__(self, open_browser: bool = False): - self.__server = None - self.__server_thread = None - self.__rpc_service = None - self.__is_alive_thread = None - self.__is_alive_stop = False - self.__score_received_sem = None - self.__client_ready_sem = None - self.__server_is_running = False - self.__open_browser = open_browser + self._server = None + self._server_thread = None + self._rpc_service = None + self._is_alive_thread = None + self._is_alive_stop = False + self._score_received_sem = None + self._client_ready_sem = None + self._server_is_running = False + self._open_browser = open_browser def start(self): """ @@ -123,30 +123,30 @@ def start(self): # Check if the server is already running by checking if the port is already in use. # If the port is already in use, throw an exception. - if self.__is_instance_running(): + if self._is_instance_running(): raise RPCAlreadyRunningException() - self.__score_received_sem = Semaphore(0) - self.__client_ready_sem = Semaphore(0) + self._score_received_sem = Semaphore(0) + self._client_ready_sem = Semaphore(0) # Start the RPC server. - self.__rpc_service = self.RPCService(self.__score_received_sem, self.__client_ready_sem) - self.__server = self.rpyc.ThreadedServer(self.__rpc_service, port=DEFAULT_PORT, protocol_config={"allow_all_attrs": True}) - self.__server_thread = Thread(target=self.__server.start) - self.__server_thread.start() + self._rpc_service = self.RPCService(self._score_received_sem, self._client_ready_sem) + self._server = self.rpyc.ThreadedServer(self._rpc_service, port=DEFAULT_PORT, protocol_config={"allow_all_attrs": True}) + self._server_thread = Thread(target=self._server.start) + self._server_thread.start() # Start a thread to check if the client is still alive - self.__is_alive_stop = False - self.__is_alive_thread = Thread(target=self.__is_alive) - self.__is_alive_thread.start() + self._is_alive_stop = False + self._is_alive_thread = Thread(target=self._is_alive) + self._is_alive_thread.start() - self.__server_is_running = True + self._server_is_running = True logger.info("RPC server started") if not is_app_running(): logger.info("Launching Gradio UI") - launch_app(open_browser=self.__open_browser) + launch_app(open_browser=self._open_browser) else: logger.info("Gradio UI is already running. Will not launch another instance.") @@ -155,12 +155,12 @@ def stop(self): Stop the RPC server and free up the listening port. """ self.stop_request() - if self.__server is not None: - self.__server_thread.join() + if self._server is not None: + self._server_thread.join() - if self.__is_alive_thread is not None: - self.__is_alive_thread.join() + if self._is_alive_thread is not None: + self._is_alive_thread.join() logger.info("RPC server stopped") @@ -170,63 +170,63 @@ def stop_request(self): """ logger.info("RPC server stopping") - if self.__server is not None: - self.__server.close() - self.__server = None + if self._server is not None: + self._server.close() + self._server = None - if self.__is_alive_thread is not None: - self.__is_alive_stop = True + if self._is_alive_thread is not None: + self._is_alive_stop = True - self.__server_is_running = False + self._server_is_running = False - self.__client_ready_sem.release() - self.__score_received_sem.release() + self._client_ready_sem.release() + self._score_received_sem.release() def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = None): """ Send a score prompt to the client. """ - if self.__rpc_service is None: + if self._rpc_service is None: raise RPCAppException("RPC server is not running.") - self.__rpc_service.send_score_prompt(prompt, task) + self._rpc_service.send_score_prompt(prompt, task) def wait_for_score(self) -> Score: """ Wait for the client to send a score. Should always return a score, but if the synchronisation fails it will return None. """ - if self.__score_received_sem is None or self.__rpc_service is None: + if self._score_received_sem is None or self._rpc_service is None: raise RPCAppException("RPC server is not running.") - self.__score_received_sem.acquire() - if not self.__server_is_running: + self._score_received_sem.acquire() + if not self._server_is_running: raise RPCServerStoppedException() - score = self.__rpc_service.pop_score_received() + score = self._rpc_service.pop_score_received() if score is None: return None - self.__client_ready_sem.release() + self._client_ready_sem.release() return score def wait_for_client(self): """ Wait for the client to be ready to receive messages. """ - if self.__client_ready_sem is None: + if self._client_ready_sem is None: raise RPCAppException("RPC server is not running.") logger.info("Waiting for client to be ready") - self.__client_ready_sem.acquire() + self._client_ready_sem.acquire() - if not self.__server_is_running: + if not self._server_is_running: raise RPCServerStoppedException() logger.info("Client is ready") - def __is_instance_running(self): + def _is_instance_running(self): """ Check if the RPC server is running. """ @@ -234,12 +234,12 @@ def __is_instance_running(self): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(('localhost', DEFAULT_PORT)) == 0 - def __is_alive(self): + def _is_alive(self): """ Check if a ping has been missed. If a ping has been missed, stop the server. """ - while not self.__is_alive_stop: - if self.__rpc_service.is_ping_missed(): + while not self._is_alive_stop: + if self._rpc_service.is_ping_missed(): logger.error("Ping missed. Stopping server.") self.stop_request() break diff --git a/pyrit/ui/rpc_client.py b/pyrit/ui/rpc_client.py index 9d0432227..b6d699191 100644 --- a/pyrit/ui/rpc_client.py +++ b/pyrit/ui/rpc_client.py @@ -13,33 +13,33 @@ class RPCClient: def __init__(self, callback_disconnected: Optional[Callable] = None): - self.__c = None - self.__bgsrv = None + self._c = None + self._bgsrv = None - self.__ping_thread = None - self.__bgsrv_thread = None - self.__is_running = False + self._ping_thread = None + self._bgsrv_thread = None + self._is_running = False - self.__shutdown_event = None - self.__prompt_received_sem = None + self._shutdown_event = None + self._prompt_received_sem = None - self.__prompt_received = None - self.__callback_disconnected = callback_disconnected + self._prompt_received = None + self._callback_disconnected = callback_disconnected def start(self): # Check if the port is open - self.__wait_for_server_avaible() - self.__prompt_received_sem = Semaphore(0) + self._wait_for_server_avaible() + self._prompt_received_sem = Semaphore(0) - self.__c = rpyc.connect("localhost", DEFAULT_PORT, config={'allow_public_attrs': True}) - self.__is_running = True - self.__shutdown_event = Event() - self.__bgsrv_thread = Thread(target=self.__bgsrv_lifecycle) - self.__bgsrv_thread.start() + self._c = rpyc.connect("localhost", DEFAULT_PORT, config={'allow_public_attrs': True}) + self._is_running = True + self._shutdown_event = Event() + self._bgsrv_thread = Thread(target=self._bgsrv_lifecycle) + self._bgsrv_thread.start() def wait_for_prompt(self) -> PromptRequestPiece: - self.__prompt_received_sem.acquire() - return self.__prompt_received + self._prompt_received_sem.acquire() + return self._prompt_received def send_prompt_response(self, response: bool): score = Score( @@ -48,14 +48,14 @@ def send_prompt_response(self, response: bool): score_category="safety", score_value_description="Safe" if response else "Unsafe", score_rationale="The prompt is safe" if response else "The prompt is unsafe", - score_metadata={"prompt_target_identifier": self.__prompt_received.prompt_target_identifier}, - prompt_request_response_id=self.__prompt_received.conversation_id + score_metadata={"prompt_target_identifier": self._prompt_received.prompt_target_identifier}, + prompt_request_response_id=self._prompt_received.conversation_id ) - self.__c.root.receive_score(score) + self._c.root.receive_score(score) - def __wait_for_server_avaible(self): + def _wait_for_server_avaible(self): # Wait for the server to be available - while not self.__is_server_running(): + while not self._is_server_running(): print("Server is not running. Waiting for server to start...") time.sleep(1) @@ -64,7 +64,7 @@ def stop(self): Stop the client. """ # Send a signal to the thread to stop - self.__shutdown_event.set() + self._shutdown_event.set() def reconnect(self): """ @@ -74,40 +74,40 @@ def reconnect(self): print("Reconnecting to server...") self.start() - def __receive_prompt(self, prompt_request: PromptRequestPiece, task: Optional[str] = None): + def _receive_prompt(self, prompt_request: PromptRequestPiece, task: Optional[str] = None): print(f"Received prompt: {prompt_request}") - self.__prompt_received = prompt_request - self.__prompt_received_sem.release() + self._prompt_received = prompt_request + self._prompt_received_sem.release() - def __ping(self): + def _ping(self): try: - while self.__is_running: - self.__c.root.receive_ping() + while self._is_running: + self._c.root.receive_ping() time.sleep(1.5) except EOFError: print("Connection closed") - if self.__callback_disconnected is not None: - self.__callback_disconnected() + if self._callback_disconnected is not None: + self._callback_disconnected() - def __bgsrv_lifecycle(self): - self.__bgsrv = rpyc.BgServingThread(self.__c) - self.__ping_thread = Thread(target=self.__ping) - self.__ping_thread.start() + def _bgsrv_lifecycle(self): + self._bgsrv = rpyc.BgServingThread(self._c) + self._ping_thread = Thread(target=self._ping) + self._ping_thread.start() # Register callback - self.__c.root.callback_score_prompt(self.__receive_prompt) + self._c.root.callback_score_prompt(self._receive_prompt) # Wait for the server to be disconnected - self.__shutdown_event.wait() + self._shutdown_event.wait() - self.__is_running = False - self.__ping_thread.join() + self._is_running = False + self._ping_thread.join() # Avoid calling stop() twice if the server is already stopped. This can happen if the server is stopped # by the ping request. - if self.__bgsrv._active: - self.__bgsrv.stop() + if self._bgsrv._active: + self._bgsrv.stop() - def __is_server_running(self): + def _is_server_running(self): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(('localhost', DEFAULT_PORT)) == 0 \ No newline at end of file diff --git a/pyrit/ui/scorer.py b/pyrit/ui/scorer.py index 3c3a0705f..d50f26fd7 100644 --- a/pyrit/ui/scorer.py +++ b/pyrit/ui/scorer.py @@ -9,7 +9,7 @@ class GradioApp: def __init__(self): self.i = 0 - self.rpc_client = RPCClient(self.__disconnected_rpc_callback) + self.rpc_client = RPCClient(self._disconnected_rpc_callback) self.connect_status = None self.url = "" @@ -29,24 +29,24 @@ def start_gradio(self, open_browser=False): safe.click( fn=lambda: [gr.update(interactive=False)]*2 + [""],outputs=[safe, unsafe, next_prompt_state] ).then( - fn=self.__safe_clicked, outputs=next_prompt_state + fn=self._safe_clicked, outputs=next_prompt_state ) unsafe.click( fn=lambda: [gr.update(interactive=False)]*2 + [""], outputs=[safe, unsafe, next_prompt_state] ).then( - fn=self.__unsafe_clicked, outputs=next_prompt_state + fn=self._unsafe_clicked, outputs=next_prompt_state ) with gr.Row() as loading_animation: loading_text = gr.Markdown("Connecting to PyRIT") timer = gr.Timer(0.5) - timer.tick(fn=self.__loading_dots, outputs=loading_text) + timer.tick(fn=self._loading_dots, outputs=loading_text) - next_prompt_state.change(fn=self.__on_next_prompt_change, inputs=[next_prompt_state], outputs=[prompt, safe, unsafe]) + next_prompt_state.change(fn=self._on_next_prompt_change, inputs=[next_prompt_state], outputs=[prompt, safe, unsafe]) self.connect_status.setup(main_interface, loading_animation, next_prompt_state) demo.load( - fn=self.__main_inteface_loaded, + fn=self._main_inteface_loaded, outputs=[main_interface, loading_animation, next_prompt_state, is_connected] ) @@ -63,29 +63,29 @@ def start_gradio(self, open_browser=False): if self.rpc_client: self.rpc_client.stop() - def __safe_clicked(self): + def _safe_clicked(self): self.rpc_client.send_prompt_response(True) prompt_request = self.rpc_client.wait_for_prompt() return str(prompt_request.original_value) - def __unsafe_clicked(self): + def _unsafe_clicked(self): self.rpc_client.send_prompt_response(False) prompt_request = self.rpc_client.wait_for_prompt() return str(prompt_request.original_value) - def __on_next_prompt_change(self, next_prompt): + def _on_next_prompt_change(self, next_prompt): if next_prompt == "": return [gr.Markdown(f"Waiting for next prompt..."), gr.update(interactive=False), gr.update(interactive=False)] return [gr.Markdown("Prompt: " + next_prompt), gr.update(interactive=True), gr.update(interactive=True)] - def __loading_dots(self): + def _loading_dots(self): self.i = (self.i + 1) % 4 return gr.Markdown("Connecting to PyRIT" + "." * self.i) - def __disconnected_rpc_callback(self): + def _disconnected_rpc_callback(self): self.connect_status.set_disconnected() - def __main_inteface_loaded(self): + def _main_inteface_loaded(self): print("Showing main interface") self.rpc_client.start() prompt_request = self.rpc_client.wait_for_prompt() From 494121183b9138b64f81db6a80d1cf50c96260f5 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Mon, 3 Mar 2025 10:44:29 -0500 Subject: [PATCH 16/17] Extracted button click logic --- pyrit/ui/scorer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pyrit/ui/scorer.py b/pyrit/ui/scorer.py index d50f26fd7..3f599489e 100644 --- a/pyrit/ui/scorer.py +++ b/pyrit/ui/scorer.py @@ -64,12 +64,13 @@ def start_gradio(self, open_browser=False): self.rpc_client.stop() def _safe_clicked(self): - self.rpc_client.send_prompt_response(True) - prompt_request = self.rpc_client.wait_for_prompt() - return str(prompt_request.original_value) + return self._send_prompt_response(True) def _unsafe_clicked(self): - self.rpc_client.send_prompt_response(False) + return self._send_prompt_response(False) + + def _send_prompt_response(value): + self.rpc_client.send_prompt_response(value) prompt_request = self.rpc_client.wait_for_prompt() return str(prompt_request.original_value) From b4eb898a423e24ca312c9d3b2f4073f5a524ad33 Mon Sep 17 00:00:00 2001 From: Martin Pouliot Date: Mon, 3 Mar 2025 11:16:05 -0500 Subject: [PATCH 17/17] Changed functions to use kw-only args --- pyrit/ui/connection_status.py | 2 +- pyrit/ui/rpc.py | 36 ++++++++++++++++++----------------- pyrit/ui/scorer.py | 5 ++++- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/pyrit/ui/connection_status.py b/pyrit/ui/connection_status.py index ab8c9e78b..92c74c4cb 100644 --- a/pyrit/ui/connection_status.py +++ b/pyrit/ui/connection_status.py @@ -14,7 +14,7 @@ def __init__(self, self.rpc_client = rpc_client self.next_prompt = "" - def setup(self, main_interface: gr.Column, loading_animation: gr.Column, next_prompt_state: gr.State): + def setup(self, *, main_interface: gr.Column, loading_animation: gr.Column, next_prompt_state: gr.State): self.state.change(fn=self._on_state_change, inputs=[self.state], outputs=[main_interface, loading_animation, next_prompt_state]) connection_status_timer = gr.Timer(1) diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py index d803c83e7..110107abf 100644 --- a/pyrit/ui/rpc.py +++ b/pyrit/ui/rpc.py @@ -53,13 +53,13 @@ class RPCService(rpyc.Service): exchange information between PyRIT's main process and Gradio's process. This way the interface is independent of which PyRIT code is running the process. """ - def __init__(self, score_received_sem: Semaphore, client_ready_sem: Semaphore): + def __init__(self, *, score_received_semaphore: Semaphore, client_ready_semaphore: Semaphore): super().__init__() self._callback_score_prompt = None self._last_ping = None self._scores_received = [] - self._score_received_sem = score_received_sem - self._client_ready_sem = client_ready_sem + self._score_received_semaphore = score_received_semaphore + self._client_ready_semaphore = client_ready_semaphore def on_connect(self, conn): logger.info("Client connected") @@ -70,7 +70,7 @@ def on_disconnect(self, conn): def exposed_receive_score(self, score: Score): logger.info(f"Score received: {score}") self._scores_received.append(score) - self._score_received_sem.release() + self._score_received_semaphore.release() def exposed_receive_ping(self): # A ping should be received every 2s from the client. If a client misses a ping then the server should stoped @@ -79,7 +79,7 @@ def exposed_receive_ping(self): def exposed_callback_score_prompt(self, callback: Callable[[PromptRequestPiece, Optional[str]], None]): self._callback_score_prompt = callback - self._client_ready_sem.release() + self._client_ready_semaphore.release() def is_client_ready(self): if self._callback_score_prompt is None: @@ -111,8 +111,8 @@ def __init__(self, open_browser: bool = False): self._rpc_service = None self._is_alive_thread = None self._is_alive_stop = False - self._score_received_sem = None - self._client_ready_sem = None + self._score_received_semaphore = None + self._client_ready_semaphore = None self._server_is_running = False self._open_browser = open_browser @@ -126,11 +126,13 @@ def start(self): if self._is_instance_running(): raise RPCAlreadyRunningException() - self._score_received_sem = Semaphore(0) - self._client_ready_sem = Semaphore(0) + self._score_received_semaphore = Semaphore(0) + self._client_ready_semaphore = Semaphore(0) # Start the RPC server. - self._rpc_service = self.RPCService(self._score_received_sem, self._client_ready_sem) + self._rpc_service = self.RPCService( + score_received_semaphore=self._score_received_semaphore, + client_ready_semaphore=self._client_ready_semaphore) self._server = self.rpyc.ThreadedServer(self._rpc_service, port=DEFAULT_PORT, protocol_config={"allow_all_attrs": True}) self._server_thread = Thread(target=self._server.start) self._server_thread.start() @@ -180,8 +182,8 @@ def stop_request(self): self._server_is_running = False - self._client_ready_sem.release() - self._score_received_sem.release() + self._client_ready_semaphore.release() + self._score_received_semaphore.release() def send_score_prompt(self, prompt: PromptRequestPiece, task: Optional[str] = None): """ @@ -196,10 +198,10 @@ def wait_for_score(self) -> Score: """ Wait for the client to send a score. Should always return a score, but if the synchronisation fails it will return None. """ - if self._score_received_sem is None or self._rpc_service is None: + if self._score_received_semaphore is None or self._rpc_service is None: raise RPCAppException("RPC server is not running.") - self._score_received_sem.acquire() + self._score_received_semaphore.acquire() if not self._server_is_running: raise RPCServerStoppedException() @@ -207,19 +209,19 @@ def wait_for_score(self) -> Score: if score is None: return None - self._client_ready_sem.release() + self._client_ready_semaphore.release() return score def wait_for_client(self): """ Wait for the client to be ready to receive messages. """ - if self._client_ready_sem is None: + if self._client_ready_semaphore is None: raise RPCAppException("RPC server is not running.") logger.info("Waiting for client to be ready") - self._client_ready_sem.acquire() + self._client_ready_semaphore.acquire() if not self._server_is_running: raise RPCServerStoppedException() diff --git a/pyrit/ui/scorer.py b/pyrit/ui/scorer.py index 3f599489e..64d6f2614 100644 --- a/pyrit/ui/scorer.py +++ b/pyrit/ui/scorer.py @@ -43,7 +43,10 @@ def start_gradio(self, open_browser=False): timer.tick(fn=self._loading_dots, outputs=loading_text) next_prompt_state.change(fn=self._on_next_prompt_change, inputs=[next_prompt_state], outputs=[prompt, safe, unsafe]) - self.connect_status.setup(main_interface, loading_animation, next_prompt_state) + self.connect_status.setup( + main_interface=main_interface, + loading_animation=loading_animation, + next_prompt_state=next_prompt_state) demo.load( fn=self._main_inteface_loaded,