diff --git a/prompt2model/utils/api_tools.py b/prompt2model/utils/api_tools.py index fb3e8d17f..a40151f53 100644 --- a/prompt2model/utils/api_tools.py +++ b/prompt2model/utils/api_tools.py @@ -141,80 +141,81 @@ async def generate_batch_completion( Returns: List of generated responses. """ - openai.aiosession.set(ClientSession()) - limiter = aiolimiter.AsyncLimiter(requests_per_minute) + async with ClientSession() as _: + limiter = aiolimiter.AsyncLimiter(requests_per_minute) - async def _throttled_completion_acreate( - model: str, - messages: list[dict[str, str]], - temperature: float, - max_tokens: int, - n: int, - top_p: float, - limiter: aiolimiter.AsyncLimiter, - ): - async with limiter: - for _ in range(3): - try: - return await acompletion( - model=model, - messages=messages, - api_base=self.api_base, - temperature=temperature, - max_tokens=max_tokens, - n=n, - top_p=top_p, - ) - except tuple(ERROR_ERRORS_TO_MESSAGES.keys()) as e: - if isinstance( - e, - ( - openai.APIStatusError, - openai.APIError, - ), - ): - logging.warning( - ERROR_ERRORS_TO_MESSAGES[type(e)].format(e=e) + async def _throttled_completion_acreate( + model: str, + messages: list[dict[str, str]], + temperature: float, + max_tokens: int, + n: int, + top_p: float, + limiter: aiolimiter.AsyncLimiter, + ): + async with limiter: + for _ in range(3): + try: + return await acompletion( + model=model, + messages=messages, + api_base=self.api_base, + temperature=temperature, + max_tokens=max_tokens, + n=n, + top_p=top_p, ) - elif isinstance(e, openai.BadRequestError): - logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)]) - return { - "choices": [ - { - "message": { - "content": "Invalid Request: Prompt was filtered" # noqa E501 + except tuple(ERROR_ERRORS_TO_MESSAGES.keys()) as e: + if isinstance( + e, + ( + openai.APIStatusError, + openai.APIError, + ), + ): + logging.warning( + ERROR_ERRORS_TO_MESSAGES[type(e)].format(e=e) + ) + elif isinstance(e, openai.BadRequestError): + logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)]) + return { + "choices": [ + { + "message": { + "content": "Invalid Request: Prompt was filtered" # noqa E501 + } } - } - ] - } - else: - logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)]) - await asyncio.sleep(10) - return {"choices": [{"message": {"content": ""}}]} + ] + } + else: + logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)]) + await asyncio.sleep(10) + return {"choices": [{"message": {"content": ""}}]} - num_prompt_tokens = max(count_tokens_from_string(prompt) for prompt in prompts) - if self.max_tokens: - max_tokens = self.max_tokens - num_prompt_tokens - token_buffer - else: - max_tokens = 3 * num_prompt_tokens - - async_responses = [ - _throttled_completion_acreate( - model=self.model_name, - messages=[ - {"role": "user", "content": f"{prompt}"}, - ], - temperature=temperature, - max_tokens=max_tokens, - n=responses_per_request, - top_p=1, - limiter=limiter, + num_prompt_tokens = max( + count_tokens_from_string(prompt) for prompt in prompts ) - for prompt in prompts - ] - responses = await tqdm_asyncio.gather(*async_responses) + if self.max_tokens: + max_tokens = self.max_tokens - num_prompt_tokens - token_buffer + else: + max_tokens = 3 * num_prompt_tokens + + async_responses = [ + _throttled_completion_acreate( + model=self.model_name, + messages=[ + {"role": "user", "content": f"{prompt}"}, + ], + temperature=temperature, + max_tokens=max_tokens, + n=responses_per_request, + top_p=1, + limiter=limiter, + ) + for prompt in prompts + ] + responses = await tqdm_asyncio.gather(*async_responses) # Note: will never be none because it's set, but mypy doesn't know that. - await openai.aiosession.get().close() return responses