diff --git a/docs/task_on_kart.rst b/docs/task_on_kart.rst index ce52c6d5..09e7a59f 100644 --- a/docs/task_on_kart.rst +++ b/docs/task_on_kart.rst @@ -286,3 +286,24 @@ If you want to dump csv file with other encodings, you can use `encoding` parame def output(self): return self.make_target('file_name.csv', processor=CsvFileProcessor(encoding='cp932')) # This will dump csv as 'cp932' which is used in Windows. + +Cache output in memory instead of dumping to files +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +You can use :class:`~InMemoryTarget` to cache output in memory instead of dumping to files by calling :func:`~gokart.target.make_inmemory_target`. + +.. code:: python + + from gokart.in_memory.target import make_inmemory_target + + def output(self): + unique_id = self.make_unique_id() if use_unique_id else None + # TaskLock is not supported in InMemoryTarget, so it's dummy + task_lock_params = make_task_lock_params( + file_path='dummy_path', + unique_id=unique_id, + redis_host=None, + redis_port=None, + redis_timeout=self.redis_timeout, + raise_task_lock_exception_on_collision=False, + ) + return make_inmemory_target('dummy_path', task_lock_params, unique_id) \ No newline at end of file diff --git a/gokart/in_memory/__init__.py b/gokart/in_memory/__init__.py new file mode 100644 index 00000000..69e7e4c3 --- /dev/null +++ b/gokart/in_memory/__init__.py @@ -0,0 +1,2 @@ +from .repository import InMemoryCacheRepository # noqa:F401 +from .target import InMemoryTarget, make_inmemory_target # noqa:F401 diff --git a/gokart/in_memory/data.py b/gokart/in_memory/data.py new file mode 100644 index 00000000..a01c3ad2 --- /dev/null +++ b/gokart/in_memory/data.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Protocol + + +class BaseData(Protocol): ... + + +@dataclass +class InMemoryData(BaseData): + value: Any + last_modification_time: datetime + + @classmethod + def create_data(self, value: Any) -> 'InMemoryData': + return InMemoryData(value=value, last_modification_time=datetime.now()) diff --git a/gokart/in_memory/repository.py b/gokart/in_memory/repository.py new file mode 100644 index 00000000..a0ef0e76 --- /dev/null +++ b/gokart/in_memory/repository.py @@ -0,0 +1,103 @@ +from abc import ABC, abstractmethod +from typing import Any, Iterator + +from .data import InMemoryData + + +class CacheScheduler(ABC): + def __new__(cls): + if not hasattr(cls, '__instance'): + setattr(cls, '__instance', super().__new__(cls)) + return getattr(cls, '__instance') + + @abstractmethod + def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): ... + + @abstractmethod + def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): ... + + @abstractmethod + def clear(self): ... + + +class DoNothingScheduler(CacheScheduler): + def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): + pass + + def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): + pass + + def clear(self): + pass + + +# TODO: ambiguous class name +class InstantScheduler(CacheScheduler): + def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): + repository.remove(key) + + def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): + pass + + def clear(self): + pass + + +class InMemoryCacheRepository: + _cache: dict[str, InMemoryData] = {} + _scheduler: CacheScheduler = DoNothingScheduler() + + def __new__(cls): + if not hasattr(cls, '__instance'): + cls.__instance = super().__new__(cls) + return cls.__instance + + @classmethod + def set_scheduler(cls, cache_scheduler: CacheScheduler): + cls._scheduler = cache_scheduler + + def get_value(self, key: str) -> Any: + data = self._get_data(key) + self._scheduler.get_hook(self, key, data) + return data.value + + def get_last_modification_time(self, key: str): + return self._get_data(key).last_modification_time + + def _get_data(self, key: str) -> InMemoryData: + return self._cache[key] + + def set_value(self, key: str, obj: Any) -> None: + data = InMemoryData.create_data(obj) + self._scheduler.set_hook(self, key, data) + self._set_data(key, data) + + def _set_data(self, key: str, data: InMemoryData): + self._cache[key] = data + + def has(self, key: str) -> bool: + return key in self._cache + + def remove(self, key: str) -> None: + assert self.has(key), f'{key} does not exist.' + del self._cache[key] + + def empty(self) -> bool: + return not self._cache + + @classmethod + def clear(cls) -> None: + cls._cache.clear() + cls._scheduler.clear() + + def get_gen(self) -> Iterator[tuple[str, Any]]: + for key, data in self._cache.items(): + yield key, data.value + + @property + def size(self) -> int: + return len(self._cache) + + @property + def scheduler(self) -> CacheScheduler: + return self._scheduler diff --git a/gokart/in_memory/target.py b/gokart/in_memory/target.py new file mode 100644 index 00000000..84ce468d --- /dev/null +++ b/gokart/in_memory/target.py @@ -0,0 +1,54 @@ +from datetime import datetime +from logging import warning +from typing import Any, Optional + +from gokart.in_memory.repository import InMemoryCacheRepository +from gokart.target import TargetOnKart, TaskLockParams + +_repository = InMemoryCacheRepository() + + +class InMemoryTarget(TargetOnKart): + def __init__(self, data_key: str, task_lock_param: TaskLockParams): + if task_lock_param.should_task_lock: + warning(f'Redis in {self.__class__.__name__} is not supported now.') + + self._data_key = data_key + self._task_lock_params = task_lock_param + self._repository = InMemoryCacheRepository() + + def _exists(self) -> bool: + return self._repository.has(self._data_key) + + def _get_task_lock_params(self) -> TaskLockParams: + return self._task_lock_params + + def _load(self) -> Any: + return self._repository.get_value(self._data_key) + + def _dump(self, obj: Any) -> None: + return self._repository.set_value(self._data_key, obj) + + def _remove(self) -> None: + self._repository.remove(self._data_key) + + def _last_modification_time(self) -> datetime: + if not self._repository.has(self._data_key): + raise ValueError(f'No object(s) which id is {self._data_key} are stored before.') + time = self._repository.get_last_modification_time(self._data_key) + return time + + def _path(self) -> str: + # TODO: this module name `_path` migit not be appropriate + return self._data_key + + +def _make_data_key(data_key: str, unique_id: Optional[str] = None): + if not unique_id: + return data_key + return data_key + '_' + unique_id + + +def make_inmemory_target(data_key: str, task_lock_params: TaskLockParams, unique_id: Optional[str] = None): + _data_key = _make_data_key(data_key, unique_id) + return InMemoryTarget(_data_key, task_lock_params) diff --git a/gokart/task.py b/gokart/task.py index f577f64b..8eac941d 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -20,9 +20,10 @@ import gokart import gokart.target -from gokart.conflict_prevention_lock.task_lock import make_task_lock_params, make_task_lock_params_for_run +from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params, make_task_lock_params_for_run from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock from gokart.file_processor import FileProcessor +from gokart.in_memory.target import make_inmemory_target from gokart.pandas_type_config import PandasTypeConfigMap from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter from gokart.target import TargetOnKart @@ -105,6 +106,9 @@ class TaskOnKart(luigi.Task, Generic[T]): default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False ) should_lock_run: bool = ExplicitBoolParameter(default=False, significant=False, description='Whether to use redis lock or not at task run.') + cache_in_memory_by_default: bool = ExplicitBoolParameter( + default=False, significant=False, description='If `True`, output is stored on a memory instead of files unless specified.' + ) @property def priority(self): @@ -134,11 +138,13 @@ def __init__(self, *args, **kwargs): task_lock_params = make_task_lock_params_for_run(task_self=self) self.run = wrap_run_with_lock(run_func=self.run, task_lock_params=task_lock_params) # type: ignore + self.make_default_target = self.make_target if not self.cache_in_memory_by_default else self.make_cache_target + def input(self) -> FlattenableItems[TargetOnKart]: return super().input() def output(self) -> FlattenableItems[TargetOnKart]: - return self.make_target() + return self.make_default_target() def requires(self) -> FlattenableItems['TaskOnKart']: tasks = self.make_task_instance_dictionary() @@ -229,6 +235,21 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather ) + def make_cache_target(self, data_key: Optional[str] = None, use_unique_id: bool = True): + _data_key = data_key if data_key else os.path.join(self.__module__.replace('.', '/'), type(self).__name__) + unique_id = self.make_unique_id() if use_unique_id else None + # TODO: combine with redis + task_lock_params = TaskLockParams( + redis_host=None, + redis_port=None, + redis_timeout=None, + redis_key='redis_key', + should_task_lock=False, + raise_task_lock_exception_on_collision=False, + lock_extend_seconds=-1, + ) + return make_inmemory_target(_data_key, task_lock_params, unique_id) + def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip') diff --git a/test/in_memory/test_in_memory_target.py b/test/in_memory/test_in_memory_target.py new file mode 100644 index 00000000..5c08634b --- /dev/null +++ b/test/in_memory/test_in_memory_target.py @@ -0,0 +1,56 @@ +from datetime import datetime +from time import sleep + +import pytest + +from gokart.conflict_prevention_lock.task_lock import TaskLockParams +from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget, make_inmemory_target + + +class TestInMemoryTarget: + @pytest.fixture + def task_lock_params(self): + return TaskLockParams( + redis_host=None, + redis_port=None, + redis_timeout=None, + redis_key='dummy', + should_task_lock=False, + raise_task_lock_exception_on_collision=False, + lock_extend_seconds=0, + ) + + @pytest.fixture + def target(self, task_lock_params: TaskLockParams): + return make_inmemory_target(data_key='dummy_key', task_lock_params=task_lock_params) + + @pytest.fixture(autouse=True) + def clear_repo(self): + InMemoryCacheRepository().clear() + + def test_dump_and_load_data(self, target: InMemoryTarget): + dumped = 'dummy_data' + target.dump(dumped) + loaded = target.load() + assert loaded == dumped + + def test_exist(self, target: InMemoryTarget): + assert not target.exists() + target.dump('dummy_data') + assert target.exists() + + def test_last_modified_time(self, target: InMemoryTarget): + input = 'dummy_data' + target.dump(input) + time = target.last_modification_time() + assert isinstance(time, datetime) + + sleep(0.1) + another_input = 'another_data' + target.dump(another_input) + another_time = target.last_modification_time() + assert time < another_time + + target.remove() + with pytest.raises(ValueError): + assert target.last_modification_time() diff --git a/test/in_memory/test_repository.py b/test/in_memory/test_repository.py new file mode 100644 index 00000000..14449937 --- /dev/null +++ b/test/in_memory/test_repository.py @@ -0,0 +1,95 @@ +import time + +import pytest + +from gokart.in_memory import InMemoryCacheRepository +from gokart.in_memory.repository import InstantScheduler + +dummy_num = 100 + + +class TestInMemoryCacheRepository: + @pytest.fixture + def repo(self): + repo = InMemoryCacheRepository() + repo.clear() + return repo + + def test_set(self, repo: InMemoryCacheRepository): + repo.set_value('dummy_key', dummy_num) + assert repo.size == 1 + for key, value in repo.get_gen(): + assert (key, value) == ('dummy_key', dummy_num) + + repo.set_value('another_key', 'another_value') + assert repo.size == 2 + + def test_get(self, repo: InMemoryCacheRepository): + repo.set_value('dummy_key', dummy_num) + repo.set_value('another_key', 'another_value') + + """Raise Error when key doesn't exist.""" + with pytest.raises(KeyError): + repo.get_value('not_exist_key') + + assert repo.get_value('dummy_key') == dummy_num + assert repo.get_value('another_key') == 'another_value' + + def test_empty(self, repo: InMemoryCacheRepository): + assert repo.empty() + repo.set_value('dummmy_key', dummy_num) + assert not repo.empty() + + def test_has(self, repo: InMemoryCacheRepository): + assert not repo.has('dummy_key') + repo.set_value('dummy_key', dummy_num) + assert repo.has('dummy_key') + assert not repo.has('not_exist_key') + + def test_remove(self, repo: InMemoryCacheRepository): + repo.set_value('dummy_key', dummy_num) + + with pytest.raises(AssertionError): + repo.remove('not_exist_key') + + repo.remove('dummy_key') + assert not repo.has('dummy_key') + + def test_last_modification_time(self, repo: InMemoryCacheRepository): + repo.set_value('dummy_key', dummy_num) + date1 = repo.get_last_modification_time('dummy_key') + time.sleep(0.1) + repo.set_value('dummy_key', dummy_num) + date2 = repo.get_last_modification_time('dummy_key') + assert date1 < date2 + + +class TestInstantScheduler: + @pytest.fixture(autouse=True) + def set_scheduler(self): + scheduler = InstantScheduler() + InMemoryCacheRepository.set_scheduler(scheduler) + + @pytest.fixture(autouse=True) + def clear_cache(self): + InMemoryCacheRepository.clear() + + @pytest.fixture + def repo(self): + repo = InMemoryCacheRepository() + return repo + + def test_identity(self): + scheduler1 = InstantScheduler() + scheduler2 = InstantScheduler() + assert id(scheduler1) == id(scheduler2) + + def test_scheduler_type(self, repo: InMemoryCacheRepository): + assert isinstance(repo.scheduler, InstantScheduler) + + def test_volatility(self, repo: InMemoryCacheRepository): + assert repo.empty() + repo.set_value('dummy_key', 100) + assert repo.has('dummy_key') + repo.get_value('dummy_key') + assert not repo.has('dummy_key') diff --git a/test/in_memory/test_task_cached_in_memory.py b/test/in_memory/test_task_cached_in_memory.py new file mode 100644 index 00000000..a874ee6b --- /dev/null +++ b/test/in_memory/test_task_cached_in_memory.py @@ -0,0 +1,118 @@ +from typing import Optional, Type, Union + +import luigi +import pytest + +import gokart +from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget +from gokart.target import SingleFileTarget + + +class DummyTask(gokart.TaskOnKart): + task_namespace = __name__ + param: str = luigi.Parameter() + + def run(self): + self.dump(self.param) + + +class DummyTaskWithDependencies(gokart.TaskOnKart): + task_namespace = __name__ + task: list[gokart.TaskOnKart[str]] = gokart.ListTaskInstanceParameter() + + def run(self): + result = ','.join(self.load()) + self.dump(result) + + +class DumpIntTask(gokart.TaskOnKart[int]): + task_namespace = __name__ + value: int = luigi.IntParameter() + + def run(self): + self.dump(self.value) + + +class AddTask(gokart.TaskOnKart[Union[int, float]]): + a: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter() + b: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter() + + def requires(self): + return dict(a=self.a, b=self.b) + + def run(self): + a = self.load(self.a) + b = self.load(self.b) + self.dump(a + b) + + +class TestTaskOnKartWithCache: + @pytest.fixture(autouse=True) + def clear_repository(slef): + InMemoryCacheRepository().clear() + + @pytest.mark.parametrize('data_key', ['sample_key', None]) + @pytest.mark.parametrize('use_unique_id', [True, False]) + def test_key_identity(self, data_key: Optional[str], use_unique_id: bool): + task = DummyTask(param='param') + ext = '.pkl' + relative_file_path = data_key + ext if data_key else None + target = task.make_target(relative_file_path=relative_file_path, use_unique_id=use_unique_id) + cached_target = task.make_cache_target(data_key=data_key, use_unique_id=use_unique_id) + + target_path = target.path().removeprefix(task.workspace_directory).removesuffix(ext).strip('/') + assert cached_target.path() == target_path + + def test_make_cached_target(self): + task = DummyTask(param='param') + target = task.make_cache_target() + assert isinstance(target, InMemoryTarget) + + @pytest.mark.parametrize(['cache_in_memory_by_default', 'target_type'], [[True, InMemoryTarget], [False, SingleFileTarget]]) + def test_make_default_target(self, cache_in_memory_by_default: bool, target_type: Type[gokart.TaskOnKart]): + task = DummyTask(param='param', cache_in_memory_by_default=cache_in_memory_by_default) + target = task.output() + assert isinstance(target, target_type) + + def test_complete_with_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', cache_in_memory_by_default=True, workspace_directory=tmpdir) + assert not task.complete() + file_target = task.make_target() + file_target.dump('data') + assert not task.complete() + cache_target = task.make_cache_target() + cache_target.dump('data') + assert task.complete() + + def test_complete_without_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', workspace_directory=tmpdir) + assert not task.complete() + cache_target = task.make_cache_target() + cache_target.dump('data') + assert not task.complete() + file_target = task.make_target() + file_target.dump('data') + assert task.complete() + + def test_dump_with_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', cache_in_memory_by_default=True, workspace_directory=tmpdir) + file_target = task.make_target() + cache_target = task.make_cache_target() + task.dump('data') + assert not file_target.exists() + assert cache_target.exists() + + def test_dump_without_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', workspace_directory=tmpdir) + file_target = task.make_target() + cache_target = task.make_cache_target() + task.dump('data') + assert file_target.exists() + assert not cache_target.exists() + + def test_gokart_build(self): + task = AddTask( + a=DumpIntTask(value=2, cache_in_memory_by_default=True), b=DumpIntTask(value=3, cache_in_memory_by_default=True), cache_in_memory_by_default=True + ) + output = gokart.build(task, reset_register=False) + assert output == 5