|
3 | 3 |
|
4 | 4 | """The dispatch background service."""
|
5 | 5 |
|
| 6 | +from __future__ import annotations |
| 7 | + |
6 | 8 | import asyncio
|
7 | 9 | import logging
|
| 10 | +from abc import ABC, abstractmethod |
8 | 11 | from dataclasses import dataclass, field
|
9 | 12 | from datetime import datetime, timedelta, timezone
|
10 | 13 | from heapq import heappop, heappush
|
| 14 | +from typing import Callable |
11 | 15 |
|
12 | 16 | import grpc.aio
|
13 | 17 | from frequenz.channels import Broadcast, Receiver, select, selected_from
|
|
23 | 27 | """The logger for this module."""
|
24 | 28 |
|
25 | 29 |
|
| 30 | +class _MergeStrategy(ABC): |
| 31 | + """Base class for strategies to merge running intervals.""" |
| 32 | + |
| 33 | + @abstractmethod |
| 34 | + def _get_filter_function( |
| 35 | + self, |
| 36 | + scheduler: DispatchScheduler, |
| 37 | + ) -> Callable[[Dispatch], bool]: |
| 38 | + """Get a filter function for dispatches. |
| 39 | +
|
| 40 | + Args: |
| 41 | + scheduler: The dispatch scheduler. |
| 42 | +
|
| 43 | + Returns: |
| 44 | + A filter function. |
| 45 | + """ |
| 46 | + |
| 47 | + |
| 48 | +class MergeByType(_MergeStrategy): |
| 49 | + """Merge running intervals based on the dispatch type.""" |
| 50 | + |
| 51 | + def __init__(self) -> None: |
| 52 | + """Initialize the strategy.""" |
| 53 | + self._scheduler: DispatchScheduler |
| 54 | + self._new_dispatch: Dispatch |
| 55 | + |
| 56 | + def _get_filter_function( |
| 57 | + self, scheduler: DispatchScheduler |
| 58 | + ) -> Callable[[Dispatch], bool]: |
| 59 | + """Get a filter function for dispatches.""" |
| 60 | + self._scheduler = scheduler |
| 61 | + return self._filter_func |
| 62 | + |
| 63 | + def _criteria(self, dispatch: Dispatch) -> bool: |
| 64 | + """Define the criteria for checking other running dispatches.""" |
| 65 | + return dispatch.type == self._new_dispatch.type |
| 66 | + |
| 67 | + def _filter_func(self, new_dispatch: Dispatch) -> bool: |
| 68 | + """Filter dispatches based on the merge strategy. |
| 69 | +
|
| 70 | + Keeps start events. |
| 71 | + Keeps stop events only if no other dispatches matching the |
| 72 | + strategy's criteria are running. |
| 73 | + """ |
| 74 | + if new_dispatch.started: |
| 75 | + return True |
| 76 | + |
| 77 | + self._new_dispatch = new_dispatch |
| 78 | + |
| 79 | + # pylint: disable=protected-access |
| 80 | + other_dispatches_running = any( |
| 81 | + dispatch.started |
| 82 | + for dispatch in self._scheduler._dispatches.values() |
| 83 | + if self._criteria(dispatch) |
| 84 | + ) |
| 85 | + # pylint: enable=protected-access |
| 86 | + |
| 87 | + return not other_dispatches_running |
| 88 | + |
| 89 | + |
| 90 | +class MergeByTypeTarget(MergeByType): |
| 91 | + """Merge running intervals based on the dispatch type and target.""" |
| 92 | + |
| 93 | + def _criteria(self, dispatch: Dispatch) -> bool: |
| 94 | + """Define the criteria for checking other running dispatches.""" |
| 95 | + return ( |
| 96 | + dispatch.type == self._new_dispatch.type |
| 97 | + and dispatch.target == self._new_dispatch.target |
| 98 | + ) |
| 99 | + |
| 100 | + |
26 | 101 | # pylint: disable=too-many-instance-attributes
|
27 | 102 | class DispatchScheduler(BackgroundService):
|
28 | 103 | """Dispatch background service.
|
@@ -119,54 +194,35 @@ def new_lifecycle_events_receiver(self, type: str) -> Receiver[DispatchEvent]:
|
119 | 194 | )
|
120 | 195 |
|
121 | 196 | async def new_running_state_event_receiver(
|
122 |
| - self, type: str, *, unify_running_intervals: bool = True |
| 197 | + self, type: str, *, merge_strategy: _MergeStrategy | None = None |
123 | 198 | ) -> Receiver[Dispatch]:
|
124 | 199 | """Create a new receiver for running state events of the specified type.
|
125 | 200 |
|
126 |
| - If `unify_running_intervals` is True, running intervals from multiple |
127 |
| - dispatches of the same type are considered as one continuous running |
128 |
| - period. In this mode, any stop events are ignored as long as at least |
129 |
| - one dispatch remains active. |
| 201 | + `merge_strategy` can be one of `MergeByType` or `MergeByTypeTarget`. |
| 202 | + If set, running intervals from multiple dispatches will be merged, |
| 203 | + depending on the chosen strategy. |
| 204 | + When merging, stop events are ignored as long as at least one |
| 205 | + merge-criteria-matching dispatch remains active. |
130 | 206 |
|
131 | 207 | Args:
|
132 | 208 | type: The type of events to receive.
|
133 |
| - unify_running_intervals: Whether to unify running intervals. |
134 |
| -
|
| 209 | + merge_strategy: The merge strategy to use. |
135 | 210 | Returns:
|
136 | 211 | A new receiver for running state status.
|
137 | 212 | """
|
138 |
| - # Find all matching dispatches based on the type and collect them |
139 | 213 | dispatches = [
|
140 | 214 | dispatch for dispatch in self._dispatches.values() if dispatch.type == type
|
141 | 215 | ]
|
142 | 216 |
|
143 |
| - # Create receiver with enough capacity to hold all matching dispatches |
144 | 217 | receiver = self._running_state_status_channel.new_receiver(
|
145 | 218 | limit=max(1, len(dispatches))
|
146 | 219 | ).filter(lambda dispatch: dispatch.type == type)
|
147 | 220 |
|
148 |
| - if unify_running_intervals: |
149 |
| - |
150 |
| - def _is_type_still_running(new_dispatch: Dispatch) -> bool: |
151 |
| - """Merge time windows of running dispatches. |
152 |
| -
|
153 |
| - Any event that would cause a stop is filtered if at least one |
154 |
| - dispatch of the same type is running. |
155 |
| - """ |
156 |
| - if new_dispatch.started: |
157 |
| - return True |
158 |
| - |
159 |
| - other_dispatches_running = any( |
160 |
| - dispatch.started |
161 |
| - for dispatch in self._dispatches.values() |
162 |
| - if dispatch.type == type |
163 |
| - ) |
164 |
| - # If no other dispatches are running, we can allow the stop event |
165 |
| - return not other_dispatches_running |
166 |
| - |
167 |
| - receiver = receiver.filter(_is_type_still_running) |
| 221 | + if merge_strategy: |
| 222 | + # pylint: disable=protected-access |
| 223 | + receiver = receiver.filter(merge_strategy._get_filter_function(self)) |
| 224 | + # pylint: enable=protected-access |
168 | 225 |
|
169 |
| - # Send all matching dispatches to the receiver |
170 | 226 | for dispatch in dispatches:
|
171 | 227 | await self._send_running_state_change(dispatch)
|
172 | 228 |
|
|
0 commit comments