diff --git a/docs/examples/conftest.py b/docs/examples/conftest.py index 430417ff..2fde57e2 100644 --- a/docs/examples/conftest.py +++ b/docs/examples/conftest.py @@ -12,6 +12,7 @@ "simple_rag_with_filter.py", "mcp_example.py", "client.py", + "pii_serve.py", } diff --git a/mellea/backends/__init__.py b/mellea/backends/__init__.py index bda4fd44..4a56665a 100644 --- a/mellea/backends/__init__.py +++ b/mellea/backends/__init__.py @@ -34,7 +34,7 @@ def __init__( self.model_options = model_options if model_options is not None else {} @abc.abstractmethod - def generate_from_context( + async def generate_from_context( self, action: Component | CBlock, ctx: Context, @@ -58,7 +58,7 @@ def generate_from_context( ... @abc.abstractmethod - def generate_from_raw( + async def generate_from_raw( self, actions: list[Component | CBlock], ctx: Context, diff --git a/mellea/backends/dummy.py b/mellea/backends/dummy.py index bde21d8b..e8673cd6 100644 --- a/mellea/backends/dummy.py +++ b/mellea/backends/dummy.py @@ -16,7 +16,7 @@ def __init__(self, responses: list[str] | None): self.responses = responses self.idx = 0 - def generate_from_context( + async def generate_from_context( self, action: Component | CBlock, ctx: Context, diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 1d82aeb0..e9407b53 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -182,7 +182,7 @@ def __init__( self._added_adapters: dict[str, LocalHFAdapter] = {} self._loaded_adapters: dict[str, LocalHFAdapter] = {} - def generate_from_context( + async def generate_from_context( self, action: Component | CBlock, ctx: Context, @@ -229,21 +229,23 @@ def generate_from_context( if reroute_to_alora: # Keep the alora requirement handling separate for now. - mot = self._generate_from_intrinsic( + mot = await self._generate_from_intrinsic( alora_action, ctx, model_options=model_opts ) return mot, ctx.add(alora_action).add(mot) elif isinstance(action, Intrinsic): - mot = self._generate_from_intrinsic(action, ctx, model_options=model_opts) + mot = await self._generate_from_intrinsic( + action, ctx, model_options=model_opts + ) return mot, ctx.add(action).add(mot) - mot = self._generate_from_context_standard( + mot = await self._generate_from_context_standard( action, ctx, _format=format, model_options=model_opts, tool_calls=tool_calls ) return mot, ctx.add(action).add(mot) - def _generate_from_intrinsic( + async def _generate_from_intrinsic( self, action: Intrinsic, ctx: Context, *, model_options: dict[str, Any] ) -> ModelOutputThunk: if not ctx.is_chat_context: @@ -394,7 +396,7 @@ async def granite_common_processing( return output - def _generate_from_context_standard( + async def _generate_from_context_standard( self, action: Component | CBlock, ctx: Context, @@ -627,7 +629,7 @@ async def post_processing( mot._generate_log = generate_log - def generate_from_raw( + async def generate_from_raw( self, actions: list[Component | CBlock], ctx: Context, @@ -663,7 +665,8 @@ def generate_from_raw( ) if format is None: - outputs = self._model.generate( # type: ignore + outputs = await asyncio.to_thread( + self._model.generate, # type: ignore input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict_in_generate=True, @@ -681,7 +684,8 @@ def generate_from_raw( from outlines.processors import RegexLogitsProcessor from transformers import LogitsProcessorList - outputs = self._model.generate( # type: ignore + outputs = await asyncio.to_thread( + self._model.generate, # type: ignore input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], return_dict_in_generate=True, diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 54b8f3c2..e7b9b9c8 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -110,7 +110,7 @@ def __init__( self._past_event_loops: set[int] = set() - def generate_from_context( + async def generate_from_context( self, action: Component | CBlock, ctx: Context, @@ -123,7 +123,7 @@ def generate_from_context( assert ctx.is_chat_context, NotImplementedError( "The Openai backend only supports chat-like contexts." ) - mot = self._generate_from_chat_context_standard( + mot = await self._generate_from_chat_context_standard( action, ctx, _format=format, @@ -231,7 +231,7 @@ def _make_backend_specific_and_remove( return backend_specific - def _generate_from_chat_context_standard( + async def _generate_from_chat_context_standard( self, action: Component | CBlock, ctx: Context, @@ -448,7 +448,7 @@ async def post_processing( "format": _format, "tools_available": tools, "tools_called": mot.tool_calls, - "seed": thinking, + "thinking": thinking, } generate_log.action = mot._action generate_log.result = mot @@ -474,7 +474,7 @@ def _extract_tools( FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") return tools - def generate_from_raw( + async def generate_from_raw( self, actions: list[Component | CBlock], ctx: Context, @@ -484,7 +484,73 @@ def generate_from_raw( tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" - raise NotImplementedError("This method is not implemented yet.") + extra_body = {} + if format is not None: + FancyLogger.get_logger().warning( + "The official OpenAI completion api does not accept response format / structured decoding; " + "it will be passed as an extra arg." + ) + + # Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests. + extra_body["guided_json"] = format.model_json_schema() + if tool_calls: + FancyLogger.get_logger().warning( + "The completion endpoint does not support tool calling." + ) + + # We don't do anything fancy for model_opts with generate from raw; litellm has too many potential options depending on provider. + model_opts = self._simplify_and_merge(model_options) + model_specific_options = self._make_backend_specific_and_remove(model_opts) + + if self._has_potential_event_loop_errors(): + FancyLogger().get_logger().warning( + "There is a known bug with litellm. This generation call may fail. If it does, you should ensure that you are either running only synchronous Mellea functions or running async Mellea functions from one asyncio.run() call." + ) + + prompts = [self.formatter.print(action) for action in actions] + + completion_response = await litellm.atext_completion( + model=self._model_id, prompt=prompts, **model_specific_options + ) + + # Necessary for type checker. + assert isinstance(completion_response, litellm.TextCompletionResponse) # type: ignore + + results = [] + date = datetime.datetime.now() + responses = completion_response.choices + if len(responses) != len(prompts): + FancyLogger().get_logger().error( + "litellm appears to have sent your batch request as a single message; this typically happens with providers like ollama that don't support batching" + ) + + for res, action, prompt in zip(responses, actions, prompts): + output = ModelOutputThunk(res.text) # type: ignore + output._context = None # There is no context for generate_from_raw for now + output._action = action + output._model_options = model_opts + output._meta = { + "litellm_chat_response": res.model_dump(), + "usage": completion_response.usage.model_dump() + if completion_response.usage + else None, + } + + self.formatter.parse(action, output) + + generate_log = GenerateLog() + generate_log.prompt = prompt + generate_log.backend = f"litellm::{self.model_id!s}" + generate_log.model_options = model_opts + generate_log.date = date + generate_log.model_output = completion_response + generate_log.extra = {"seed": model_opts.get("seed", None)} + generate_log.action = action + output._generate_log = generate_log + + results.append(output) + + return results def _extract_model_tool_requests( self, diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 02d6d620..713acdd7 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -252,7 +252,7 @@ def _make_backend_specific_and_remove( ) return ModelOption.remove_special_keys(backend_specific) - def generate_from_context( + async def generate_from_context( self, action: Component | CBlock, ctx: Context, @@ -265,7 +265,7 @@ def generate_from_context( assert ctx.is_chat_context, ( "The ollama backend only supports chat-like contexts." ) - mot = self.generate_from_chat_context( + mot = await self.generate_from_chat_context( action, ctx, _format=format, @@ -275,7 +275,7 @@ def generate_from_context( return mot, ctx.add(action).add(mot) - def generate_from_chat_context( + async def generate_from_chat_context( self, action: Component | CBlock, ctx: Context, @@ -375,6 +375,8 @@ def generate_from_chat_context( # This function should always be called from a running event loop so we don't have to worry about # scheduling the task to a specific event loop here. + + # Use `create_task` so that we don't have to specifically await this task before it starts executing. output._generate = asyncio.create_task( send_to_queue(chat_response, output._async_queue) ) @@ -385,7 +387,7 @@ def generate_from_chat_context( return output - def generate_from_raw( + async def generate_from_raw( self, actions: list[Component | CBlock], ctx: Context, @@ -410,27 +412,20 @@ def generate_from_raw( # See https://github.com/ollama/ollama/blob/main/docs/faq.md#how-does-ollama-handle-concurrent-requests. prompts = [self.formatter.print(action) for action in actions] - async def get_response(): - # Run async so that we can make use of Ollama's concurrency. - coroutines: list[Coroutine[Any, Any, ollama.GenerateResponse]] = [] - for prompt in prompts: - co = self._async_client.generate( - model=self._get_ollama_model_id(), - prompt=prompt, - raw=True, - think=model_opts.get(ModelOption.THINKING, None), - format=format.model_json_schema() if format is not None else None, - options=self._make_backend_specific_and_remove(model_opts), - ) - coroutines.append(co) - - responses = await asyncio.gather(*coroutines, return_exceptions=True) - return responses + # Run async so that we can make use of Ollama's concurrency. + coroutines: list[Coroutine[Any, Any, ollama.GenerateResponse]] = [] + for prompt in prompts: + co = self._async_client.generate( + model=self._get_ollama_model_id(), + prompt=prompt, + raw=True, + think=model_opts.get(ModelOption.THINKING, None), + format=format.model_json_schema() if format is not None else None, + options=self._make_backend_specific_and_remove(model_opts), + ) + coroutines.append(co) - # Run in the same event_loop like other Mellea async code called from a sync function. - responses: list[ollama.GenerateResponse | BaseException] = _run_async_in_thread( - get_response() - ) + responses = await asyncio.gather(*coroutines, return_exceptions=True) results = [] date = datetime.datetime.now() diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index d28766e9..7d1c70ab 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -284,7 +284,7 @@ def _make_backend_specific_and_remove( return model_opts - def generate_from_context( + async def generate_from_context( self, action: Component | CBlock, ctx: Context, @@ -297,7 +297,7 @@ def generate_from_context( assert ctx.is_chat_context, NotImplementedError( "The Openai backend only supports chat-like contexts." ) - return self.generate_from_chat_context( + return await self.generate_from_chat_context( action, ctx, _format=format, @@ -305,7 +305,7 @@ def generate_from_context( tool_calls=tool_calls, ) - def generate_from_chat_context( + async def generate_from_chat_context( self, action: Component | CBlock, ctx: Context, @@ -349,18 +349,18 @@ def generate_from_chat_context( if reroute_to_alora: # Keep the alora requirement handling separate for now. - mot = self._generate_from_intrinsic( + mot = await self._generate_from_intrinsic( alora_action, ctx, model_options=model_options ) return mot, ctx.add(alora_action).add(mot) elif isinstance(action, Intrinsic): - mot = self._generate_from_intrinsic( + mot = await self._generate_from_intrinsic( action, ctx, model_options=model_options ) return mot, ctx.add(action).add(mot) - mot = self._generate_from_chat_context_standard( + mot = await self._generate_from_chat_context_standard( action, ctx, _format=_format, @@ -369,7 +369,7 @@ def generate_from_chat_context( ) return mot, ctx.add(action).add(mot) - def _generate_from_intrinsic( + async def _generate_from_intrinsic( self, action: Intrinsic, ctx: Context, *, model_options: dict | None = None ) -> ModelOutputThunk: model_opts = self._simplify_and_merge( @@ -556,7 +556,7 @@ def messages_to_docs(msgs: list[Message]) -> list[dict[str, str]]: json_docs.append(json_doc) return json_docs - def _generate_from_chat_context_standard( + async def _generate_from_chat_context_standard( self, action: Component | CBlock, ctx: Context, @@ -770,7 +770,7 @@ async def post_processing( generate_log.result = mot mot._generate_log = generate_log - def generate_from_raw( + async def generate_from_raw( self, actions: list[Component | CBlock], ctx: Context, @@ -799,13 +799,15 @@ def generate_from_raw( prompts = [self.formatter.print(action) for action in actions] try: - completion_response: Completion = self._client.completions.create( - model=self._hf_model_id, - prompt=prompts, - extra_body=extra_body, - **self._make_backend_specific_and_remove( - model_opts, is_chat_context=False - ), + completion_response: Completion = ( + await self._async_client.completions.create( + model=self._hf_model_id, + prompt=prompts, + extra_body=extra_body, + **self._make_backend_specific_and_remove( + model_opts, is_chat_context=False + ), + ) ) # type: ignore except openai.BadRequestError as e: if openai_ollama_batching_error in e.message: @@ -822,8 +824,7 @@ def generate_from_raw( for response, action, prompt in zip( completion_response.choices, actions, prompts ): - output = ModelOutputThunk(None) - output.value = response.text + output = ModelOutputThunk(response.text) output._context = None # There is no context for generate_from_raw for now output._action = action output._model_options = model_opts diff --git a/mellea/backends/vllm.py b/mellea/backends/vllm.py index db335c93..f9d6a753 100644 --- a/mellea/backends/vllm.py +++ b/mellea/backends/vllm.py @@ -52,6 +52,8 @@ assert outlines, "outlines needs to be present to make outlines_core work" +format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors + class LocalVLLMBackend(FormatterBackend): """The LocalVLLMBackend uses vLLM's python interface for inference, and uses a Formatter to convert `Component`s into prompts. @@ -235,7 +237,7 @@ def _model(self) -> vllm.AsyncLLMEngine: return self._underlying_model - def generate_from_context( + async def generate_from_context( self, action: Component | CBlock, ctx: Context, @@ -251,22 +253,22 @@ def generate_from_context( # TODO: insert the alora code here. - mot = self._generate_from_context_standard( + mot = await self._generate_from_context_standard( action, ctx, - format=format, + _format=format, model_options=model_options, generate_logs=generate_logs, tool_calls=tool_calls, ) return mot, ctx.add(action).add(mot) - def _generate_from_context_standard( + async def _generate_from_context_standard( self, action: Component | CBlock, ctx: Context, *, - format: type[BaseModelSubclass] | None = None, + _format: type[BaseModelSubclass] | None = None, model_options: dict[str, Any], generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, @@ -281,7 +283,7 @@ def _generate_from_context_standard( # Append tool call information if applicable. tools: dict[str, Callable] = dict() if tool_calls: - if format: + if _format: FancyLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) @@ -315,10 +317,10 @@ def _generate_from_context_standard( ), ) - if format is not None: + if _format is not None: # outlines.generate.json always parses the resulting json into a python dict. # We however want to keep it as a json string for later storing it in ModelOutputThunk - schema: dict[str, Any] = format.model_json_schema() + schema: dict[str, Any] = _format.model_json_schema() schema_json: str = json.dumps(schema) regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( schema_json @@ -353,6 +355,7 @@ def _generate_from_context_standard( output._post_process = functools.partial( self.post_processing, conversation=ctx_as_chat, + _format=_format, tool_calls=tool_calls, tools=tools, seed=model_options.get(ModelOption.SEED, None), @@ -384,6 +387,7 @@ async def post_processing( self, mot: ModelOutputThunk, conversation: list[dict], + _format: type[BaseModelSubclass] | None, tool_calls: bool, tools: dict[str, Callable], seed, @@ -393,7 +397,7 @@ async def post_processing( assert mot.value is not None # Only scan for tools if we are not doing structured output and tool calls were provided to the model. - if format is None and tool_calls: + if _format is None and tool_calls: mot.tool_calls = to_tool_calls(tools, mot.value) assert mot._action is not None, ( @@ -413,7 +417,7 @@ async def post_processing( generate_log.date = datetime.datetime.now() generate_log.model_output = mot.value generate_log.extra = { - "format": format, + "format": _format, "tools_available": tools, "tools_called": mot.tool_calls, "seed": seed, @@ -423,7 +427,7 @@ async def post_processing( mot._generate_log = generate_log - def generate_from_raw( + async def generate_from_raw( self, actions: list[Component | CBlock], ctx: Context, @@ -473,12 +477,8 @@ async def generate(prompt, request_id): assert result_output.finished return result_output.outputs[0].text - async def generate_all(prompts): - tasks = [generate(p, f"{id(prompts)}-{i}") for i, p in enumerate(prompts)] - return await asyncio.gather(*tasks) - - # Allow calling this from async functions. - decoded_results = _run_async_in_thread(generate_all(prompts)) + tasks = [generate(p, f"{id(prompts)}-{i}") for i, p in enumerate(prompts)] + decoded_results = await asyncio.gather(*tasks) results = [ModelOutputThunk(value=text) for text in decoded_results] diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 61192f0c..5821b446 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -236,7 +236,7 @@ def _make_backend_specific_and_remove( return model_opts - def generate_from_context( + async def generate_from_context( self, action: Component | CBlock, ctx: Context, @@ -249,7 +249,7 @@ def generate_from_context( assert ctx.is_chat_context, NotImplementedError( "The watsonx.ai backend only supports chat-like contexts." ) - mot = self.generate_from_chat_context( + mot = await self.generate_from_chat_context( action, ctx, _format=format, @@ -258,7 +258,7 @@ def generate_from_context( ) return mot, ctx.add(action).add(mot) - def generate_from_chat_context( + async def generate_from_chat_context( self, action: Component | CBlock, ctx: Context, @@ -480,7 +480,7 @@ async def post_processing( generate_log.action = mot._action mot._generate_log = generate_log - def generate_from_raw( + async def generate_from_raw( self, actions: list[Component | CBlock], ctx: Context, @@ -499,12 +499,14 @@ def generate_from_raw( prompts = [self.formatter.print(action) for action in actions] - responses = self._model.generate( + responses = await asyncio.to_thread( + self._model.generate, prompt=prompts, params=self._make_backend_specific_and_remove( model_opts, is_chat_context=False ), ) + results = [] date = datetime.datetime.now() diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index 5990edb2..e758be04 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -499,7 +499,7 @@ async def aact( "Calling the function with NO strategy BUT requirements. No requirement is being checked!" ) - result, new_ctx = backend.generate_from_context( + result, new_ctx = await backend.generate_from_context( action, ctx=context, format=format, diff --git a/mellea/stdlib/requirement.py b/mellea/stdlib/requirement.py index 4cef5e35..e5c4cad9 100644 --- a/mellea/stdlib/requirement.py +++ b/mellea/stdlib/requirement.py @@ -149,7 +149,7 @@ async def validate( # and its template gets populated with the output correctly. req_copy = copy(self) req_copy._output = last_output.value - llm_as_a_judge_result, val_ctx = backend.generate_from_context( + llm_as_a_judge_result, val_ctx = await backend.generate_from_context( req_copy, ctx, format=format, model_options=model_options ) await llm_as_a_judge_result.avalue() @@ -296,7 +296,7 @@ async def validate( # and its template gets populated with the output correctly. req_copy = copy(self) req_copy._output = last_output.value - llm_as_a_judge_result, val_ctx = backend.generate_from_context( + llm_as_a_judge_result, val_ctx = await backend.generate_from_context( req_copy, ctx, format=format, model_options=model_options ) await llm_as_a_judge_result.avalue() diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index b307723f..33834163 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -315,7 +315,7 @@ async def validate( # Use a CBlock for HuggingFace - it won't be added as a message action = CBlock("") # type: ignore - mot, val_ctx = self._backend.generate_from_context( + mot, val_ctx = await self._backend.generate_from_context( action, gctx, model_options=guardian_options ) await mot.avalue() diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 17831c48..06401ec4 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -151,7 +151,7 @@ async def sample( flog.info(f"Running loop {loop_count} of {self.loop_budget}") # run a generation pass - result, result_ctx = backend.generate_from_context( + result, result_ctx = await backend.generate_from_context( next_action, ctx=next_context, format=format, diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py index 376ff64b..be082159 100644 --- a/mellea/stdlib/sampling/best_of_n.py +++ b/mellea/stdlib/sampling/best_of_n.py @@ -108,7 +108,7 @@ async def sample( flog.info(f"Running loop {loop_count} of {self.loop_budget}") # run a generation pass - result, result_ctx = backend.generate_from_context( + result, result_ctx = await backend.generate_from_context( next_action, ctx=next_context, format=format, diff --git a/mellea/stdlib/sampling/budget_forcing.py b/mellea/stdlib/sampling/budget_forcing.py index 3aaefb90..692fae66 100644 --- a/mellea/stdlib/sampling/budget_forcing.py +++ b/mellea/stdlib/sampling/budget_forcing.py @@ -148,7 +148,7 @@ async def sample( "Only ollama backend supported with budget forcing" ) # run a generation pass with budget forcing - result = think_budget_forcing( + result = await think_budget_forcing( backend, next_action, ctx=context, diff --git a/mellea/stdlib/sampling_algos/budget_forcing_alg.py b/mellea/stdlib/sampling_algos/budget_forcing_alg.py index eae78e6d..2f0a4c01 100644 --- a/mellea/stdlib/sampling_algos/budget_forcing_alg.py +++ b/mellea/stdlib/sampling_algos/budget_forcing_alg.py @@ -8,7 +8,7 @@ from mellea.stdlib.base import CBlock, Component, Context, GenerateLog, ModelOutputThunk -def think_budget_forcing( # noqa: D417 +async def think_budget_forcing( # noqa: D417 backend: OllamaModelBackend, action: CBlock | Component, *, @@ -77,7 +77,7 @@ def think_budget_forcing( # noqa: D417 break model_options[ModelOption.MAX_NEW_TOKENS] = rem_toks - result = backend.generate_from_raw( + result = await backend.generate_from_raw( [CBlock(value=curr_prompt)], model_options=model_options, ctx=ctx, @@ -150,7 +150,7 @@ def think_budget_forcing( # noqa: D417 model_options.pop(ModelOption.MAX_NEW_TOKENS, None) # generate unconditionally # model_options["logprobs"] = 1 # To get number of generated tokens - result = backend.generate_from_raw( + result = await backend.generate_from_raw( [CBlock(curr_prompt)], model_options=model_options, ctx=ctx, diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index 79434097..c2a5497f 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -204,7 +204,7 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative -def test_generate_from_raw(session): +async def test_generate_from_raw(session): prompts = [ "what is 1+1?", "what is 2+2?", @@ -213,22 +213,23 @@ def test_generate_from_raw(session): "what is 4+2+2?", ] - results = session.backend.generate_from_raw( + results = await session.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx ) assert len(results) == len(prompts) + assert results[0].value is not None @pytest.mark.qualitative -def test_generate_from_raw_with_format(session): +async def test_generate_from_raw_with_format(session): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] class Answer(pydantic.BaseModel): name: str value: int - results = session.backend.generate_from_raw( + results = await session.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], format=Answer, ctx=session.ctx, @@ -248,10 +249,10 @@ class Answer(pydantic.BaseModel): @pytest.mark.qualitative async def test_async_parallel_requests(session): model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context( + mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts ) - mot2, _ = session.backend.generate_from_context( + mot2, _ = await session.backend.generate_from_context( CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts ) @@ -283,7 +284,7 @@ async def test_async_parallel_requests(session): @pytest.mark.qualitative async def test_async_avalue(session): - mot1, _ = session.backend.generate_from_context( + mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) m1_final_val = await mot1.avalue() diff --git a/test/backends/test_litellm_ollama.py b/test/backends/test_litellm_ollama.py index 5c1ddf86..f6f10807 100644 --- a/test/backends/test_litellm_ollama.py +++ b/test/backends/test_litellm_ollama.py @@ -144,13 +144,29 @@ def is_happy(text: str) -> bool: # should yield to true - but, of course, is model dependent assert h is True +async def test_generate_from_raw(session): + prompts = [ + "what is 1+1?", + "what is 2+2?", + "what is 3+3?", + "what is 4+4?", + "what is 4+2+2?", + ] + + results = await session.backend.generate_from_raw( + actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx + ) + + assert len(results) == 1, "ollama doesn't support batching; litellm should send a single message containing all prompts" + assert results[0].value is not None + async def test_async_parallel_requests(session): model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context( + mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts ) - mot2, _ = session.backend.generate_from_context( + mot2, _ = await session.backend.generate_from_context( CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts ) @@ -181,7 +197,7 @@ async def test_async_parallel_requests(session): async def test_async_avalue(session): - mot1, _ = session.backend.generate_from_context( + mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) m1_final_val = await mot1.avalue() diff --git a/test/backends/test_litellm_watsonx.py b/test/backends/test_litellm_watsonx.py index 95ecb2b0..76030f8e 100644 --- a/test/backends/test_litellm_watsonx.py +++ b/test/backends/test_litellm_watsonx.py @@ -3,6 +3,7 @@ from mellea import MelleaSession from mellea.backends.litellm import LiteLLMBackend +from mellea.stdlib.base import CBlock @pytest.fixture(scope="function") @@ -38,6 +39,23 @@ def test_multiple_sync_funcs(session): session.chat("second") +@pytest.mark.qualitative +async def test_generate_from_raw(session): + prompts = [ + "what is 1+1?", + "what is 2+2?", + "what is 3+3?", + "what is 4+2+2?", + ] + + results = await session.backend.generate_from_raw( + actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx + ) + + assert len(results) == 1, "litellm converts a batch request for watsonx into a single message" + assert results[0].value is not None + + @pytest.mark.qualitative @pytest.mark.xfail( reason="litellm has a bug with watsonx; once that is fixed, this should pass." diff --git a/test/backends/test_ollama.py b/test/backends/test_ollama.py index 5d9a0805..3e990430 100644 --- a/test/backends/test_ollama.py +++ b/test/backends/test_ollama.py @@ -97,25 +97,26 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative -def test_generate_from_raw(session): +async def test_generate_from_raw(session): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] - results = session.backend.generate_from_raw( + results = await session.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx ) assert len(results) == len(prompts) + assert results[0].value is not None @pytest.mark.xfail(reason="ollama sometimes fails generated structured outputs") -def test_generate_from_raw_with_format(session): +async def test_generate_from_raw_with_format(session): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] class Answer(pydantic.BaseModel): name: str value: int - results = session.backend.generate_from_raw( + results = await session.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx, format=Answer, @@ -134,10 +135,10 @@ class Answer(pydantic.BaseModel): async def test_async_parallel_requests(session): model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context( + mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts ) - mot2, _ = session.backend.generate_from_context( + mot2, _ = await session.backend.generate_from_context( CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts ) @@ -168,7 +169,7 @@ async def test_async_parallel_requests(session): async def test_async_avalue(session): - mot1, _ = session.backend.generate_from_context( + mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) m1_final_val = await mot1.avalue() diff --git a/test/backends/test_openai_ollama.py b/test/backends/test_openai_ollama.py index fab85105..24f656bc 100644 --- a/test/backends/test_openai_ollama.py +++ b/test/backends/test_openai_ollama.py @@ -2,6 +2,7 @@ import asyncio import os +import openai import pydantic import pytest from typing_extensions import Annotated @@ -112,15 +113,14 @@ class Email(pydantic.BaseModel): pass -# @pytest.mark.qualitative -# def test_generate_from_raw(m_session): -# prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] - -# results = m_session.backend.generate_from_raw( -# actions=[CBlock(value=prompt) for prompt in prompts], ctx=m_session.ctx -# ) +@pytest.mark.qualitative +async def test_generate_from_raw(m_session): + prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] -# assert len(results) == len(prompts) + with pytest.raises(openai.BadRequestError): + results = await m_session.backend.generate_from_raw( + actions=[CBlock(value=prompt) for prompt in prompts], ctx=m_session.ctx + ) # Default OpenAI implementation doesn't support structured outputs for the completions API. # def test_generate_from_raw_with_format(self): @@ -147,10 +147,10 @@ class Email(pydantic.BaseModel): async def test_async_parallel_requests(m_session): model_opts = {ModelOption.STREAM: True} - mot1, _ = m_session.backend.generate_from_context( + mot1, _ = await m_session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts ) - mot2, _ = m_session.backend.generate_from_context( + mot2, _ = await m_session.backend.generate_from_context( CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts ) @@ -181,7 +181,7 @@ async def test_async_parallel_requests(m_session): async def test_async_avalue(m_session): - mot1, _ = m_session.backend.generate_from_context( + mot1, _ = await m_session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) m1_final_val = await mot1.avalue() @@ -218,4 +218,4 @@ async def get_client_async(): if __name__ == "__main__": import pytest - pytest.main([__file__]) + pytest.main([__file__, "-k", "generate_from_raw"]) diff --git a/test/backends/test_openai_vllm/test_openai_vllm.py b/test/backends/test_openai_vllm/test_openai_vllm.py index 8fc05a1b..537392a4 100644 --- a/test/backends/test_openai_vllm/test_openai_vllm.py +++ b/test/backends/test_openai_vllm/test_openai_vllm.py @@ -97,23 +97,24 @@ class Email(pydantic.BaseModel): # assert email.to.email_address.endswith("example.com") pass - def test_generate_from_raw(self): + async def test_generate_from_raw(self): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] - results = self.m.backend.generate_from_raw( + results = await self.m.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], ctx=self.m.ctx ) assert len(results) == len(prompts) + assert results[0].value is not None - def test_generate_from_raw_with_format(self): + async def test_generate_from_raw_with_format(self): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] class Answer(pydantic.BaseModel): name: str value: int - results = self.m.backend.generate_from_raw( + results = await self.m.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], format=Answer, ctx=self.m.ctx, diff --git a/test/backends/test_vllm.py b/test/backends/test_vllm.py index 7bb67bef..cfcda8c2 100644 --- a/test/backends/test_vllm.py +++ b/test/backends/test_vllm.py @@ -102,25 +102,26 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative -def test_generate_from_raw(session): +async def test_generate_from_raw(session): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] - results = session.backend.generate_from_raw( + results = await session.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx ) assert len(results) == len(prompts) + assert results[0].value is not None @pytest.mark.qualitative -def test_generate_from_raw_with_format(session): +async def test_generate_from_raw_with_format(session): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] class Answer(pydantic.BaseModel): name: str value: int - results = session.backend.generate_from_raw( + results = await session.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx, format=Answer, @@ -141,10 +142,10 @@ class Answer(pydantic.BaseModel): def test_async_parallel_requests(session): async def parallel_requests(): model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context( + mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts ) - mot2, _ = session.backend.generate_from_context( + mot2, _ = await session.backend.generate_from_context( CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts ) @@ -179,7 +180,7 @@ async def parallel_requests(): @pytest.mark.qualitative def test_async_avalue(session): async def avalue(): - mot1, _ = session.backend.generate_from_context( + mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) m1_final_val = await mot1.avalue() diff --git a/test/backends/test_watsonx.py b/test/backends/test_watsonx.py index 0d160b89..08615973 100644 --- a/test/backends/test_watsonx.py +++ b/test/backends/test_watsonx.py @@ -130,23 +130,24 @@ class Email(pydantic.BaseModel): @pytest.mark.qualitative -def test_generate_from_raw(session: MelleaSession): +async def test_generate_from_raw(session: MelleaSession): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] - results = session.backend.generate_from_raw( + results = await session.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx ) assert len(results) == len(prompts) + assert results[0].value is not None @pytest.mark.qualitative async def test_async_parallel_requests(session): model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context( + mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts ) - mot2, _ = session.backend.generate_from_context( + mot2, _ = await session.backend.generate_from_context( CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts ) @@ -178,7 +179,7 @@ async def test_async_parallel_requests(session): @pytest.mark.qualitative async def test_async_avalue(session): - mot1, _ = session.backend.generate_from_context( + mot1, _ = await session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) m1_final_val = await mot1.avalue()