Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/task_on_kart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,24 @@ If you want to dump csv file with other encodings, you can use `encoding` parame
def output(self):
return self.make_target('file_name.csv', processor=CsvFileProcessor(encoding='cp932'))
# This will dump csv as 'cp932' which is used in Windows.

Cache output in memory instead of dumping to files
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
You can use :class:`~InMemoryTarget` to cache output in memory instead of dumping to files by calling :func:`~gokart.target.make_inmemory_target`.

.. code:: python

from gokart.in_memory.target import make_inmemory_target

def output(self):
unique_id = self.make_unique_id() if use_unique_id else None
# TaskLock is not supported in InMemoryTarget, so it's dummy
task_lock_params = make_task_lock_params(
file_path='dummy_path',
unique_id=unique_id,
redis_host=None,
redis_port=None,
redis_timeout=self.redis_timeout,
raise_task_lock_exception_on_collision=False,
)
return make_inmemory_target('dummy_path', task_lock_params, unique_id)
2 changes: 2 additions & 0 deletions gokart/in_memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .repository import InMemoryCacheRepository # noqa:F401
from .target import InMemoryTarget, make_inmemory_target # noqa:F401
16 changes: 16 additions & 0 deletions gokart/in_memory/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Protocol


class BaseData(Protocol): ...


@dataclass
class InMemoryData(BaseData):
value: Any
last_modification_time: datetime

@classmethod
def create_data(self, value: Any) -> 'InMemoryData':
return InMemoryData(value=value, last_modification_time=datetime.now())
103 changes: 103 additions & 0 deletions gokart/in_memory/repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from abc import ABC, abstractmethod
from typing import Any, Iterator

from .data import InMemoryData


class CacheScheduler(ABC):
def __new__(cls):
if not hasattr(cls, '__instance'):
setattr(cls, '__instance', super().__new__(cls))
return getattr(cls, '__instance')

@abstractmethod
def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): ...

@abstractmethod
def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): ...

@abstractmethod
def clear(self): ...


class DoNothingScheduler(CacheScheduler):
def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData):
pass

def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData):
pass

def clear(self):
pass


# TODO: ambiguous class name
class InstantScheduler(CacheScheduler):
def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData):
repository.remove(key)

def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData):
pass

def clear(self):
pass


class InMemoryCacheRepository:
_cache: dict[str, InMemoryData] = {}
_scheduler: CacheScheduler = DoNothingScheduler()

def __new__(cls):
if not hasattr(cls, '__instance'):
cls.__instance = super().__new__(cls)
return cls.__instance

@classmethod
def set_scheduler(cls, cache_scheduler: CacheScheduler):
cls._scheduler = cache_scheduler

def get_value(self, key: str) -> Any:
data = self._get_data(key)
self._scheduler.get_hook(self, key, data)
return data.value

def get_last_modification_time(self, key: str):
return self._get_data(key).last_modification_time

def _get_data(self, key: str) -> InMemoryData:
return self._cache[key]

def set_value(self, key: str, obj: Any) -> None:
data = InMemoryData.create_data(obj)
self._scheduler.set_hook(self, key, data)
self._set_data(key, data)

def _set_data(self, key: str, data: InMemoryData):
self._cache[key] = data

def has(self, key: str) -> bool:
return key in self._cache

def remove(self, key: str) -> None:
assert self.has(key), f'{key} does not exist.'
del self._cache[key]

def empty(self) -> bool:
return not self._cache

@classmethod
def clear(cls) -> None:
cls._cache.clear()
cls._scheduler.clear()

def get_gen(self) -> Iterator[tuple[str, Any]]:
for key, data in self._cache.items():
yield key, data.value

@property
def size(self) -> int:
return len(self._cache)

@property
def scheduler(self) -> CacheScheduler:
return self._scheduler
54 changes: 54 additions & 0 deletions gokart/in_memory/target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from datetime import datetime
from logging import warning
from typing import Any, Optional

from gokart.in_memory.repository import InMemoryCacheRepository
from gokart.target import TargetOnKart, TaskLockParams

_repository = InMemoryCacheRepository()


class InMemoryTarget(TargetOnKart):
def __init__(self, data_key: str, task_lock_param: TaskLockParams):
if task_lock_param.should_task_lock:
warning(f'Redis in {self.__class__.__name__} is not supported now.')

self._data_key = data_key
self._task_lock_params = task_lock_param
self._repository = InMemoryCacheRepository()

def _exists(self) -> bool:
return self._repository.has(self._data_key)

def _get_task_lock_params(self) -> TaskLockParams:
return self._task_lock_params

def _load(self) -> Any:
return self._repository.get_value(self._data_key)

def _dump(self, obj: Any) -> None:
return self._repository.set_value(self._data_key, obj)

def _remove(self) -> None:
self._repository.remove(self._data_key)

def _last_modification_time(self) -> datetime:
if not self._repository.has(self._data_key):
raise ValueError(f'No object(s) which id is {self._data_key} are stored before.')
time = self._repository.get_last_modification_time(self._data_key)
return time

def _path(self) -> str:
# TODO: this module name `_path` migit not be appropriate
return self._data_key


def _make_data_key(data_key: str, unique_id: Optional[str] = None):
if not unique_id:
return data_key
return data_key + '_' + unique_id


def make_inmemory_target(data_key: str, task_lock_params: TaskLockParams, unique_id: Optional[str] = None):
_data_key = _make_data_key(data_key, unique_id)
return InMemoryTarget(_data_key, task_lock_params)
25 changes: 23 additions & 2 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

import gokart
import gokart.target
from gokart.conflict_prevention_lock.task_lock import make_task_lock_params, make_task_lock_params_for_run
from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params, make_task_lock_params_for_run
from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock
from gokart.file_processor import FileProcessor
from gokart.in_memory.target import make_inmemory_target
from gokart.pandas_type_config import PandasTypeConfigMap
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
from gokart.target import TargetOnKart
Expand Down Expand Up @@ -105,6 +106,9 @@ class TaskOnKart(luigi.Task, Generic[T]):
default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False
)
should_lock_run: bool = ExplicitBoolParameter(default=False, significant=False, description='Whether to use redis lock or not at task run.')
cache_in_memory_by_default: bool = ExplicitBoolParameter(
default=False, significant=False, description='If `True`, output is stored on a memory instead of files unless specified.'
)

@property
def priority(self):
Expand Down Expand Up @@ -134,11 +138,13 @@ def __init__(self, *args, **kwargs):
task_lock_params = make_task_lock_params_for_run(task_self=self)
self.run = wrap_run_with_lock(run_func=self.run, task_lock_params=task_lock_params) # type: ignore

self.make_default_target = self.make_target if not self.cache_in_memory_by_default else self.make_cache_target

def input(self) -> FlattenableItems[TargetOnKart]:
return super().input()

def output(self) -> FlattenableItems[TargetOnKart]:
return self.make_target()
return self.make_default_target()

def requires(self) -> FlattenableItems['TaskOnKart']:
tasks = self.make_task_instance_dictionary()
Expand Down Expand Up @@ -229,6 +235,21 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b
file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather
)

def make_cache_target(self, data_key: Optional[str] = None, use_unique_id: bool = True):
_data_key = data_key if data_key else os.path.join(self.__module__.replace('.', '/'), type(self).__name__)
unique_id = self.make_unique_id() if use_unique_id else None
# TODO: combine with redis
task_lock_params = TaskLockParams(
redis_host=None,
redis_port=None,
redis_timeout=None,
redis_key='redis_key',
should_task_lock=False,
raise_task_lock_exception_on_collision=False,
lock_extend_seconds=-1,
)
return make_inmemory_target(_data_key, task_lock_params, unique_id)

def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart:
formatted_relative_file_path = (
relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip')
Expand Down
56 changes: 56 additions & 0 deletions test/in_memory/test_in_memory_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from datetime import datetime
from time import sleep

import pytest

from gokart.conflict_prevention_lock.task_lock import TaskLockParams
from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget, make_inmemory_target


class TestInMemoryTarget:
@pytest.fixture
def task_lock_params(self):
return TaskLockParams(
redis_host=None,
redis_port=None,
redis_timeout=None,
redis_key='dummy',
should_task_lock=False,
raise_task_lock_exception_on_collision=False,
lock_extend_seconds=0,
)

@pytest.fixture
def target(self, task_lock_params: TaskLockParams):
return make_inmemory_target(data_key='dummy_key', task_lock_params=task_lock_params)

@pytest.fixture(autouse=True)
def clear_repo(self):
InMemoryCacheRepository().clear()

def test_dump_and_load_data(self, target: InMemoryTarget):
dumped = 'dummy_data'
target.dump(dumped)
loaded = target.load()
assert loaded == dumped

def test_exist(self, target: InMemoryTarget):
assert not target.exists()
target.dump('dummy_data')
assert target.exists()

def test_last_modified_time(self, target: InMemoryTarget):
input = 'dummy_data'
target.dump(input)
time = target.last_modification_time()
assert isinstance(time, datetime)

sleep(0.1)
another_input = 'another_data'
target.dump(another_input)
another_time = target.last_modification_time()
assert time < another_time

target.remove()
with pytest.raises(ValueError):
assert target.last_modification_time()
Loading