-
Notifications
You must be signed in to change notification settings - Fork 65
Open
Description
The backend generation process is a fairly complex protocol. We should break it down into steps so that there's a common implementation of the protocol.
While we are doing this, we might as well clean up mellea so that it has a proper core module.
One proposed organization could look like this:
class Component:
def parts() -> list[Component | CBlock | ModelOutputThunk]:
... # OR we have format_for_llm from which parts() is derived. But we don't have both. This is a major change.
def parse_output(...) -> ...
class CBlock: same as now
class ModelOutputThunk: same as now"Backends":
class InferenceEngine:
async def handle_model_options(): ...
async def handle_input_cahce(): ...
async def handle_output_cache(): ....
async def handle_cache_eviction(): ...
...and so on...
async def generate(c: Component | CBlock, exec_ctx: IEngineExecutionContext) -> ModelOutputThunk:
FULL IMPLEMENTATION THIS CANNOT BE OVERRIDDENContexts (this needs more thought, both in terms of what role it plays in the code and how it's implemented):
class LogDAG: the main logging class. Need to decide what this looks like.
class IEngineExecutionState(Callable[[InferenceEngine], InferenceEngine]):
# The Callable interface is an implementation detail that may or may not be correct; you could alos imagine this being a protocol and implementing a visitor pattern on IEngine.
# LinearContext isn't even a thing. That's now a class ChatHistory(Component).
# Debate: should this be a ContextManager? Try both options.
def __call__(engine: InferenceEngine) -> InferenceEngine:
... # modifies the underlying inference engine.
def logs() -> OrderedDict[Component | CBlock, LogDAG]
span_cache : ... = ...Sampling:
class Sampler:
def _hyper_param_sampler(self) -> list[IEngineContextManager]:
return [IdentityContext()] # by default there's no hyper param sampler but you can override this
@abstractmethod
async def _sampler(self, engine: InferenceEngine, c: Component | CBlock) -> list[ModelOutputThunk]
async def __call__(self, engine: InferenceEngine, c: Component | CBlock) -> list[ModelOutputThunk]:
# 1. get all of the `IEngineContextManager`s
# 2. for each of those, run _sampler
# return all of the resulting mots.Metadata
Metadata
Assignees
Labels
No labels