Skip to content

Commit

Permalink
Add iteration_index
Browse files Browse the repository at this point in the history
  • Loading branch information
doruirimescu committed Jun 6, 2024
1 parent f1a7509 commit 02e08c5
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 33 deletions.
21 changes: 13 additions & 8 deletions src/stateful_data_processor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from logging import Logger, getLogger
import signal
from stateful_data_processor.file_rw import FileRW
from typing import Optional, Any
from typing import Optional, Any, Collection


"""
Expand All @@ -24,7 +24,10 @@

class StatefulDataProcessor:
def __init__(
self, file_rw: FileRW, logger: Optional[Logger]=None, should_read: Optional[bool] = True
self,
file_rw: FileRW,
logger: Optional[Logger] = None,
should_read: Optional[bool] = True,
):
self.file_rw = file_rw
if logger is None:
Expand All @@ -51,13 +54,13 @@ def __init__(
signal.signal(signal.SIGTERM, self._signal_handler)

@abstractmethod
def process_data(self, items, *args, **kwargs):
def process_data(self, items: Collection[Any], *args, **kwargs):
"""Template method for processing data. Get data, and call _iterate_items.
Arguments are forwarded to _iterate_items. You can override this method to implement
more custom processing."""
self._iterate_items(items, *args, **kwargs)

def _iterate_items(self, items, *args, **kwargs):
def _iterate_items(self, items: Collection[Any], *args, **kwargs):
"""General iteration method for processing items. This should be called from process_data.
This method will iterate through the items and call process_item for each item.
If an item is already processed, it will skip it.
Expand All @@ -68,17 +71,19 @@ def _iterate_items(self, items, *args, **kwargs):
self.logger.info("All items already processed, skipping...")
return

for item in items:
for iteration_index, item in enumerate(items):
if item in self.data:
self.logger.info(f"Item {item} already processed, skipping...")
continue

self.process_item(item, *args, **kwargs)
self.process_item(item, iteration_index, *args, **kwargs)
self.logger.info(f"Processed item {item} {len(self.data)} / {items_len}")
self.logger.info("Finished processing all items.")

@abstractmethod
def process_item(self, item: Any, *args: Any, **kwargs: Any) -> Any:
def process_item(
self, item: Any, iteration_index: int, *args: Any, **kwargs: Any
) -> Any:
"""Process a single item."""
pass

Expand All @@ -89,7 +94,7 @@ def _signal_handler(self, signum, frame):
self.logger.info("Data saved, exiting.")
exit(0)

def run(self, items, *args, **kwargs):
def run(self, items: Collection[Any], *args, **kwargs):
"""Main method to run the processor."""
if not items:
self.logger.error("No items to process.")
Expand Down
90 changes: 65 additions & 25 deletions tests/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class QueueHandler(logging.Handler):
"""
This is a logging handler which sends log messages to a multiprocessing queue.
"""

def __init__(self, log_queue: multiprocessing.Queue):
super().__init__()
self.log_queue = log_queue
Expand All @@ -25,25 +26,36 @@ def emit(self, record):
except Exception:
self.handleError(record)

class SymbolGetter:
"""
This class could be a complex class that gets a list of data from a database or API.
"""

def get_symbols(self) -> List[str]:
return ["a", "b", "c"]


class SymbolProcessor(StatefulDataProcessor):
"""
This class processes a list of symbols.
"""
def process_item(self, item: str, delay=0.0, *args: Any, **kwargs: Any) -> None:

def process_item(
self, item: str, iteration_index: int, delay=0.0, *args: Any, **kwargs: Any
) -> None:
processed = item + "!"
self.data[item] = processed
time.sleep(delay)


class NumberProcessor(StatefulDataProcessor):
LOOKUP = ["a", "b", "c"]

def process_item(
self, item: int, iteration_index: int, delay=0.0, *args: Any, **kwargs: Any
) -> None:
'''
Process an item by squaring it and adding the corresponding letter from the lookup.
Note: the item and iteration_index are coming from the process_data method.
delay is an argument that must be supplied by the user through the run method.
'''
processed = NumberProcessor.LOOKUP[iteration_index] + str(item ** 2)
self.data[item] = processed
time.sleep(delay)


class TestStatefulDataProcessor(unittest.TestCase):
def setUp(self):
self.file_rw = JsonFileRW(TEST_FILE_JSON_PATH)
Expand All @@ -56,7 +68,9 @@ def tearDown(self) -> None:
del self.file_rw

def test_items_must_be_unique(self):
processor = SymbolProcessor(self.file_rw, should_read=False, logger=self.mock_logger)
processor = SymbolProcessor(
self.file_rw, should_read=False, logger=self.mock_logger
)
processor.run(items=["a", "a", "b"], delay=0)
calls = [
call("Items must be unique."),
Expand All @@ -80,15 +94,15 @@ def test_process_data(self):
wait_for_file(TEST_FILE_JSON_PATH)

def test_processes_data_and_retrieves_completed_state_after_deletion(self):
processor = SymbolProcessor(
self.file_rw, should_read=False
)
processor = SymbolProcessor(self.file_rw, should_read=False)
processor.run(items=["a", "b", "c"], delay=0)

wait_for_file(TEST_FILE_JSON_PATH)
del processor

processor = SymbolProcessor(self.file_rw, should_read=True, logger=self.mock_logger)
processor = SymbolProcessor(
self.file_rw, should_read=True, logger=self.mock_logger
)
calls = [call(f"Read from file: {TEST_FILE_JSON_PATH} data of len 3")]
self.mock_logger.info.assert_has_calls(calls, any_order=True)
self.assertEqual(processor.data, {"a": "a!", "b": "b!", "c": "c!"})
Expand All @@ -98,15 +112,16 @@ def test_processes_data_and_retrieves_completed_state_after_deletion(self):
self.assertEqual(data, {"a": "a!", "b": "b!", "c": "c!"})

def test_skip_already_processed_items(self):
processor = SymbolProcessor(self.file_rw, should_read=False, logger=self.mock_logger)
processor = SymbolProcessor(
self.file_rw, should_read=False, logger=self.mock_logger
)
processor.run(items=["a", "b", "c"], delay=0)
self.assertEqual(processor.data, {"a": "a!", "b": "b!", "c": "c!"})

processor.run(items=["a", "b", "c"], delay=0)
calls = [call("All items already processed, skipping...")]
self.mock_logger.info.assert_has_calls(calls, any_order=True)


def test_resumes_after_termination_with_saved_state(self):
log_queue = multiprocessing.Queue()

Expand All @@ -116,11 +131,15 @@ def test_resumes_after_termination_with_saved_state(self):
queue_handler = QueueHandler(log_queue)
logger.addHandler(queue_handler)

symbol_processor = SymbolProcessor(self.file_rw, should_read=False, logger=logger)
symbol_processor = SymbolProcessor(
self.file_rw, should_read=False, logger=logger
)

# Add a large enough delay to ensure the process is terminated before it processes another item

p = multiprocessing.Process(target=symbol_processor.run, kwargs={"items": ["a", "b", "c"], "delay": 5})
p = multiprocessing.Process(
target=symbol_processor.run, kwargs={"items": ["a", "b", "c"], "delay": 5}
)
p.start()

# wait for process to start and process one item
Expand All @@ -135,14 +154,35 @@ def test_resumes_after_termination_with_saved_state(self):
while not log_queue.empty():
log_message = log_queue.get()
self.mock_logger.info(log_message)
calls=[call("Interrupt signal received, saving data..."),
call("Data saved, exiting.")]
calls = [
call("Interrupt signal received, saving data..."),
call("Data saved, exiting."),
]
self.mock_logger.info.assert_has_calls(calls, any_order=True)

processor = SymbolProcessor(self.file_rw, should_read=True, logger=self.mock_logger)
processor = SymbolProcessor(
self.file_rw, should_read=True, logger=self.mock_logger
)
processor.run(items=["a", "b", "c"], delay=0)
calls = [call("Item a already processed, skipping..."),
call("Processed item b 2 / 3"),
call("Processed item c 3 / 3"),
call("Finished processing all items.")]
calls = [
call("Item a already processed, skipping..."),
call("Processed item b 2 / 3"),
call("Processed item c 3 / 3"),
call("Finished processing all items."),
]
self.mock_logger.info.assert_has_calls(calls, any_order=True)

def test_number_processor(self):
processor = NumberProcessor(
self.file_rw, should_read=False, logger=self.mock_logger
)
processor.run(items=[1, 2, 3], delay=0)
self.assertEqual(processor.data, {1: "a1", 2: "b4", 3: "c9"})

calls = [
call("Processed item 1 1 / 3"),
call("Processed item 2 2 / 3"),
call("Processed item 3 3 / 3"),
call("Finished processing all items."),
]
self.mock_logger.info.assert_has_calls(calls, any_order=True)

0 comments on commit 02e08c5

Please sign in to comment.