diff --git a/src/baskerville/models/config.py b/src/baskerville/models/config.py index 8e8cc9f5..c5e2c5e1 100644 --- a/src/baskerville/models/config.py +++ b/src/baskerville/models/config.py @@ -246,6 +246,7 @@ class EngineConfig(Config): simulation = None datetime_format = '%Y-%m-%d %H:%M:%S' cache_path = None + save_cache_to_storage = False storage_path = None cache_expire_time = None cache_load_past = False diff --git a/src/baskerville/models/pipeline_tasks/service_provider.py b/src/baskerville/models/pipeline_tasks/service_provider.py index cfabdb27..b7c69c8a 100644 --- a/src/baskerville/models/pipeline_tasks/service_provider.py +++ b/src/baskerville/models/pipeline_tasks/service_provider.py @@ -106,6 +106,7 @@ def initialize_request_set_cache_service(self): expire_if_longer_than=self.config.engine.cache_expire_time, path=os.path.join(self.config.engine.storage_path, FOLDER_CACHE), + save_to_storage=self.config.engine.save_cache_to_storage, logger=self.logger ) if self.config.engine.cache_load_past: diff --git a/src/baskerville/models/request_set_cache.py b/src/baskerville/models/request_set_cache.py index b168e47d..9e5c6da6 100644 --- a/src/baskerville/models/request_set_cache.py +++ b/src/baskerville/models/request_set_cache.py @@ -31,7 +31,8 @@ def __init__( session_getter=get_spark_session, group_by_fields=('target', 'ip'), format_='parquet', - path='request_set_cache' + path='request_set_cache', + save_to_storage=False, ): self.__cache = None self.__persistent_cache = None @@ -54,17 +55,20 @@ def __init__( self._count = 0 self._last_updated = datetime.datetime.utcnow() self._changed = False - self.file_manager = FileManager(path, self.session_getter()) + self._save_to_storage = save_to_storage - self.file_name = os.path.join( - path, f'{self.__class__.__name__}.{self.format_}') - self.temp_file_name = os.path.join( - path, f'{self.__class__.__name__}temp.{self.format_}') + if self._save_to_storage: + self.file_manager = FileManager(path, self.session_getter()) - if self.file_manager.path_exists(self.file_name): - self.file_manager.delete_path(self.file_name) - if self.file_manager.path_exists(self.temp_file_name): - self.file_manager.delete_path(self.temp_file_name) + self.file_name = os.path.join( + path, f'{self.__class__.__name__}.{self.format_}') + self.temp_file_name = os.path.join( + path, f'{self.__class__.__name__}temp.{self.format_}') + + if self.file_manager.path_exists(self.file_name): + self.file_manager.delete_path(self.file_name) + if self.file_manager.path_exists(self.temp_file_name): + self.file_manager.delete_path(self.temp_file_name) @property def cache(self): @@ -242,28 +246,42 @@ def filter_by(self, df, columns=None): if not columns: columns = df.columns - if self.file_manager.path_exists(self.persistent_cache_file): - self.__cache = self.session_getter().read.format( - self.format_ - ).load(self.persistent_cache_file).join( - df, - on=columns, - how='inner' - ).drop( - 'a.ip' - ) #.persist(self.storage_level) + if self._save_to_storage: + if self.file_manager.path_exists(self.persistent_cache_file): + self.__cache = self.session_getter().read.format( + self.format_ + ).load(self.persistent_cache_file).join( + df, + on=columns, + how='inner' + ).drop( + 'a.ip' + ) #.persist(self.storage_level) + else: + if self.__cache: + self.__cache = self.__cache.join( + df, + on=columns, + how='inner' + ).drop( + 'a.ip' + )# .persist(self.storage_level) + else: + self.load_empty(self.schema) else: - if self.__cache: - self.__cache = self.__cache.join( + # memory only, no saving to parquet + if self.__persistent_cache: + self.__cache = self.__persistent_cache.join( df, on=columns, how='inner' ).drop( 'a.ip' - )# .persist(self.storage_level) + ) # .persist(self.storage_level) else: self.load_empty(self.schema) + # if self.__persistent_cache: # self.__cache = self.__persistent_cache.join( # df, @@ -311,7 +329,7 @@ def update_self( source_df = source_df.select(columns) # read the whole thing again - if self.file_manager.path_exists(self.file_name): + if self._save_to_storage and self.file_manager.path_exists(self.file_name): if self.__persistent_cache: self.__persistent_cache.unpersist() self.__persistent_cache = self.session_getter().read.format( @@ -364,17 +382,24 @@ def update_self( '*' ).where(F.col('updated_at') >= update_date) - # write back to parquet - different file/folder though - # because self.parquet_name is already in use - # rename temp to self.parquet_name - if self.file_manager.path_exists(self.temp_file_name): - self.file_manager.delete_path(self.temp_file_name) + if self._save_to_storage: + # write back to parquet - different file/folder though + # because self.parquet_name is already in use + # rename temp to self.parquet_name + if self.file_manager.path_exists(self.temp_file_name): + self.file_manager.delete_path(self.temp_file_name) - self.__persistent_cache.write.mode( - 'overwrite' - ).format( - self.format_ - ).save(self.temp_file_name) + self.__persistent_cache.write.mode( + 'overwrite' + ).format( + self.format_ + ).save(self.temp_file_name) + + # rename temp to self.parquet_name + if self.file_manager.path_exists(self.file_name): + self.file_manager.delete_path(self.file_name) + + self.file_manager.rename_path(self.temp_file_name, self.file_name) # we don't need anything in memory anymore source_df.unpersist(blocking=True) @@ -382,12 +407,6 @@ def update_self( del source_df self.empty_all() - # rename temp to self.parquet_name - if self.file_manager.path_exists(self.file_name): - self.file_manager.delete_path(self.file_name) - - self.file_manager.rename_path(self.temp_file_name, self.file_name) - def refresh(self, update_date, hosts, extra_filters=None): df = self._load( update_date=update_date, hosts=hosts, extra_filters=extra_filters @@ -452,11 +471,13 @@ def empty(self): def empty_all(self): if self.__cache is not None: self.__cache.unpersist(blocking=True) - if self.__persistent_cache is not None: - self.__persistent_cache.unpersist(blocking=True) - self.__cache = None - self.__persistent_cache = None + + if self._save_to_storage: + if self.__persistent_cache is not None: + self.__persistent_cache.unpersist(blocking=True) + self.__persistent_cache = None + gc.collect() self.session_getter().sparkContext._jvm.System.gc()