diff --git a/docs/examples/melp/lazy.py b/docs/examples/melp/lazy.py new file mode 100644 index 00000000..cbdb33c0 --- /dev/null +++ b/docs/examples/melp/lazy.py @@ -0,0 +1,39 @@ +import asyncio +from mellea.stdlib.base import ( + SimpleContext, + Context, + CBlock, + ModelOutputThunk, + SimpleComponent, +) +from mellea.backends import Backend +from mellea.backends.ollama import OllamaModelBackend + +backend = OllamaModelBackend("granite4:latest") + + +async def fib(backend: Backend, ctx: Context, x: CBlock, y: CBlock) -> ModelOutputThunk: + sc = SimpleComponent( + instruction="What is x+y? Respond with the number only.", x=x, y=y + ) + mot, _ = await backend.generate_from_context(action=sc, ctx=SimpleContext()) + return mot + + +async def main(backend: Backend, ctx: Context): + fibs = [] + for i in range(100): + if i == 0 or i == 1: + fibs.append(CBlock(f"{i + 1}")) + else: + fibs.append(await fib(backend, ctx, fibs[i - 1], fibs[i - 2])) + + for x in fibs: + match x: + case ModelOutputThunk(): + print(await x.avalue()) + case CBlock(): + print(x.value) + + +asyncio.run(main(backend, SimpleContext())) diff --git a/docs/examples/melp/lazy_fib.py b/docs/examples/melp/lazy_fib.py new file mode 100644 index 00000000..feec18e0 --- /dev/null +++ b/docs/examples/melp/lazy_fib.py @@ -0,0 +1,44 @@ +import asyncio +from mellea.stdlib.base import ( + SimpleContext, + Context, + CBlock, + ModelOutputThunk, + SimpleComponent, +) +from mellea.stdlib.requirement import Requirement +from mellea.backends import Backend +from mellea.backends.ollama import OllamaModelBackend +from typing import Tuple + +backend = OllamaModelBackend("granite4:latest") + + +async def fib(backend: Backend, ctx: Context, x: CBlock, y: CBlock) -> ModelOutputThunk: + sc = SimpleComponent( + instruction="What is x+y? Respond with the number only.", x=x, y=y + ) + mot, _ = await backend.generate_from_context(action=sc, ctx=SimpleContext()) + return mot + + +async def fib_main(backend: Backend, ctx: Context): + fibs = [] + for i in range(20): + if i == 0 or i == 1: + fibs.append(CBlock(f"{i}")) + else: + mot = await fib(backend, ctx, fibs[i - 1], fibs[i - 2]) + fibs.append(mot) + + print(await fibs[-1].avalue()) + # for x in fibs: + # match x: + # case ModelOutputThunk(): + # n = await x.avalue() + # print(n) + # case CBlock(): + # print(x.value) + + +asyncio.run(fib_main(backend, SimpleContext())) diff --git a/docs/examples/melp/lazy_fib_sample.py b/docs/examples/melp/lazy_fib_sample.py new file mode 100644 index 00000000..0bec2907 --- /dev/null +++ b/docs/examples/melp/lazy_fib_sample.py @@ -0,0 +1,66 @@ +import asyncio +from mellea.stdlib.base import ( + SimpleContext, + Context, + CBlock, + ModelOutputThunk, + SimpleComponent, +) +from mellea.stdlib.requirement import Requirement +from mellea.backends import Backend +from mellea.backends.ollama import OllamaModelBackend +from typing import Tuple + +backend = OllamaModelBackend("granite4:latest") + + +async def _fib_sample( + backend: Backend, ctx: Context, x: CBlock, y: CBlock +) -> ModelOutputThunk | None: + sc = SimpleComponent( + instruction="What is x+y? Respond with the number only.", x=x, y=y + ) + answer_mot, _ = await backend.generate_from_context(action=sc, ctx=SimpleContext()) + + # This is a fundamental thing: it means computation must occur. + # We need to be able to read this off at c.g. construction time. + value = await answer_mot.avalue() + + try: + int(value) + return answer_mot + except: + return None + + +async def fib_sampling_version( + backend: Backend, ctx: Context, x: CBlock, y: CBlock +) -> ModelOutputThunk | None: + for i in range(5): + sample = await _fib_sample(backend, ctx, x, y) + if sample is not None: + return sample + else: + continue + return None + + +async def fib_sampling_version_main(backend: Backend, ctx: Context): + fibs = [] + for i in range(20): + if i == 0 or i == 1: + fibs.append(CBlock(f"{i}")) + else: + mot = await fib_sampling_version(backend, ctx, fibs[i - 1], fibs[i - 2]) + fibs.append(mot) + + for x_i, x in enumerate(fibs): + match x: + case ModelOutputThunk(): + n = await x.avalue() + print(n) + case CBlock(): + print(x.value) + + +asyncio.run(fib_sampling_version_main(backend, SimpleContext())) diff --git a/docs/examples/melp/simple_example.py b/docs/examples/melp/simple_example.py new file mode 100644 index 00000000..7ac1059b --- /dev/null +++ b/docs/examples/melp/simple_example.py @@ -0,0 +1,38 @@ +import asyncio +from mellea.stdlib.base import Context, CBlock, SimpleContext, ModelOutputThunk +from mellea.backends import Backend +from mellea.backends.ollama import OllamaModelBackend + + +async def main(backend: Backend, ctx: Context): + """ + In this example, we show how executing multiple MOTs in parallel should work. + """ + m_states = "Missouri", "Minnesota", "Montana", "Massachusetts" + + poem_thunks = [] + for state_name in m_states: + mot, ctx = await backend.generate_from_context( + CBlock(f"Write a poem about {state_name}"), ctx + ) + poem_thunks.append(mot) + + # Notice that what we have now is a list of ModelOutputThunks, none of which are computed. + for poem_thunk in poem_thunks: + assert type(poem_thunk) == ModelOutputThunk + print(f"Computed: {poem_thunk.is_computed()}") + + # Let's run all of these in parallel. + await asyncio.gather(*[c.avalue() for c in poem_thunks]) + + # Print out the final results, which are now computed. + for poem_thunk in poem_thunks: + print(f"Computed: {poem_thunk.is_computed()}") + + # And let's print out the final results. + for poem_thunk in poem_thunks: + print(poem_thunk.value) + + +backend = OllamaModelBackend(model_id="granite4:latest") +asyncio.run(main(backend, SimpleContext())) diff --git a/docs/examples/melp/states.py b/docs/examples/melp/states.py new file mode 100644 index 00000000..2383bf4a --- /dev/null +++ b/docs/examples/melp/states.py @@ -0,0 +1,44 @@ +from mellea.stdlib.base import SimpleContext, Context, CBlock, SimpleComponent +from mellea.backends import Backend +from mellea.backends.ollama import OllamaModelBackend +import asyncio + + +async def main(backend: Backend, ctx: Context): + a_states = "Alaska,Arizona,Arkansas".split(",") + m_states = "Missouri", "Minnesota", "Montana", "Massachusetts" + + a_state_pops = dict() + for state in a_states: + a_state_pops[state], _ = await backend.generate_from_context( + CBlock(f"What is the population of {state}? Respond with an integer only."), + SimpleContext(), + ) + a_total_pop = SimpleComponent( + instruction=CBlock( + "What is the total population of these states? Respond with an integer only." + ), + **a_state_pops, + ) + a_state_total, _ = await backend.generate_from_context(a_total_pop, SimpleContext()) + + m_state_pops = dict() + for state in m_states: + m_state_pops[state], _ = await backend.generate_from_context( + CBlock(f"What is the population of {state}? Respond with an integer only."), + SimpleContext(), + ) + m_total_pop = SimpleComponent( + instruction=CBlock( + "What is the total population of these states? Respond with an integer only." + ), + **m_state_pops, + ) + m_state_total, _ = await backend.generate_from_context(m_total_pop, SimpleContext()) + + print(await a_state_total.avalue()) + print(await m_state_total.avalue()) + + +backend = OllamaModelBackend(model_id="granite4:latest") +asyncio.run(main(backend, SimpleContext())) diff --git a/docs/rewrite/session_deepdive/0.py b/docs/rewrite/session_deepdive/0.py new file mode 100644 index 00000000..ae95a18a --- /dev/null +++ b/docs/rewrite/session_deepdive/0.py @@ -0,0 +1,10 @@ +from mellea import MelleaSession +from mellea.stdlib.base import SimpleContext +from mellea.backends.ollama import OllamaModelBackend + + +m = MelleaSession( + backend=OllamaModelBackend("granite4:latest"), context=SimpleContext() +) +response = m.chat("What is 1+1?") +print(response.content) diff --git a/docs/rewrite/session_deepdive/1.py b/docs/rewrite/session_deepdive/1.py new file mode 100644 index 00000000..1886a0bb --- /dev/null +++ b/docs/rewrite/session_deepdive/1.py @@ -0,0 +1,11 @@ +import mellea.stdlib.functional as mfuncs +from mellea.stdlib.base import SimpleContext +from mellea.backends.ollama import OllamaModelBackend + +response, next_context = mfuncs.chat( + "What is 1+1?", + context=SimpleContext(), + backend=OllamaModelBackend("granite4:latest"), +) + +print(response.content) diff --git a/docs/rewrite/session_deepdive/2.py b/docs/rewrite/session_deepdive/2.py new file mode 100644 index 00000000..e20346e2 --- /dev/null +++ b/docs/rewrite/session_deepdive/2.py @@ -0,0 +1,11 @@ +import mellea.stdlib.functional as mfuncs +from mellea.stdlib.base import SimpleContext, CBlock +from mellea.backends.ollama import OllamaModelBackend + +response, next_context = mfuncs.act( + CBlock("What is 1+1?"), + context=SimpleContext(), + backend=OllamaModelBackend("granite4:latest"), +) + +print(response.value) diff --git a/docs/rewrite/session_deepdive/3.py b/docs/rewrite/session_deepdive/3.py new file mode 100644 index 00000000..c43797bd --- /dev/null +++ b/docs/rewrite/session_deepdive/3.py @@ -0,0 +1,16 @@ +import mellea.stdlib.functional as mfuncs +from mellea.stdlib.base import SimpleContext, CBlock, Context +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends import Backend +import asyncio + + +async def main(backend: Backend, ctx: Context): + response, next_context = await mfuncs.aact( + CBlock("What is 1+1?"), context=ctx, backend=backend + ) + + print(response.value) + + +asyncio.run(main(OllamaModelBackend("granite4:latest"), SimpleContext())) diff --git a/docs/rewrite/session_deepdive/4.py b/docs/rewrite/session_deepdive/4.py new file mode 100644 index 00000000..e6df38fc --- /dev/null +++ b/docs/rewrite/session_deepdive/4.py @@ -0,0 +1,19 @@ +import mellea.stdlib.functional as mfuncs +from mellea.stdlib.base import SimpleContext, CBlock, Context +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends import Backend +import asyncio + + +async def main(backend: Backend, ctx: Context): + response, next_context = await backend.generate_from_context( + CBlock("What is 1+1?"), + ctx=ctx, # TODO we should rationalize ctx and context acress mfuncs and base/backend. + ) + + print(f"Currently computed: {response.is_computed()}") + print(await response.avalue()) + print(f"Currently computed: {response.is_computed()}") + + +asyncio.run(main(OllamaModelBackend("granite4:latest"), SimpleContext())) diff --git a/docs/rewrite/session_deepdive/5.py b/docs/rewrite/session_deepdive/5.py new file mode 100644 index 00000000..acac7bba --- /dev/null +++ b/docs/rewrite/session_deepdive/5.py @@ -0,0 +1,31 @@ +import mellea.stdlib.functional as mfuncs +from mellea.stdlib.base import ( + SimpleContext, + CBlock, + Context, + SimpleComponent, + Component, +) +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends import Backend +import asyncio + + +async def main(backend: Backend, ctx: Context): + x, _ = await backend.generate_from_context(CBlock("What is 1+1?"), ctx=ctx) + + y, _ = await backend.generate_from_context(CBlock("What is 2+2?"), ctx=ctx) + + response, _ = await backend.generate_from_context( + SimpleComponent(instruction="What is x+y?", x=x, y=y), + ctx=ctx, # TODO we should rationalize ctx and context acress mfuncs and base/backend. + ) + + print(f"x currently computed: {x.is_computed()}") + print(f"y currently computed: {y.is_computed()}") + print(f"response currently computed: {response.is_computed()}") + print(await response.avalue()) + print(f"response currently computed: {response.is_computed()}") + + +asyncio.run(main(OllamaModelBackend("granite4:latest"), SimpleContext())) diff --git a/docs/rewrite/streaming/1.py b/docs/rewrite/streaming/1.py new file mode 100644 index 00000000..688cbf96 --- /dev/null +++ b/docs/rewrite/streaming/1.py @@ -0,0 +1,25 @@ +import mellea.stdlib.functional as mfuncs +from mellea.stdlib.base import SimpleContext, CBlock, Context, SimpleComponent +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends import Backend +import asyncio + + +async def main(backend: Backend, ctx: Context): + x, _ = await backend.generate_from_context(CBlock("What is 1+1?"), ctx=ctx) + + y, _ = await backend.generate_from_context(CBlock("What is 2+2?"), ctx=ctx) + + response, _ = await backend.generate_from_context( + SimpleComponent(instruction="What is x+y?", x=x, y=y), + ctx=ctx, # TODO we should rationalize ctx and context acress mfuncs and base/backend. + ) + + print(f"x currently computed: {x.is_computed()}") + print(f"y currently computed: {y.is_computed()}") + print(f"response currently computed: {response.is_computed()}") + print(await response.avalue()) + print(f"response currently computed: {response.is_computed()}") + + +asyncio.run(main(OllamaModelBackend("granite4:latest"), SimpleContext())) diff --git a/mellea/backends/__init__.py b/mellea/backends/__init__.py index 4a56665a..88e3aec5 100644 --- a/mellea/backends/__init__.py +++ b/mellea/backends/__init__.py @@ -3,12 +3,15 @@ from __future__ import annotations import abc +import asyncio +import itertools from typing import TypeVar import pydantic from mellea.backends.model_ids import ModelIdentifier from mellea.backends.types import ModelOption +from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import CBlock, Component, Context, GenerateLog, ModelOutputThunk BaseModelSubclass = TypeVar( @@ -76,3 +79,43 @@ async def generate_from_raw( model_options: Any model options to upsert into the defaults for this call. tool_calls: Always set to false unless supported by backend. """ + + async def do_generate_walk( + self, action: CBlock | Component | ModelOutputThunk + ) -> None: + """Does the generation walk.""" + _to_compute = list(generate_walk(action)) + coroutines = [x.avalue() for x in _to_compute] + # The following log message might get noisy. Feel free to remove if so. + if len(_to_compute) > 0: + FancyLogger.get_logger().info( + f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots." + ) + await asyncio.gather(*coroutines) + + async def do_generate_walks( + self, actions: list[CBlock | Component | ModelOutputThunk] + ) -> None: + """Does the generation walk.""" + _to_compute = [] + for action in actions: + _to_compute.extend(list(generate_walk(action))) + coroutines = [x.avalue() for x in _to_compute] + # The following log message might get noisy. Feel free to remove if so. + if len(_to_compute) > 0: + FancyLogger.get_logger().info( + f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots." + ) + await asyncio.gather(*coroutines) + + +def generate_walk(c: CBlock | Component | ModelOutputThunk) -> list[ModelOutputThunk]: + """Returns the generation walk ordering for a Span.""" + match c: + case ModelOutputThunk() if not c.is_computed(): + return [c] + case CBlock(): + return [] + case Component(): + parts_walk = [generate_walk(p) for p in c.parts()] + return list(itertools.chain.from_iterable(parts_walk)) # aka flatten diff --git a/mellea/backends/_utils.py b/mellea/backends/_utils.py index 08720bc0..28dc6d5f 100644 --- a/mellea/backends/_utils.py +++ b/mellea/backends/_utils.py @@ -1,13 +1,20 @@ from __future__ import annotations import inspect +import itertools from collections.abc import Callable from typing import Any, Literal from mellea.backends.formatter import Formatter from mellea.backends.tools import parse_tools from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, Component, Context, ModelToolCall +from mellea.stdlib.base import ( + CBlock, + Component, + Context, + ModelOutputThunk, + ModelToolCall, +) from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 28113642..d17a7844 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -196,6 +196,8 @@ async def generate_from_context( tool_calls: bool = False, ): """Generate using the huggingface model.""" + await self.do_generate_walk(action) + # Upsert model options. model_opts = self._simplify_and_merge(model_options) @@ -677,6 +679,8 @@ async def generate_from_raw( tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" + await self.do_generate_walks(actions) + if tool_calls: FancyLogger.get_logger().warning( "The raw endpoint does not support tool calling at the moment." diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 555431c5..80adbc8b 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -241,6 +241,8 @@ async def _generate_from_chat_context_standard( model_options: dict | None = None, tool_calls: bool = False, ) -> ModelOutputThunk: + await self.do_generate_walk(action) + model_opts = self._simplify_and_merge(model_options) linearized_context = ctx.view_for_generation() assert linearized_context is not None, ( @@ -484,6 +486,7 @@ async def generate_from_raw( tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" + await self.do_generate_walks(actions) extra_body = {} if format is not None: FancyLogger.get_logger().warning( diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 713acdd7..b2b8cb39 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -10,7 +10,7 @@ from tqdm import tqdm import mellea.backends.model_ids as model_ids -from mellea.backends import BaseModelSubclass +from mellea.backends import BaseModelSubclass, generate_walk from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter from mellea.backends.model_ids import ModelIdentifier from mellea.backends.tools import ( @@ -294,6 +294,9 @@ async def generate_from_chat_context( Raises: RuntimeError: If not called from a thread with a running event loop. """ + # Start by awaiting any necessary computation. + await self.do_generate_walk(action) + model_opts = self._simplify_and_merge(model_options) linearized_context = ctx.view_for_generation() @@ -408,9 +411,16 @@ async def generate_from_raw( model_opts = self._simplify_and_merge(model_options) + _to_compute = [] + for act in actions: + _to_compute.extend(generate_walk(act)) + parts_coroutines = [x.avalue() for x in _to_compute] + await asyncio.gather(*parts_coroutines) + + prompts = [self.formatter.print(action) for action in actions] + # Ollama doesn't support "batching". There's some ability for concurrency. Use that here. # See https://github.com/ollama/ollama/blob/main/docs/faq.md#how-does-ollama-handle-concurrent-requests. - prompts = [self.formatter.print(action) for action in actions] # Run async so that we can make use of Ollama's concurrency. coroutines: list[Coroutine[Any, Any, ollama.GenerateResponse]] = [] diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index ba825753..c9a7299a 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -316,6 +316,8 @@ async def generate_from_chat_context( tool_calls: bool = False, ) -> tuple[ModelOutputThunk, Context]: """Generates a new completion from the provided Context using this backend's `Formatter`.""" + await self.do_generate_walk(action) + # Requirements can be automatically rerouted to a requirement adapter. if isinstance(action, Requirement): # See docs/dev/requirement_aLoRA_rerouting.md @@ -786,6 +788,8 @@ async def generate_from_raw( tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" + await self.do_generate_walks(actions) + extra_body = {} if format is not None: FancyLogger.get_logger().warning( diff --git a/mellea/backends/vllm.py b/mellea/backends/vllm.py index f9d6a753..07f483ee 100644 --- a/mellea/backends/vllm.py +++ b/mellea/backends/vllm.py @@ -248,6 +248,8 @@ async def generate_from_context( tool_calls: bool = False, ) -> tuple[ModelOutputThunk, Context]: """Generate using the huggingface model.""" + await self.do_generate_walk(action) + # Upsert model options. model_options = self._simplify_and_merge(model_options) @@ -437,6 +439,8 @@ async def generate_from_raw( tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" + await self.do_generate_walks(actions) + if tool_calls: FancyLogger.get_logger().warning( "The completion endpoint does not support tool calling at the moment." diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 5821b446..721e4f05 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -269,6 +269,8 @@ async def generate_from_chat_context( tool_calls: bool = False, ) -> ModelOutputThunk: """Generates a new completion from the provided Context using this backend's `Formatter`.""" + await self.do_generate_walk(action) + model_opts = self._simplify_and_merge( model_options, is_chat_context=ctx.is_chat_context ) @@ -490,6 +492,8 @@ async def generate_from_raw( tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generates a completion text. Gives the input provided to the model without templating.""" + await self.do_generate_walks(actions) + if format is not None: FancyLogger.get_logger().warning( "WatsonxAI completion api does not accept response format, ignoring it for this request." diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index 111d44f6..2260b548 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -656,3 +656,90 @@ class ModelToolCall: def call_func(self) -> Any: """A helper function for calling the function/tool represented by this object.""" return self.func(**self.args) + + +class SimpleComponent(Component): + """A Component that is make up of named spans.""" + + def __init__(self, **kwargs): + """Initialized a simple component of the constructor's kwargs.""" + for key in kwargs.keys(): + if type(kwargs[key]) is str: + kwargs[key] = CBlock(value=kwargs[key]) + self._kwargs_type_check(kwargs) + self._kwargs = kwargs + + def parts(self): + """Returns the values of the kwargs.""" + return list(self._kwargs.values()) + + def _kwargs_type_check(self, kwargs): + for key in kwargs.keys(): + value = kwargs[key] + assert issubclass(type(value), Component) or issubclass( + type(value), CBlock + ), f"Expected span but found {type(value)} of value: {value}" + assert type(key) is str + return True + + @staticmethod + def make_simple_string(kwargs): + """Uses <|key|>value to represent a simple component.""" + return "\n".join( + [f"<|{key}|>{value}" for (key, value) in kwargs.items()] + ) + + @staticmethod + def make_json_string(kwargs): + """Uses json.""" + str_args = dict() + for key in kwargs.keys(): + match kwargs[key]: + case ModelOutputThunk() | CBlock(): + str_args[key] = kwargs[key].value + case Component(): + str_args[key] = kwargs[key].format_for_llm() + import json + + return json.dumps(str_args) + + def format_for_llm(self): + """Uses a string rep.""" + return SimpleComponent.make_json_string(self._kwargs) + + +class HeapContext(Context): + """A HeapContext is a context that is constructed by reading off all of the locals() and globals() whose values are CBlock | Component | MoTs.""" + + def __init__(self): + """Heap at construction-time. Should this be at the use site?""" + self._heap = dict() + + for key, value in globals().items(): + match value: + case ModelOutputThunk() | Component() | CBlock(): + self._heap[key] = value + case _: + continue + + for key, value in locals().items(): + match value: + case ModelOutputThunk() | Component() | CBlock(): + self._heap[key] = value + case _: + continue + + def is_chat_context(self): + """Heap contexts are not chat contexts.""" + return False + + def add(self, c: Component | CBlock) -> Context: + """Returns a new context obtained by adding `c` to this context as the "last item", using _ to denote the last expression.""" + new_context = HeapContext() + new_context._heap = copy(self._heap) + new_context._heap["_"] = c + return new_context + + def view_for_generation(self) -> list[Component | CBlock] | None: + """Provides a linear list of context components to use for generation, or None if that is not possible to construct.""" + return [SimpleComponent(**self._heap)] diff --git a/mellea/stdlib/chat.py b/mellea/stdlib/chat.py index 574e6fa6..f4c38c49 100644 --- a/mellea/stdlib/chat.py +++ b/mellea/stdlib/chat.py @@ -17,7 +17,11 @@ class Message(Component): - """A single Message in a Chat history.""" + """A single Message in a Chat history. + + TODO: we may want to deprecate this Component entirely. + The fact that some Component gets rendered as a chat message is `Formatter` miscellania. + """ Role = Literal["system", "user", "assistant", "tool"] @@ -38,22 +42,33 @@ def __init__( documents (list[Document]): documents associated with the message if any. """ self.role = role - self.content = content + self.content = content # TODO this should be private. + self._content_cblock = CBlock(self.content) self._images = images + # TODO this should replace _images. + self._images_cblocks: list[CBlock] | None = None + if self._images is not None: + self._images_cblocks = [CBlock(str(i)) for i in self._images] self._docs = documents @property def images(self) -> None | list[str]: """Returns the images associated with this message as list of base 64 strings.""" - if self._images is not None: - return [str(i) for i in self._images] + if self._images_cblocks is not None: + return [str(i.value) for i in self._images_cblocks] return None def parts(self): """Returns all of the constituent parts of an Instruction.""" - raise Exception( - "Disallowing use of `parts` until we figure out exactly what it's supposed to be for" + assert self._images is None, ( + "TODO: images are not handled correctly in the mellea core." ) + parts = [self._content_cblock] + if self._docs is not None: + parts.extend(self._docs) + if self._images is not None: + parts.extend(self._images) + return parts def format_for_llm(self) -> TemplateRepresentation: """Formats the content for a Language Model. @@ -65,8 +80,8 @@ def format_for_llm(self) -> TemplateRepresentation: obj=self, args={ "role": self.role, - "content": self.content, - "images": self.images, + "content": self._content_cblock, + "images": self._images_cblocks, "documents": self._docs, }, template_order=["*", "Message"], diff --git a/mellea/stdlib/docs/richdocument.py b/mellea/stdlib/docs/richdocument.py index d0ea9d4e..87b99130 100644 --- a/mellea/stdlib/docs/richdocument.py +++ b/mellea/stdlib/docs/richdocument.py @@ -26,10 +26,13 @@ def __init__(self, doc: DoclingDocument): self._doc = doc def parts(self) -> list[Component | CBlock]: - """A `RichDocument` has no parts.""" - raise NotImplementedError( - "Disallowing use of `parts` until we figure out exactly what it's supposed to be for" - ) + """RichDocument has no parts. + + In the future, we should allow chunking of DoclingDocuments to correspond to parts(). + """ + # TODO: we could separate a DoclingDocument into chunks and then treat those chunks as parts. + # for now, do nothing. + return [] def format_for_llm(self) -> TemplateRepresentation | str: """Return Document content as Markdown. @@ -93,6 +96,10 @@ def __init__(self, obj: Table, query: str) -> None: """ super().__init__(obj, query) + def parts(self): + """The list of cblocks/components on which TableQuery depends.""" + return [self.obj] + def format_for_llm(self) -> TemplateRepresentation: """Template arguments for Formatter.""" assert isinstance(self._obj, Table) @@ -119,6 +126,10 @@ def __init__(self, obj: Table, transformation: str) -> None: """ super().__init__(obj, transformation) + def parts(self): + """The parts for this component.""" + return [self.obj] + def format_for_llm(self) -> TemplateRepresentation: """Template arguments for Formatter.""" assert isinstance(self._obj, Table) @@ -156,6 +167,10 @@ def from_markdown(cls, md: str) -> Table | None: else: return None + def parts(self): + """The current implementation does not necessarily entail any string re-use, so parts is empty.""" + return [] + def to_markdown(self) -> str: """Get the `Table` as markdown.""" return self._ti.export_to_markdown(self._doc) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index 0124d674..bf0ab56c 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -74,7 +74,7 @@ class ArgumentDict(TypedDict): class Argument: - """An Argument Component.""" + """An Argument.""" def __init__( self, @@ -82,7 +82,7 @@ def __init__( name: str | None = None, value: str | None = None, ): - """An Argument Component.""" + """An Argument.""" self._argument_dict: ArgumentDict = { "name": name, "annotation": annotation, @@ -140,10 +140,10 @@ def __init__( class Function: - """A Function Component.""" + """A Function.""" def __init__(self, func: Callable): - """A Function Component.""" + """A Function.""" self._func: Callable = func self._function_dict: FunctionDict = describe_function(func) @@ -382,7 +382,10 @@ def _context_backend_extract_args_and_kwargs( def parts(self): """Not implemented.""" - raise NotImplementedError + cs: list = [] + cs.extend(self._arguments) + cs.extend(self.requirements) + return cs def format_for_llm(self) -> TemplateRepresentation: """Formats the instruction for Formatter use.""" diff --git a/mellea/stdlib/instruction.py b/mellea/stdlib/instruction.py index f8d07efb..695f9117 100644 --- a/mellea/stdlib/instruction.py +++ b/mellea/stdlib/instruction.py @@ -121,9 +121,17 @@ def __init__( def parts(self): """Returns all of the constituent parts of an Instruction.""" - raise Exception( - "Disallowing use of `parts` until we figure out exactly what it's supposed to be for" - ) + # Add all of the optionally defined CBlocks/Components then filter Nones at the end. + cs = [self._description, self._prefix, self._output_prefix] + match self._grounding_context: + case CBlock(): + cs.append(self._grounding_context) + case _: + cs.extend(list(self._grounding_context.values())) + cs.extend(self._requirements) + cs.extend(self._icl_examples) + cs = list(filter(lambda x: x is not None, cs)) + return cs def format_for_llm(self) -> TemplateRepresentation: """Formats the instruction for Formatter use.""" diff --git a/mellea/stdlib/intrinsics/intrinsic.py b/mellea/stdlib/intrinsics/intrinsic.py index 4c54a55d..a90b5a5c 100644 --- a/mellea/stdlib/intrinsics/intrinsic.py +++ b/mellea/stdlib/intrinsics/intrinsic.py @@ -51,7 +51,9 @@ def parts(self) -> list[Component | CBlock]: Will need to be implemented by subclasses since not all intrinsics are output as text / messages. """ - raise NotImplementedError("parts isn't implemented by default") + raise NotImplementedError( + "There is no default definition of parts() for an Intrinsic function." + ) def format_for_llm(self) -> TemplateRepresentation | str: """`Intrinsic` doesn't implement `format_for_default`. diff --git a/mellea/stdlib/mify.py b/mellea/stdlib/mify.py index ec5e5c25..2059e426 100644 --- a/mellea/stdlib/mify.py +++ b/mellea/stdlib/mify.py @@ -28,13 +28,13 @@ class MifiedProtocol(MObjectProtocol, Protocol): _stringify_func: Callable[[object], str] | None = None def parts(self) -> list[Component | CBlock]: - """Returns a list of parts for MObject. + """TODO: we need to rewrite this component to use format_for_llm and initializer correctly. + + For now an empty list is the correct behavior. [no-index] """ - raise NotImplementedError( - "Disallowing use of `parts` until we figure out exactly what it's supposed to be for" - ) + return [] def get_query_object(self, query: str) -> Query: """Returns the instantiated query object. diff --git a/mellea/stdlib/mobject.py b/mellea/stdlib/mobject.py index b2c76b47..19852264 100644 --- a/mellea/stdlib/mobject.py +++ b/mellea/stdlib/mobject.py @@ -24,7 +24,7 @@ def __init__(self, obj: Component, query: str) -> None: def parts(self) -> list[Component | CBlock]: """Get the parts of the query.""" - return [] + return [self._obj] def format_for_llm(self) -> TemplateRepresentation | str: """Format the query for llm.""" @@ -64,7 +64,7 @@ def __init__(self, obj: Component, transformation: str) -> None: def parts(self) -> list[Component | CBlock]: """Get the parts of the transform.""" - return [] + return [self._obj] def format_for_llm(self) -> TemplateRepresentation | str: """Format the transform for llm.""" @@ -154,10 +154,8 @@ def __init__( self._transform_type = transform_type def parts(self) -> list[Component | CBlock]: - """Returns a list of parts for MObject.""" - raise NotImplementedError( - "Disallowing use of `parts` until we figure out exactly what it's supposed to be for" - ) + """MObject has no parts because of how format_for_llm is defined.""" + return [] def get_query_object(self, query: str) -> Query: """Returns the instantiated query object. diff --git a/mellea/stdlib/requirement.py b/mellea/stdlib/requirement.py index e5c4cad9..f30aa4fc 100644 --- a/mellea/stdlib/requirement.py +++ b/mellea/stdlib/requirement.py @@ -163,9 +163,7 @@ async def validate( def parts(self): """Returns all of the constituent parts of a Requirement.""" - raise Exception( - "Disallowing use of `parts` until we figure out exactly what it's supposed to be for" - ) + return [] def format_for_llm(self) -> TemplateRepresentation | str: """Some object protocol magic happens here with management of the output."""