Skip to content

Cleanup backend generate calls #253

@nrfulton

Description

@nrfulton

The backend generation process is a fairly complex protocol. We should break it down into steps so that there's a common implementation of the protocol.

While we are doing this, we might as well clean up mellea so that it has a proper core module.

One proposed organization could look like this:

class Component:
    def parts() -> list[Component | CBlock | ModelOutputThunk]:
        ... #  OR we have format_for_llm from which parts() is derived. But we don't have both. This is a major change.

    def parse_output(...) -> ...
    
class CBlock: same as now

class ModelOutputThunk: same as now

"Backends":

class InferenceEngine:
    async def handle_model_options(): ...

    async def handle_input_cahce(): ...

    async def handle_output_cache(): ....

    async def handle_cache_eviction(): ...

    ...and so on...

    async def generate(c: Component | CBlock, exec_ctx: IEngineExecutionContext) -> ModelOutputThunk:
        FULL IMPLEMENTATION THIS CANNOT BE OVERRIDDEN

Contexts (this needs more thought, both in terms of what role it plays in the code and how it's implemented):

class LogDAG: the main logging class. Need to decide what this looks like.

class IEngineExecutionState(Callable[[InferenceEngine], InferenceEngine]):
    # The Callable interface is an implementation detail that may or may not be correct; you could alos imagine this being a protocol and implementing a visitor pattern on IEngine.

    # LinearContext isn't even a thing. That's now a class ChatHistory(Component).
    # Debate: should this be a ContextManager? Try both options.
    def __call__(engine: InferenceEngine) -> InferenceEngine:
        ... # modifies the underlying inference engine.

    def logs() -> OrderedDict[Component | CBlock, LogDAG]
    
    span_cache : ... = ...

Sampling:

class Sampler:
    def _hyper_param_sampler(self) -> list[IEngineContextManager]:
        return [IdentityContext()] # by default there's no hyper param sampler but you can override this

    @abstractmethod
    async def _sampler(self, engine: InferenceEngine, c: Component | CBlock) -> list[ModelOutputThunk]

    async def __call__(self, engine: InferenceEngine, c: Component | CBlock) -> list[ModelOutputThunk]:
        # 1. get all of the `IEngineContextManager`s
        # 2. for each of those, run _sampler
        # return all of the resulting mots.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions