|
1 | 1 | import typing as t
|
2 | 2 | from abc import ABC, abstractmethod
|
| 3 | +from collections import defaultdict |
3 | 4 |
|
4 | 5 | if t.TYPE_CHECKING:
|
5 | 6 | from superduper.base.datalayer import Datalayer
|
6 | 7 | from superduper.components.component import Component
|
7 | 8 |
|
8 | 9 |
|
| 10 | +class Bookkeeping(ABC): |
| 11 | + """Mixin class for tracking components and associated tools.""" |
| 12 | + |
| 13 | + def __init__(self): |
| 14 | + self.component_uuid_mapping = defaultdict(set) |
| 15 | + self.uuid_component_mapping = {} |
| 16 | + self.tool_uuid_mapping = defaultdict(set) |
| 17 | + self.uuid_tool_mapping = {} |
| 18 | + self.tools = {} |
| 19 | + |
| 20 | + def build_tool(self, component: 'Component'): |
| 21 | + """Build a tool from a component. |
| 22 | +
|
| 23 | + :param component: Component to build tool from. |
| 24 | + """ |
| 25 | + pass |
| 26 | + |
| 27 | + def get_tool(self, uuid: str): |
| 28 | + """Get the tool from a uuid. |
| 29 | +
|
| 30 | + :param uuid: UUID of the tool. |
| 31 | + """ |
| 32 | + tool_id = self.uuid_tool_mapping[uuid] |
| 33 | + return self.tools[tool_id] |
| 34 | + |
| 35 | + def put_component(self, component: 'Component', **kwargs): |
| 36 | + """Put a component to the backend. |
| 37 | +
|
| 38 | + :param component: Component to put. |
| 39 | + :param kwargs: kwargs dictionary. |
| 40 | + """ |
| 41 | + tool = self.build_tool(component) |
| 42 | + tool.db = self.db |
| 43 | + self.component_uuid_mapping[(component.component, component.identifier)].add( |
| 44 | + component.uuid |
| 45 | + ) |
| 46 | + self.uuid_component_mapping[component.uuid] = ( |
| 47 | + component.component, |
| 48 | + component.identifier, |
| 49 | + ) |
| 50 | + self.uuid_tool_mapping[component.uuid] = tool.identifier |
| 51 | + self.tool_uuid_mapping[tool.identifier].add(component.uuid) |
| 52 | + self.tools[tool.identifier] = tool |
| 53 | + tool.initialize(**kwargs) |
| 54 | + |
| 55 | + def drop_component(self, component: str, identifier: str): |
| 56 | + """Drop the component from backend. |
| 57 | +
|
| 58 | + :param component: Component name. |
| 59 | + :param identifier: Component identifier. |
| 60 | + """ |
| 61 | + uuids = self.component_uuid_mapping[(component, identifier)] |
| 62 | + tool_ids = [] |
| 63 | + for uuid in uuids: |
| 64 | + del self.uuid_component_mapping[uuid] |
| 65 | + tool_id = self.uuid_tool_mapping[uuid] |
| 66 | + tool_ids.append(tool_id) |
| 67 | + del self.uuid_tool_mapping[uuid] |
| 68 | + self.tool_uuid_mapping[tool_id].remove(uuid) |
| 69 | + if not self.tool_uuid_mapping[tool_id]: |
| 70 | + self.tools[tool_id].drop() |
| 71 | + del self.tools[tool_id] |
| 72 | + del self.component_uuid_mapping[(component, identifier)] |
| 73 | + |
| 74 | + def drop(self): |
| 75 | + """Drop the backend.""" |
| 76 | + for tool in self.tools.values(): |
| 77 | + tool.drop() |
| 78 | + self.component_uuid_mapping = defaultdict(set) |
| 79 | + self.uuid_component_mapping = {} |
| 80 | + self.tool_uuid_mapping = defaultdict(set) |
| 81 | + self.uuid_tool_mapping = {} |
| 82 | + self.tools = {} |
| 83 | + |
| 84 | + def list_components(self): |
| 85 | + """List components, and identifiers deployed.""" |
| 86 | + return list(self.component_uuid_mapping.keys()) |
| 87 | + |
| 88 | + def list_tools(self): |
| 89 | + """List tools deployed.""" |
| 90 | + return list(self.tools.keys()) |
| 91 | + |
| 92 | + def list_uuids(self): |
| 93 | + """List uuids deployed.""" |
| 94 | + return list(self.uuid_component_mapping.keys()) |
| 95 | + |
| 96 | + |
9 | 97 | class BaseBackend(ABC):
|
10 | 98 | """Base backend class for cluster client."""
|
11 | 99 |
|
@@ -34,34 +122,19 @@ def initialize(self):
|
34 | 122 | """To be called on program start."""
|
35 | 123 | pass
|
36 | 124 |
|
37 |
| - def put_component(self, component: 'Component', **kwargs): |
| 125 | + @abstractmethod |
| 126 | + def put_component(self, component: 'Component'): |
38 | 127 | """Add a component to the deployment.
|
39 | 128 |
|
40 | 129 | :param component: ``Component`` to put.
|
41 |
| - :param kwargs: kwargs dictionary. |
42 | 130 | """
|
43 |
| - # This is to make sure that we only have 1 version |
44 |
| - # of each component implemented at any given time |
45 |
| - # TODO: get identifier in string component argument. |
46 |
| - identifier = '' |
47 |
| - if isinstance(component, str): |
48 |
| - uuid = component |
49 |
| - else: |
50 |
| - uuid = component.uuid |
51 |
| - identifier = component.identifier |
52 |
| - |
53 |
| - if uuid in self.list_uuids(): |
54 |
| - return |
55 |
| - if identifier in self.list_components(): |
56 |
| - del self[component.identifier] |
57 |
| - |
58 |
| - self._put(component, **kwargs) |
59 | 131 |
|
60 | 132 | @abstractmethod
|
61 | 133 | def drop_component(self, component: str, identifier: str):
|
62 | 134 | """Drop the component from backend.
|
63 | 135 |
|
64 |
| - :param identifier: Component identifier |
| 136 | + :param component: Component name. |
| 137 | + :param identifier: Component identifier. |
65 | 138 | """
|
66 | 139 |
|
67 | 140 | @property
|
|
0 commit comments