diff --git a/flexeval/core/language_model/vllm_serve_lm.py b/flexeval/core/language_model/vllm_serve_lm.py index 7bbd09c7..ec469074 100644 --- a/flexeval/core/language_model/vllm_serve_lm.py +++ b/flexeval/core/language_model/vllm_serve_lm.py @@ -7,17 +7,19 @@ import threading import time from collections.abc import Callable -from typing import IO, Any +from typing import IO, TYPE_CHECKING, Any import requests import torch from loguru import logger from flexeval.core.language_model.base import LMOutput -from flexeval.core.string_processor import StringProcessor from .openai_api import OpenAIChatAPI +if TYPE_CHECKING: + from flexeval.core.string_processor import StringProcessor + def find_free_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -184,6 +186,9 @@ def __init__( model_limit_new_tokens: int | None = None, tools: list[dict[str, Any]] | None = None, max_parallel_requests: int | None = None, + max_num_trials: int = 5, + first_wait_time: int = 1, + max_wait_time: int = 1, ) -> None: logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) @@ -204,6 +209,9 @@ def __init__( model_limit_new_tokens=model_limit_new_tokens, tools=tools, max_parallel_requests=max_parallel_requests, + max_num_trials=max_num_trials, + first_wait_time=first_wait_time, + max_wait_time=max_wait_time, ) @staticmethod