diff --git a/python/pathway/tests/test_persistence.py b/python/pathway/tests/test_persistence.py index 327667d1..f0dbdf32 100644 --- a/python/pathway/tests/test_persistence.py +++ b/python/pathway/tests/test_persistence.py @@ -3,17 +3,24 @@ import json import multiprocessing import os +import pathlib import time +from typing import Callable +import pandas as pd import pytest import pathway as pw +from pathway.internals import api from pathway.internals.parse_graph import G from pathway.tests.utils import ( CsvPathwayChecker, + consolidate, needs_multiprocessing_fork, + run, wait_result_with_checker, write_csv, + write_lines, ) @@ -212,7 +219,7 @@ def pw_identity_program(): pw.io.jsonlines.write(table, output_path) pw.run(persistence_config=persistence_config) - file_contents = {} + file_contents: dict[str, str] = {} next_file_contents = 0 for sequence in scenario: expected_diffs = [] @@ -259,8 +266,303 @@ def pw_identity_program(): actual_diffs = [] with open(output_path, "r") as f: for row in f: - row = json.loads(row) - actual_diffs.append([row["data"], row["diff"]]) + row_parsed = json.loads(row) + actual_diffs.append([row_parsed["data"], row_parsed["diff"]]) actual_diffs.sort() expected_diffs.sort() assert actual_diffs == expected_diffs + + +def combine_columns(df: pd.DataFrame) -> pd.Series: + result = None + for column in df.columns: + if column == "time": + continue + if result is None: + result = df[column].astype(str) + else: + result += "," + df[column].astype(str) + return result + + +def get_one_table_runner( + tmp_path: pathlib.Path, + mode: api.PersistenceMode, + logic: Callable[[pw.Table], pw.Table], + schema: type[pw.Schema], +) -> tuple[Callable[[list[str], set[str]], None], pathlib.Path]: + input_path = tmp_path / "1" + os.makedirs(input_path) + output_path = tmp_path / "out.csv" + persistent_storage_path = tmp_path / "p" + count = 0 + + def run_computation(inputs, expected): + nonlocal count + count += 1 + G.clear() + path = input_path / str(count) + write_lines(path, inputs) + t_1 = pw.io.csv.read(input_path, schema=schema, mode="static") + res = logic(t_1) + pw.io.csv.write(res, output_path) + run( + persistence_config=pw.persistence.Config( + pw.persistence.Backend.filesystem(persistent_storage_path), + persistence_mode=mode, + ) + ) + result = consolidate(pd.read_csv(output_path)) + assert set(combine_columns(result)) == expected + + return run_computation, input_path + + +def get_two_tables_runner( + tmp_path: pathlib.Path, + mode: api.PersistenceMode, + logic: Callable[[pw.Table, pw.Table], pw.Table], + schema: type[pw.Schema], + terminate_on_error: bool = True, +) -> tuple[ + Callable[[list[str], list[str], set[str]], None], pathlib.Path, pathlib.Path +]: + + input_path_1 = tmp_path / "1" + input_path_2 = tmp_path / "2" + os.makedirs(input_path_1) + os.makedirs(input_path_2) + output_path = tmp_path / "out.csv" + persistent_storage_path = tmp_path / "p" + count = 0 + + def run_computation(inputs_1, inputs_2, expected): + nonlocal count + count += 1 + G.clear() + path_1 = input_path_1 / str(count) + path_2 = input_path_2 / str(count) + write_lines(path_1, inputs_1) + write_lines(path_2, inputs_2) + t_1 = pw.io.csv.read(input_path_1, schema=schema, mode="static") + t_2 = pw.io.csv.read(input_path_2, schema=schema, mode="static") + res = logic(t_1, t_2) + pw.io.csv.write(res, output_path) + run( + persistence_config=pw.persistence.Config( + pw.persistence.Backend.filesystem(persistent_storage_path), + persistence_mode=mode, + ), + terminate_on_error=terminate_on_error, + # hack to allow changes from different files at different point in time + ) + result = consolidate(pd.read_csv(output_path)) + assert set(combine_columns(result)) == expected + + return run_computation, input_path_1, input_path_2 + + +@pytest.mark.parametrize( + "mode", [api.PersistenceMode.PERSISTING, api.PersistenceMode.OPERATOR_PERSISTING] +) +def test_restrict(tmp_path, mode): + class InputSchema(pw.Schema): + a: int = pw.column_definition(primary_key=True) + + def logic(t_1: pw.Table, t_2: pw.Table) -> pw.Table: + t_2.promise_universe_is_subset_of(t_1) + return t_1.restrict(t_2) + + run, _, input_path_2 = get_two_tables_runner( + tmp_path, mode, logic, InputSchema, terminate_on_error=False + ) + + run(["a", "1", "2", "3"], ["a", "1"], {"1,1"}) + run(["a"], ["a", "3"], {"3,1"}) + run(["a", "4", "5"], ["a", "5"], {"5,1"}) + run(["a", "6"], ["a", "4", "6"], {"4,1", "6,1"}) + os.remove(input_path_2 / "3") + run(["a"], ["a"], {"5,-1"}) + + +@pytest.mark.parametrize( + "mode", [api.PersistenceMode.PERSISTING, api.PersistenceMode.OPERATOR_PERSISTING] +) +def test_with_universe_of(tmp_path, mode): + class InputSchema(pw.Schema): + a: int = pw.column_definition(primary_key=True) + b: int + + def logic(t_1: pw.Table, t_2: pw.Table) -> pw.Table: + return t_1.with_universe_of(t_2).with_columns(c=t_2.b) + + run, input_path_1, input_path_2 = get_two_tables_runner( + tmp_path, mode, logic, InputSchema, terminate_on_error=False + ) + + run(["a,b", "1,2", "2,3"], ["a,b", "1,3", "2,4"], {"1,2,3,1", "2,3,4,1"}) + run(["a,b", "3,3", "5,1"], ["a,b", "3,4", "5,0"], {"3,3,4,1", "5,1,0,1"}) + os.remove(input_path_1 / "2") + os.remove(input_path_2 / "2") + run( + ["a,b", "3,4"], + ["a,b", "3,5"], + { + "3,3,4,-1", + "5,1,0,-1", + "3,4,5,1", + }, + ) + + +@pytest.mark.parametrize( + "mode", [api.PersistenceMode.PERSISTING, api.PersistenceMode.OPERATOR_PERSISTING] +) +def test_intersect(tmp_path, mode): + class InputSchema(pw.Schema): + a: int = pw.column_definition(primary_key=True) + + def logic(t_1: pw.Table, t_2: pw.Table) -> pw.Table: + return t_1.intersect(t_2) + + run, _, input_path_2 = get_two_tables_runner(tmp_path, mode, logic, InputSchema) + + run(["a", "1", "2", "3"], ["a", "1"], {"1,1"}) + run(["a"], ["a", "3"], {"3,1"}) + run(["a", "4", "5"], ["a", "5", "6"], {"5,1"}) + run(["a", "6"], ["a", "4"], {"4,1", "6,1"}) + os.remove(input_path_2 / "3") + run(["a"], ["a"], {"5,-1", "6,-1"}) + + +@pytest.mark.parametrize( + "mode", [api.PersistenceMode.PERSISTING, api.PersistenceMode.OPERATOR_PERSISTING] +) +def test_difference(tmp_path, mode): + class InputSchema(pw.Schema): + a: int = pw.column_definition(primary_key=True) + + def logic(t_1: pw.Table, t_2: pw.Table) -> pw.Table: + return t_1.difference(t_2) + + run, _, input_path_2 = get_two_tables_runner(tmp_path, mode, logic, InputSchema) + + run(["a", "1", "2", "3"], ["a", "1"], {"2,1", "3,1"}) + run(["a"], ["a", "3"], {"3,-1"}) + run(["a", "4", "5"], ["a", "5", "6"], {"4,1"}) + run(["a", "6"], ["a", "4"], {"4,-1"}) + os.remove(input_path_2 / "3") + run(["a"], ["a"], {"5,1", "6,1"}) + + +@pytest.mark.parametrize( + "mode", [api.PersistenceMode.PERSISTING, api.PersistenceMode.OPERATOR_PERSISTING] +) +def test_sorting_ix(tmp_path, mode): + class InputSchema(pw.Schema): + a: int = pw.column_definition(primary_key=True) + + def logic(t_1: pw.Table) -> pw.Table: + t_1 += t_1.sort(pw.this.a) + t_1_filtered = t_1.filter(pw.this.prev.is_not_none()) + return t_1_filtered.select(b=t_1.ix(pw.this.prev).a, a=pw.this.a) + + run, input_path = get_one_table_runner(tmp_path, mode, logic, InputSchema) + + run(["a", "1", "6"], {"1,6,1"}) + run(["a", "3"], {"1,6,-1", "1,3,1", "3,6,1"}) + run(["a", "4", "5"], {"3,6,-1", "3,4,1", "4,5,1", "5,6,1"}) + os.remove(input_path / "2") + run(["a"], {"1,3,-1", "3,4,-1", "1,4,1"}) + run(["a", "2"], {"1,4,-1", "1,2,1", "2,4,1"}) + + +@pytest.mark.parametrize( + "mode", [api.PersistenceMode.PERSISTING, api.PersistenceMode.OPERATOR_PERSISTING] +) +def test_update_rows(tmp_path, mode): + class InputSchema(pw.Schema): + a: int = pw.column_definition(primary_key=True) + b: int + + def logic(t_1: pw.Table, t_2: pw.Table) -> pw.Table: + return t_1.update_rows(t_2) + + run, _, input_path_2 = get_two_tables_runner(tmp_path, mode, logic, InputSchema) + + run(["a,b", "1,2", "2,4"], ["a,b", "1,3", "3,5"], {"1,3,1", "2,4,1", "3,5,1"}) + run(["a,b", "3,3"], ["a,b", "2,6", "5,1"], {"2,4,-1", "2,6,1", "5,1,1"}) + os.remove(input_path_2 / "1") + run(["a,b"], ["a,b"], {"3,5,-1", "3,3,1", "1,3,-1", "1,2,1"}) + run(["a,b", "7,10"], ["a,b", "3,8"], {"3,3,-1", "3,8,1", "7,10,1"}) + + +@pytest.mark.parametrize( + "mode", [api.PersistenceMode.PERSISTING, api.PersistenceMode.OPERATOR_PERSISTING] +) +def test_update_cells(tmp_path, mode): + class InputSchema(pw.Schema): + a: int = pw.column_definition(primary_key=True) + b: int + + def logic(t_1: pw.Table, t_2: pw.Table) -> pw.Table: + t_2.promise_universe_is_subset_of(t_1) + return t_1.update_cells(t_2) + + run, _, input_path_2 = get_two_tables_runner( + tmp_path, mode, logic, InputSchema, terminate_on_error=False + ) + + run(["a,b", "1,2", "2,4"], ["a,b", "1,3"], {"1,3,1", "2,4,1"}) + run(["a,b", "3,3"], ["a,b", "2,6"], {"2,4,-1", "2,6,1", "3,3,1"}) + os.remove(input_path_2 / "1") + run(["a,b"], ["a,b"], {"1,3,-1", "1,2,1"}) + run(["a,b", "7,10"], ["a,b", "3,8"], {"3,3,-1", "3,8,1", "7,10,1"}) + + +@pytest.mark.parametrize( + "mode", [api.PersistenceMode.PERSISTING, api.PersistenceMode.OPERATOR_PERSISTING] +) +def test_join(tmp_path, mode): + class InputSchema(pw.Schema): + a: int = pw.column_definition(primary_key=True) + b: int + + def logic(t_1: pw.Table, t_2: pw.Table) -> pw.Table: + return t_1.join(t_2, t_1.a == t_2.a).select( + pw.this.a, b=pw.left.b, c=pw.right.b + ) + + run, _, input_path_2 = get_two_tables_runner(tmp_path, mode, logic, InputSchema) + + run(["a,b", "1,2", "2,4"], ["a,b", "1,3"], {"1,2,3,1"}) + run(["a,b", "3,3"], ["a,b", "2,6", "1,4"], {"2,4,6,1", "1,2,4,1"}) + os.remove(input_path_2 / "1") + run(["a,b"], ["a,b"], {"1,2,3,-1"}) + run(["a,b", "1,4"], ["a,b", "1,8"], {"1,2,8,1", "1,4,8,1", "1,4,4,1"}) + + +@pytest.mark.parametrize( + "mode", [api.PersistenceMode.PERSISTING, api.PersistenceMode.OPERATOR_PERSISTING] +) +def test_groupby(tmp_path, mode): + class InputSchema(pw.Schema): + a: int + b: int + + def logic(t_1: pw.Table) -> pw.Table: + return t_1.groupby(pw.this.a).reduce( + pw.this.a, + c=pw.reducers.count(), + s=pw.reducers.sum(pw.this.b), + m=pw.reducers.max(pw.this.b), + ) + + run, input_path = get_one_table_runner(tmp_path, mode, logic, InputSchema) + + run(["a,b", "1,3", "2,4"], {"1,1,3,3,1", "2,1,4,4,1"}) + run(["a,b", "1,1"], {"1,1,3,3,-1", "1,2,4,3,1"}) + run(["a,b", "2,5"], {"2,1,4,4,-1", "2,2,9,5,1"}) + os.remove(input_path / "2") + run(["a,b"], {"1,1,3,3,1", "1,2,4,3,-1"}) + run(["a,b", "2,0"], {"2,2,9,5,-1", "2,3,9,5,1"}) diff --git a/python/pathway/tests/utils.py b/python/pathway/tests/utils.py index f606b68f..1df03b1c 100644 --- a/python/pathway/tests/utils.py +++ b/python/pathway/tests/utils.py @@ -693,6 +693,11 @@ def write_lines(path: str | pathlib.Path, data: str | list[str]): f.writelines(data) +def read_lines(path: str | pathlib.Path) -> list[str]: + with open(path) as f: + return f.readlines() + + def get_aws_s3_settings(): return pw.io.s3.AwsS3Settings( bucket_name="aws-integrationtest", @@ -777,3 +782,28 @@ def deprecated_call_here( *, match: str | re.Pattern[str] | None = None ) -> AbstractContextManager[pytest.WarningsRecorder]: return warns_here((DeprecationWarning, PendingDeprecationWarning), match=match) + + +def consolidate(df: pd.DataFrame) -> pd.DataFrame: + values = None + for column in df.columns: + if column in ["time", "diff"]: + continue + if values is None: + values = df[column].astype(str) + else: + values = values + "," + df[column].astype(str) + df["_all_values"] = values + + total = {} + for _, row in df.iterrows(): + value = row["_all_values"] + if value not in total: + total[value] = 0 + total[value] += row["diff"] + + for i in range(df.shape[0]): + value = df.at[i, "_all_values"] + df.at[i, "diff"] = total[value] + total[value] = 0 + return df[df["diff"] != 0].drop(columns=["_all_values"]) diff --git a/src/engine/dataflow.rs b/src/engine/dataflow.rs index dfbb55ff..32512c66 100644 --- a/src/engine/dataflow.rs +++ b/src/engine/dataflow.rs @@ -68,7 +68,7 @@ use log::{error, info}; use ndarray::ArrayD; use once_cell::unsync::{Lazy, OnceCell}; use persist::{ - EmptyPersistenceWrapper, PersistableCollection, PersistenceWrapper, + effective_persistent_id, EmptyPersistenceWrapper, PersistableCollection, PersistenceWrapper, TimestampBasedPersistenceWrapper, }; use pyo3::PyObject; @@ -342,12 +342,16 @@ enum ColumnData { collection: Values, arranged: OnceCell>, consolidated: OnceCell>, + persisted_arranged: OnceCell>, keys: OnceCell>, + keys_persisted_arranged: OnceCell>, }, Arranged { arranged: ValuesArranged, collection: OnceCell>, + persisted_arranged: OnceCell>, keys: OnceCell>, + keys_persisted_arranged: OnceCell>, }, } @@ -357,7 +361,9 @@ impl ColumnData { collection, arranged: OnceCell::new(), consolidated: OnceCell::new(), + persisted_arranged: OnceCell::new(), keys: OnceCell::new(), + keys_persisted_arranged: OnceCell::new(), } } @@ -365,7 +371,9 @@ impl ColumnData { Self::Arranged { collection: OnceCell::new(), arranged, + persisted_arranged: OnceCell::new(), keys: OnceCell::new(), + keys_persisted_arranged: OnceCell::new(), } } @@ -389,6 +397,32 @@ impl ColumnData { } } + fn persisted_arranged( + &self, + persistence_wrapper: &mut Box>, + pollers: &mut Vec, + connector_threads: &mut Vec>, + ) -> Result<&ValuesArranged> { + match self { + Self::Collection { + persisted_arranged, .. + } + | Self::Arranged { + persisted_arranged, .. + } => persisted_arranged.get_or_try_init(|| { + Ok(self + .collection() + .maybe_persist_internal( + persistence_wrapper, + pollers, + connector_threads, + "values_arranged", + )? + .arrange()) + }), + } + } + fn keys(&self) -> &Keys { match self { Self::Collection { keys, .. } | Self::Arranged { keys, .. } => keys.get_or_init(|| { @@ -398,9 +432,32 @@ impl ColumnData { } } - fn keys_arranged(&self) -> KeysArranged { - self.keys().arrange() - // FIXME: maybe sth better if it is possible to extract arranged keys from an arranged collection + fn keys_persisted_arranged( + &self, + persistence_wrapper: &mut Box>, + pollers: &mut Vec, + connector_threads: &mut Vec>, + ) -> Result<&KeysArranged> { + match self { + Self::Collection { + keys_persisted_arranged, + .. + } + | Self::Arranged { + keys_persisted_arranged, + .. + } => keys_persisted_arranged.get_or_try_init(|| { + Ok(self + .keys() + .maybe_persist_internal( + persistence_wrapper, + pollers, + connector_threads, + "keys_arranged", + )? + .arrange()) + }), + } } fn consolidated(&self) -> &Values { @@ -487,17 +544,18 @@ impl Table { } } - fn from_arranged(values: ValuesArranged) -> Self { - let data = Rc::new(ColumnData::from_arranged(values)); - Self::from_data(data) - } - fn values(&self) -> &Values { self.data.collection() } - fn values_arranged(&self) -> &ValuesArranged { - self.data.arranged() + fn values_persisted_arranged( + &self, + persistence_wrapper: &mut Box>, + pollers: &mut Vec, + connector_threads: &mut Vec>, + ) -> Result<&ValuesArranged> { + self.data + .persisted_arranged(persistence_wrapper, pollers, connector_threads) } fn values_consolidated(&self) -> &Values { @@ -508,8 +566,14 @@ impl Table { self.data.keys() } - fn keys_arranged(&self) -> KeysArranged { - self.data.keys_arranged() + fn keys_persisted_arranged( + &self, + persistence_wrapper: &mut Box>, + pollers: &mut Vec, + connector_threads: &mut Vec>, + ) -> Result<&KeysArranged> { + self.data + .keys_persisted_arranged(persistence_wrapper, pollers, connector_threads) } } @@ -771,7 +835,6 @@ struct DataflowGraphInner { probes: HashMap>, ignore_asserts: bool, persistence_wrapper: Box>, - persisted_states_count: u64, config: Arc, terminate_on_error: bool, default_error_log: Option, @@ -955,6 +1018,101 @@ enum MaybeUpdate { Update(T), } +trait MaybePersist +where + S: MaybeTotalScope, + Self: Sized, +{ + fn maybe_persist(&self, graph: &mut DataflowGraphInner, name: &str) -> Result { + self.maybe_persist_internal( + &mut graph.persistence_wrapper, + &mut graph.pollers, + &mut graph.connector_threads, + name, + ) + } + + fn maybe_persist_internal( + &self, + persistence_wrapper: &mut Box>, + pollers: &mut Vec, + connector_threads: &mut Vec>, + name: &str, + ) -> Result; + + fn filter_out_persisted(&self, graph: &mut Box>) -> Result; +} + +impl MaybePersist for Collection +where + S: MaybeTotalScope, + D: ExchangeData + Shard, + R: ExchangeData + Semigroup, + Collection: Into> + From>, +{ + fn maybe_persist_internal( + &self, + persistence_wrapper: &mut Box>, + pollers: &mut Vec, + connector_threads: &mut Vec>, + name: &str, + ) -> Result { + // TODO: generate better persistent ids that can be used even if graph changes + let effective_persistent_id = effective_persistent_id( + persistence_wrapper, + false, + None, + RequiredPersistenceMode::OperatorPersistence, + |next_state_id| { + let generated_external_id = format!("{name}-{next_state_id}"); + info!("Persistent ID autogenerated for {name}: {generated_external_id}"); + generated_external_id + }, + )?; + let persistent_id = effective_persistent_id + .clone() + .map(IntoPersistentId::into_persistent_id); + + if let Some(persistent_id) = persistent_id { + let (persisted_collection, poller, thread_handle) = persistence_wrapper + .as_mut() + .maybe_persist_named(self.clone().into(), name, persistent_id)?; + if let Some(poller) = poller { + pollers.push(poller); + } + if let Some(thread_handle) = thread_handle { + connector_threads.push(thread_handle); + } + Ok(persisted_collection.into()) + } else { + Ok(self.clone()) + } + } + + fn filter_out_persisted( + &self, + persistence_wrapper: &mut Box>, + ) -> Result { + // Check if persistent id would be generated for the operator. + // If yes, it means operator persistence is enabled and we need to filter out old persisted rows. + let with_persistent_id = effective_persistent_id( + persistence_wrapper, + false, + None, + RequiredPersistenceMode::OperatorPersistence, + |_| String::new(), + )? + .is_some(); + if with_persistent_id { + Ok(persistence_wrapper + .filter_out_persisted(self.clone().into()) + .into()) + } else { + Ok(self.clone()) + } + } +} + #[allow(clippy::unnecessary_wraps)] // we want to always return Result for symmetry impl DataflowGraphInner { #[allow(clippy::too_many_arguments)] @@ -984,7 +1142,6 @@ impl DataflowGraphInner { probes: HashMap::new(), ignore_asserts, persistence_wrapper, - persisted_states_count: 0, config, terminate_on_error, default_error_log, @@ -1009,6 +1166,36 @@ impl DataflowGraphInner { self.config.processes() } + fn get_table_values_persisted_arranged( + &mut self, + handle: TableHandle, + ) -> Result> { + self.tables + .get(handle) + .ok_or(Error::InvalidTableHandle)? + .values_persisted_arranged( + &mut self.persistence_wrapper, + &mut self.pollers, + &mut self.connector_threads, + ) + .cloned() + } + + fn get_table_keys_persisted_arranged( + &mut self, + handle: TableHandle, + ) -> Result> { + self.tables + .get(handle) + .ok_or(Error::InvalidTableHandle)? + .keys_persisted_arranged( + &mut self.persistence_wrapper, + &mut self.pollers, + &mut self.connector_threads, + ) + .cloned() + } + fn empty_universe(&mut self) -> Result { self.static_universe(Vec::new()) } @@ -1062,7 +1249,6 @@ impl DataflowGraphInner { ) }, ) - .distinct() .negate(), ); let error_logger = self.create_error_logger()?; @@ -1696,25 +1882,27 @@ impl DataflowGraphInner { same_universes: bool, table_properties: Arc, ) -> Result { + let original_values_arranged = + self.get_table_values_persisted_arranged(original_table_handle)?; + let new_values_arranged = self.get_table_values_persisted_arranged(new_table_handle)?; let original_table = self .tables .get(original_table_handle) .ok_or(Error::InvalidTableHandle)?; - let new_table = self .tables .get(new_table_handle) .ok_or(Error::InvalidTableHandle)?; - let result = new_table.values_arranged().join_core( - original_table.values_arranged(), - |key, new_values, orig_values| { + let result = new_values_arranged + .join_core(&original_values_arranged, |key, new_values, orig_values| { once(( *key, Value::from([new_values.clone(), orig_values.clone()].as_slice()), )) - }, - ); + }) + .filter_out_persisted(&mut self.persistence_wrapper)?; + let result = self.make_output_keys_match_input_keys(new_table.values(), &result)?; if !self.ignore_asserts && same_universes { @@ -1732,29 +1920,31 @@ impl DataflowGraphInner { other_table_handles: Vec, table_properties: Arc, ) -> Result { - let table = self - .tables - .get(table_handle) - .ok_or(Error::InvalidTableHandle)?; - let mut new_values = table.data.clone(); + let mut restricted_keys: Option> = None; for other_table_handle in other_table_handles { - let other_table = self - .tables - .get(other_table_handle) - .ok_or(Error::InvalidTableHandle)?; - new_values = Rc::new(ColumnData::from_collection( - new_values - .arranged() - .join_core(&other_table.keys_arranged(), |k, values, ()| { - once((*k, values.clone())) - }) - .into(), - )); + let other_table_keys_arranged = + self.get_table_keys_persisted_arranged(other_table_handle)?; + restricted_keys = if let Some(restricted_keys) = restricted_keys { + Some( + restricted_keys + .join_core(&other_table_keys_arranged, |k, (), ()| once((*k, ()))) + .arrange(), + ) + } else { + Some(other_table_keys_arranged) + }; } - Ok(self - .tables - .alloc(Table::from_data(new_values).with_properties(table_properties))) + if let Some(restricted_keys) = restricted_keys { + let data = self + .get_table_values_persisted_arranged(table_handle)? + .join_core(&restricted_keys, |k, values, ()| once((*k, values.clone()))) + .filter_out_persisted(&mut self.persistence_wrapper)?; + let table = Table::from_collection(data); + Ok(self.tables.alloc(table.with_properties(table_properties))) + } else { + Ok(table_handle) + } } fn reindex_table( @@ -1798,20 +1988,18 @@ impl DataflowGraphInner { right_table_handle: TableHandle, table_properties: Arc, ) -> Result { + let left_values_arranged = self.get_table_values_persisted_arranged(left_table_handle)?; + let right_keys_arranged = self.get_table_keys_persisted_arranged(right_table_handle)?; let left_table = self .tables .get(left_table_handle) .ok_or(Error::InvalidTableHandle)?; - let right_table = self - .tables - .get(right_table_handle) - .ok_or(Error::InvalidTableHandle)?; - let intersection = left_table - .values_arranged() - .join_core(&right_table.keys_arranged(), |k, values, ()| { + let intersection = left_values_arranged + .join_core(&right_keys_arranged, |k, values, ()| { once((*k, values.clone())) - }); + }) + .filter_out_persisted(&mut self.persistence_wrapper)?; let new_values = left_table .values() @@ -1942,6 +2130,7 @@ impl DataflowGraphInner { SortingCell::new(instance, key, id) }, ) + .maybe_persist(self, "sort_table")? .arrange(); let prev_next: ArrangedByKey = @@ -1959,8 +2148,8 @@ impl DataflowGraphInner { }) .arrange(); - let new_values = table - .values_arranged() + let new_values = self + .get_table_values_persisted_arranged(table_handle)? .join_core(&prev_next, |key, values, prev_next| { once(( *key, @@ -1971,7 +2160,8 @@ impl DataflowGraphInner { .collect(), ), )) - }); + }) + .filter_out_persisted(&mut self.persistence_wrapper)?; Ok(self .tables @@ -2004,6 +2194,7 @@ impl DataflowGraphInner { (k, MaybeUpdate::Update(v)) }), ) + .maybe_persist(self, "update_rows")? .arrange_named("update_rows_arrange::both")) } @@ -2016,7 +2207,7 @@ impl DataflowGraphInner { let error_logger = self.create_error_logger()?; let both_arranged = self.update_rows_arrange(table_handle, update_handle)?; - let updated_values = both_arranged.reduce_abelian( + let updated_values: ValuesArranged = both_arranged.reduce_abelian( "update_rows_table::updated", move |key, input, output| { let values = match input { @@ -2033,10 +2224,13 @@ impl DataflowGraphInner { output.push((values.clone(), 1)); }, ); + let result = updated_values + .as_collection(|k: &Key, v: &Value| (*k, v.clone())) + .filter_out_persisted(&mut self.persistence_wrapper)?; Ok(self .tables - .alloc(Table::from_arranged(updated_values).with_properties(table_properties))) + .alloc(Table::from_collection(result).with_properties(table_properties))) } fn update_cells_table( @@ -2052,7 +2246,7 @@ impl DataflowGraphInner { let error_reporter = self.error_reporter.clone(); - let updated_values = both_arranged.reduce_abelian( + let updated_values: ValuesArranged = both_arranged.reduce_abelian( "update_cells_table::updated", move |key, input, output| { let (original_values, selected_values, selected_paths) = match input { @@ -2093,9 +2287,13 @@ impl DataflowGraphInner { }, ); + let result = updated_values + .as_collection(|k, v| (*k, v.clone())) + .filter_out_persisted(&mut self.persistence_wrapper)?; + Ok(self .tables - .alloc(Table::from_arranged(updated_values).with_properties(table_properties))) + .alloc(Table::from_collection(result).with_properties(table_properties))) } fn gradual_broadcast( @@ -2167,10 +2365,6 @@ impl DataflowGraphInner { ix_key_policy: IxKeyPolicy, table_properties: Arc, ) -> Result { - let to_ix_table = self - .tables - .get(to_ix_handle) - .ok_or(Error::InvalidTableHandle)?; let key_table = self .tables .get(key_handle) @@ -2206,22 +2400,25 @@ impl DataflowGraphInner { } }), }; + let to_ix_table_values_arranged = self.get_table_values_persisted_arranged(to_ix_handle)?; + let new_table = if ix_key_policy == IxKeyPolicy::SkipMissing { let valued_to_keys_arranged: ArrangedByKey = values_to_keys .map_named( "ix_skip_missing_arrange_keys", |(source_key, (result_key, _result_value))| (source_key, result_key), ) + .maybe_persist(self, "ix")? .arrange(); valued_to_keys_arranged.join_core( - to_ix_table.values_arranged(), + &to_ix_table_values_arranged, |_source_key, result_key, to_ix_row| once((*result_key, to_ix_row.clone())), ) } else { let values_to_keys_arranged: ArrangedByKey = - values_to_keys.arrange(); + values_to_keys.maybe_persist(self, "ix")?.arrange(); values_to_keys_arranged.join_core( - to_ix_table.values_arranged(), + &to_ix_table_values_arranged, |_source_key, (result_key, result_row), to_ix_row| { once(( *result_key, @@ -2229,7 +2426,8 @@ impl DataflowGraphInner { )) }, ) - }; + } + .filter_out_persisted(&mut self.persistence_wrapper)?; let new_table = match ix_key_policy { IxKeyPolicy::ForwardNone => { let none_keys = @@ -2244,6 +2442,10 @@ impl DataflowGraphInner { let new_table = if ix_key_policy == IxKeyPolicy::SkipMissing { new_table } else { + let key_table = self + .tables + .get(key_handle) + .ok_or(Error::InvalidTableHandle)?; self.make_output_keys_match_input_keys(key_table.values(), &new_table)? }; @@ -2342,10 +2544,6 @@ impl DataflowGraphInner { .tables .get(left_data.table_handle) .ok_or(Error::InvalidTableHandle)?; - let right_table = self - .tables - .get(right_data.table_handle) - .ok_or(Error::InvalidTableHandle)?; let error_reporter_left = self.error_reporter.clone(); let error_reporter_right = self.error_reporter.clone(); @@ -2369,7 +2567,13 @@ impl DataflowGraphInner { }); let join_left = left_with_join_key .flat_map(|(join_key, left_key_values)| Some((join_key?, left_key_values))); - let join_left_arranged: ArrangedByKey = join_left.arrange(); + let join_left_arranged: ArrangedByKey = + join_left.maybe_persist(self, "join")?.arrange(); + + let right_table = self + .tables + .get(right_data.table_handle) + .ok_or(Error::InvalidTableHandle)?; let right_with_join_key = right_table .values() @@ -2386,12 +2590,13 @@ impl DataflowGraphInner { }); let join_right = right_with_join_key .flat_map(|(join_key, right_key_values)| Some((join_key?, right_key_values))); - let join_right_arranged: ArrangedByKey = join_right.arrange(); + let join_right_arranged: ArrangedByKey = + join_right.maybe_persist(self, "join")?.arrange(); let join_left_right = join_left_arranged .join_core(&join_right_arranged, |join_key, left_key, right_key| { once((*join_key, left_key.clone(), right_key.clone())) - }); + }); // TODO modify join_core internals to avoid recomputing join on restart let join_left_right_to_result_fn = match join_type { JoinType::LeftKeysFull | JoinType::LeftKeysSubset => { @@ -2402,26 +2607,28 @@ impl DataflowGraphInner { .with_shard_of(join_key) }, }; - let result_left_right = join_left_right.map_named( - "join::result_left_right", - move |(join_key, (left_key, left_values), (right_key, right_values))| { - ( - join_left_right_to_result_fn(join_key, left_key, right_key), - Value::from( - [ - Value::Pointer(left_key), - left_values, - Value::Pointer(right_key), - right_values, - ] - .as_slice(), - ), - ) - }, - ); + let result_left_right = join_left_right + .filter_out_persisted(&mut self.persistence_wrapper)? + .map_named( + "join::result_left_right", + move |(join_key, (left_key, left_values), (right_key, right_values))| { + ( + join_left_right_to_result_fn(join_key, left_key, right_key), + Value::from( + [ + Value::Pointer(left_key), + left_values, + Value::Pointer(right_key), + right_values, + ] + .as_slice(), + ), + ) + }, + ); - let left_outer = || { - left_with_join_key.concat( + let mut left_outer = || -> Result<_> { + Ok(left_with_join_key.concat( &join_left_right .map_named( "join::left_outer_res", @@ -2430,12 +2637,13 @@ impl DataflowGraphInner { }, ) .distinct() + .filter_out_persisted(&mut self.persistence_wrapper)? .negate() .map_named("join::left_outer_wrap", |(key, values)| (Some(key), values)), - ) + )) }; let result_left_outer = match join_type { - JoinType::LeftOuter | JoinType::FullOuter => Some(left_outer().map_named( + JoinType::LeftOuter | JoinType::FullOuter => Some(left_outer()?.map_named( "join::result_left_outer", |(join_key, (left_key, left_values))| { let result_key = Key::for_values(&[Value::from(left_key), Value::None]) @@ -2444,7 +2652,7 @@ impl DataflowGraphInner { (left_key, left_values, result_key) }, )), - JoinType::LeftKeysFull => Some(left_outer().map_named( + JoinType::LeftKeysFull => Some(left_outer()?.map_named( "join::result_left_outer", |(_join_key, (left_key, left_values))| (left_key, left_values, left_key), )), @@ -2475,22 +2683,23 @@ impl DataflowGraphInner { result_left_right }; - let right_outer = || { - right_with_join_key.concat( + let mut right_outer = || -> Result<_> { + Ok(right_with_join_key.concat( &join_left_right .map_named( "join::right_outer_res", |(join_key, _left_key, right_key_values)| (join_key, right_key_values), ) .distinct() + .filter_out_persisted(&mut self.persistence_wrapper)? .negate() .map_named("join::right_outer_wrap", |(key, values)| { (Some(key), values) }), - ) + )) }; let result_right_outer = match join_type { - JoinType::RightOuter | JoinType::FullOuter => Some(right_outer().map_named( + JoinType::RightOuter | JoinType::FullOuter => Some(right_outer()?.map_named( "join::right_result_outer", |(join_key, (right_key, right_values))| { let result_key = Key::for_values(&[Value::None, Value::from(right_key)]) @@ -2662,85 +2871,6 @@ impl DataflowGraphInner { .tables .alloc(Table::from_collection(new_values).with_properties(table_properties))) } - - fn effective_persistent_id( - &self, - reader_is_internal: bool, - external_persistent_id: Option<&ExternalPersistentId>, - required_persistence_mode: RequiredPersistenceMode, - logic: impl FnOnce() -> String, - ) -> Result> { - let has_persistent_storage = self - .persistence_wrapper - .get_worker_persistent_storage() - .is_some(); - if let Some(external_persistent_id) = external_persistent_id { - if !has_persistent_storage { - return Err(Error::NoPersistentStorage(external_persistent_id.clone())); - } - } - if external_persistent_id.is_some() { - Ok(external_persistent_id.cloned()) - } else if has_persistent_storage && !reader_is_internal { - let worker_persistent_storage = self - .persistence_wrapper - .get_worker_persistent_storage() - .unwrap() - .lock() - .unwrap(); - if worker_persistent_storage.persistent_id_generation_enabled(required_persistence_mode) - && worker_persistent_storage.table_persistence_enabled() - { - Ok(Some(logic())) - } else { - Ok(None) - } - } else { - Ok(None) - } - } - - fn maybe_persist( - &mut self, - collection: Collection, - name: &str, - ) -> Result> - where - D: ExchangeData + Shard, - R: ExchangeData + Semigroup, - Collection: Into> + From>, - { - self.persisted_states_count += 1; - // TODO: generate better persistent ids that can be used even if graph changes - let effective_persistent_id = self.effective_persistent_id( - false, - None, - RequiredPersistenceMode::OperatorPersistence, - || { - let generated_external_id = format!("{name}-{}", self.persisted_states_count); - info!("Persistent ID autogenerated for {name}: {generated_external_id}"); - generated_external_id - }, - )?; - let persistent_id = effective_persistent_id - .clone() - .map(IntoPersistentId::into_persistent_id); - - if let Some(persistent_id) = persistent_id { - let (persisted_collection, poller, thread_handle) = self - .persistence_wrapper - .maybe_persist_named(collection.into(), name, persistent_id)?; - if let Some(poller) = poller { - self.pollers.push(poller); - } - if let Some(thread_handle) = thread_handle { - self.connector_threads.push(thread_handle); - } - Ok(persisted_collection.into()) - } else { - Ok(collection) - } - } } trait DataflowReducer { @@ -2765,22 +2895,22 @@ where _trace: Trace, graph: &mut DataflowGraphInner, ) -> Result> { - let initialized = values.map_named("DataFlowReducer::reduce::init", { - let self_ = self.clone(); - let error_logger = error_logger.clone(); - move |(source_key, result_key, values)| { - let state = if values.contains(&Value::Error) { - None - } else { - self_ - .init(&source_key, &values) - .ok_with_logger(error_logger.as_ref()) - }; - (result_key, state) - } - }); - Ok(graph - .maybe_persist(initialized, "DataFlowReducer::reduce")? + Ok(values + .map_named("DataFlowReducer::reduce::init", { + let self_ = self.clone(); + let error_logger = error_logger.clone(); + move |(source_key, result_key, values)| { + let state = if values.contains(&Value::Error) { + None + } else { + self_ + .init(&source_key, &values) + .ok_with_logger(error_logger.as_ref()) + }; + (result_key, state) + } + }) + .maybe_persist(graph, "DataFlowReducer::reduce")? .reduce({ let self_ = self.clone(); move |_key, input, output| { @@ -2819,7 +2949,7 @@ impl DataflowReducer for IntSumReducer { _trace: Trace, graph: &mut DataflowGraphInner, ) -> Result> { - let initialized = values + Ok(values .map_named("IntSumReducer::reduce::init", { let self_ = self.clone(); move |(source_key, result_key, values)| { @@ -2833,9 +2963,8 @@ impl DataflowReducer for IntSumReducer { (result_key, state) } }) - .explode(|(key, state)| once((key, state))); - Ok(graph - .maybe_persist(initialized, "IntSumReducer::reduce")? + .explode(|(key, state)| once((key, state))) + .maybe_persist(graph, "IntSumReducer::reduce")? .count() .map_named("IntSumReducer::reduce", move |(key, state)| { (key, self.finish(state)) @@ -2852,12 +2981,12 @@ impl DataflowReducer for CountReducer { _trace: Trace, graph: &mut DataflowGraphInner, ) -> Result> { - let initialized = values.map_named( - "CountReducer::reduce::init", - |(_source_key, result_key, _values)| (result_key), - ); - Ok(graph - .maybe_persist(initialized, "CountReducer::reduce")? + Ok(values + .map_named( + "CountReducer::reduce::init", + |(_source_key, result_key, _values)| (result_key), + ) + .maybe_persist(graph, "CountReducer::reduce")? .count() .map_named("CountReducer::reduce", |(key, count)| { (key, Value::from(count as i64)) @@ -3108,15 +3237,19 @@ where once((*key, new_values)) }); } - joined.map_named("group_by_table::wrap", |(key, values)| { - (key, Value::Tuple(values)) - }) + joined + .map_named("group_by_table::wrap", |(key, values)| { + (key, Value::Tuple(values)) + }) + .filter_out_persisted(&mut self.persistence_wrapper)? } else { with_new_key .map_named("group_by_table::empty", |(_key, new_key, _values)| { (new_key, Value::Tuple(Arc::from([]))) }) + .maybe_persist(self, "groupby")? .distinct() + .filter_out_persisted(&mut self.persistence_wrapper)? }; Ok(self .tables @@ -3144,13 +3277,13 @@ where .ok_or(Error::InvalidTableHandle)?; let error_reporter = self.error_reporter.clone(); - self.persisted_states_count += 1; - let effective_persistent_id = self.effective_persistent_id( + let effective_persistent_id = effective_persistent_id( + &mut self.persistence_wrapper, false, external_persistent_id, RequiredPersistenceMode::InputOrOperatorPersistence, - || { - let generated_external_id = format!("deduplicate-{}", self.persisted_states_count); + |next_state_id| { + let generated_external_id = format!("deduplicate-{next_state_id}"); info!("Persistent ID autogenerated for deduplicate: {generated_external_id}"); generated_external_id }, @@ -3323,11 +3456,12 @@ impl> DataflowGraphInner table_properties: Arc, external_persistent_id: Option<&ExternalPersistentId>, ) -> Result { - let effective_persistent_id = self.effective_persistent_id( + let effective_persistent_id = effective_persistent_id( + &mut self.persistence_wrapper, reader.is_internal(), external_persistent_id, RequiredPersistenceMode::InputOrOperatorPersistence, - || { + |_| { let generated_external_id = reader.name(None, self.connector_monitors.len()); reader .update_persistent_id(Some(generated_external_id.clone().into_persistent_id())); diff --git a/src/engine/dataflow/persist.rs b/src/engine/dataflow/persist.rs index daed041e..516dc7b2 100644 --- a/src/engine/dataflow/persist.rs +++ b/src/engine/dataflow/persist.rs @@ -13,21 +13,62 @@ use differential_dataflow::{AsCollection, Collection, ExchangeData}; use log::error; use ordered_float::OrderedFloat; use timely::dataflow::channels::pact::Exchange; -use timely::dataflow::operators::{Capability, Operator}; +use timely::dataflow::operators::{Capability, Filter, Operator}; use timely::dataflow::Scope; use timely::{order::TotalOrder, progress::Timestamp as TimelyTimestampTrait}; use crate::engine::reduce::IntSumState; -use crate::engine::{Key, Result, Timestamp, Value}; +use crate::engine::{Error, Key, Result, Timestamp, Value}; use crate::persistence::config::PersistenceManagerConfig; use crate::persistence::operator_snapshot::{OperatorSnapshotReader, OperatorSnapshotWriter}; -use crate::persistence::tracker::{SharedWorkerPersistentStorage, WorkerPersistentStorage}; -use crate::persistence::{PersistenceTime, PersistentId}; - -use super::maybe_total::MaybeTotalScope; -use super::{shard::Shard, Poller}; +use crate::persistence::tracker::{ + RequiredPersistenceMode, SharedWorkerPersistentStorage, WorkerPersistentStorage, +}; +use crate::persistence::{ExternalPersistentId, PersistenceTime, PersistentId}; + +use crate::engine::dataflow::maybe_total::MaybeTotalScope; +use crate::engine::dataflow::shard::Shard; +use crate::engine::dataflow::{MaybeUpdate, Poller, SortingCell}; + +pub(super) fn effective_persistent_id( + persistence_wrapper: &mut Box>, + reader_is_internal: bool, + external_persistent_id: Option<&ExternalPersistentId>, + required_persistence_mode: RequiredPersistenceMode, + logic: impl FnOnce(u64) -> String, +) -> Result> +where + S: MaybeTotalScope, +{ + let has_persistent_storage = persistence_wrapper + .get_worker_persistent_storage() + .is_some(); + if let Some(external_persistent_id) = external_persistent_id { + if has_persistent_storage { + Ok(Some(external_persistent_id.clone())) + } else { + Err(Error::NoPersistentStorage(external_persistent_id.clone())) + } + } else if has_persistent_storage && !reader_is_internal { + let next_state_id = persistence_wrapper.next_state_id(); + let worker_persistent_storage = persistence_wrapper + .get_worker_persistent_storage() + .unwrap() + .lock() + .unwrap(); + if worker_persistent_storage.persistent_id_generation_enabled(required_persistence_mode) + && worker_persistent_storage.table_persistence_enabled() + { + Ok(Some(logic(next_state_id))) + } else { + Ok(None) + } + } else { + Ok(None) + } +} -pub trait PersistenceWrapper +pub(super) trait PersistenceWrapper where S: MaybeTotalScope, { @@ -43,6 +84,11 @@ where Option, Option>, )>; + fn filter_out_persisted( + &self, + collection: PersistableCollection, + ) -> PersistableCollection; + fn next_state_id(&mut self) -> u64; } pub struct EmptyPersistenceWrapper; @@ -71,6 +117,17 @@ where )> { Ok((collection, None, None)) } + + fn filter_out_persisted( + &self, + collection: PersistableCollection, + ) -> PersistableCollection { + collection + } + + fn next_state_id(&mut self) -> u64 { + 0 + } } /// Why is `PersistableCollection` needed? We could have generic `maybe_persist_named` instead? @@ -87,7 +144,7 @@ where /// To handle this, operator snapshot writer is created in a separate object (instance of `TimestampBasedPersistenceWrapper`) /// that is aware that `MaybeTotalTimestamp` = `Timestamp`. -pub enum PersistableCollection { +pub(super) enum PersistableCollection { KeyValueIsize(Collection), KeyIntSumState(Collection), KeyIsize(Collection), @@ -99,6 +156,13 @@ pub enum PersistableCollection { Collection, Key, Value)>>), isize>, ), KeyOptionKeyValue(Collection), isize>), + SortingCellIsize(Collection), + KeyMaybeUpdateIsize(Collection), isize>), + KeyKeyIsize(Collection), + KeyKeyValueIsize(Collection), + KeyIntSumStateIsize(Collection), + KeyIsizeIsize(Collection), + KeyKeyValueKeyValueIsize(Collection), } macro_rules! impl_conversion { @@ -158,10 +222,34 @@ impl_conversion!( (Key, Option<(Key, Value)>), isize ); +impl_conversion!(PersistableCollection::SortingCellIsize, SortingCell, isize); +impl_conversion!( + PersistableCollection::KeyMaybeUpdateIsize, + (Key, MaybeUpdate), + isize +); +impl_conversion!(PersistableCollection::KeyKeyIsize, (Key, Key), isize); +impl_conversion!( + PersistableCollection::KeyKeyValueIsize, + (Key, (Key, Value)), + isize +); +impl_conversion!( + PersistableCollection::KeyIntSumStateIsize, + (Key, IntSumState), + isize +); +impl_conversion!(PersistableCollection::KeyIsizeIsize, (Key, isize), isize); +impl_conversion!( + PersistableCollection::KeyKeyValueKeyValueIsize, + (Key, (Key, Value), (Key, Value)), + isize +); pub struct TimestampBasedPersistenceWrapper { persistence_config: PersistenceManagerConfig, worker_persistent_storage: SharedWorkerPersistentStorage, + persisted_states_count: u64, } impl TimestampBasedPersistenceWrapper { @@ -172,6 +260,7 @@ impl TimestampBasedPersistenceWrapper { Ok(Self { persistence_config, worker_persistent_storage, + persisted_states_count: 0, }) } @@ -199,6 +288,22 @@ impl TimestampBasedPersistenceWrapper { } } +fn generic_filter_out_persisted( + collection: &Collection, +) -> PersistableCollection +where + S: MaybeTotalScope, + D: ExchangeData + Shard, + R: ExchangeData + Semigroup, + Collection: Into>, +{ + collection + .inner + .filter(|(_data, time, _diff)| *time != PersistenceTime::persistence_time()) + .as_collection() + .into() +} + impl> PersistenceWrapper for TimestampBasedPersistenceWrapper { @@ -248,8 +353,90 @@ impl> PersistenceWrapper PersistableCollection::KeyOptionKeyValue(collection) => { self.generic_maybe_persist(&collection, name, persistent_id) } + PersistableCollection::SortingCellIsize(collection) => { + self.generic_maybe_persist(&collection, name, persistent_id) + } + PersistableCollection::KeyMaybeUpdateIsize(collection) => { + self.generic_maybe_persist(&collection, name, persistent_id) + } + PersistableCollection::KeyKeyIsize(collection) => { + self.generic_maybe_persist(&collection, name, persistent_id) + } + PersistableCollection::KeyKeyValueIsize(collection) => { + self.generic_maybe_persist(&collection, name, persistent_id) + } + PersistableCollection::KeyIntSumStateIsize(collection) => { + self.generic_maybe_persist(&collection, name, persistent_id) + } + PersistableCollection::KeyIsizeIsize(collection) => { + self.generic_maybe_persist(&collection, name, persistent_id) + } + PersistableCollection::KeyKeyValueKeyValueIsize(collection) => { + self.generic_maybe_persist(&collection, name, persistent_id) + } } } + + fn filter_out_persisted( + &self, + collection: PersistableCollection, + ) -> PersistableCollection { + match collection { + PersistableCollection::KeyValueIsize(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyIntSumState(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyIsize(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyOptionOrderderFloatIsize(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyOptionValueIsize(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyOptionValueKeyIsize(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyOptionVecValueIsize(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyOptionVecOptionValueKeyValue(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyOptionKeyValue(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::SortingCellIsize(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyMaybeUpdateIsize(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyKeyIsize(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyKeyValueIsize(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyIntSumStateIsize(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyIsizeIsize(collection) => { + generic_filter_out_persisted(&collection) + } + PersistableCollection::KeyKeyValueKeyValueIsize(collection) => { + generic_filter_out_persisted(&collection) + } + } + } + + fn next_state_id(&mut self) -> u64 { + self.persisted_states_count += 1; + self.persisted_states_count + } } struct CapabilityOrdWrapper(Capability); @@ -331,7 +518,7 @@ where .spawn(move || { let data = reader.load_persisted(); if let Err(e) = sender.send(data) { - error!("Failed to send data from persistence: {e}"); // FIXME possibly exit + error!("Failed to send data from persistence: {e}"); } }) .expect("persistence read thread creation should succeed"); @@ -347,8 +534,7 @@ where ControlFlow::Continue(Some(next_try_at)) } Ok(Err(backend_error)) => { - error!("Error while reading persisted data: {backend_error}"); - ControlFlow::Continue(Some(next_try_at)) + panic!("Error while reading persisted data: {backend_error}"); // TODO make pollers return Result } Err(TryRecvError::Empty) => ControlFlow::Continue(Some(next_try_at)), Err(TryRecvError::Disconnected) => { diff --git a/src/persistence/config.rs b/src/persistence/config.rs index 48ddc02b..3bab08b9 100644 --- a/src/persistence/config.rs +++ b/src/persistence/config.rs @@ -9,24 +9,25 @@ use std::collections::HashMap; use std::fs; use std::io::Error as IoError; use std::path::{Path, PathBuf}; -use std::sync::{Arc, Mutex}; +use std::sync::{mpsc, Arc, Mutex}; use std::time::Duration; use s3::bucket::Bucket as S3Bucket; -use crate::persistence::input_snapshot::{ - Event, InputSnapshotReader, InputSnapshotWriter, MockSnapshotReader, ReadInputSnapshot, - SnapshotMode, -}; - use crate::connectors::{PersistenceMode, SnapshotAccess}; use crate::deepcopy::DeepCopy; -use crate::engine::{Timestamp, TotalFrontier}; +use crate::engine::error::DynError; +use crate::engine::license::License; +use crate::engine::{Result, Timestamp, TotalFrontier}; use crate::fs_helpers::ensure_directory; use crate::persistence::backends::{ FilesystemKVStorage, MockKVStorage, PersistenceBackend, S3KVStorage, }; use crate::persistence::cached_object_storage::CachedObjectStorage; +use crate::persistence::input_snapshot::{ + Event, InputSnapshotReader, InputSnapshotWriter, MockSnapshotReader, ReadInputSnapshot, + SnapshotMode, +}; use crate::persistence::operator_snapshot::{ ConcreteSnapshotMerger, ConcreteSnapshotReader, ConcreteSnapshotWriter, MultiConcreteSnapshotReader, @@ -97,6 +98,15 @@ impl PersistenceManagerOuterConfig { pub fn into_inner(self, worker_id: usize, total_workers: usize) -> PersistenceManagerConfig { PersistenceManagerConfig::new(self, worker_id, total_workers) } + + pub fn validate(&self, license: &License) -> Result<()> { + if matches!(self.persistence_mode, PersistenceMode::OperatorPersisting) { + license + .check_entitlements(["full_persistence"]) + .map_err(DynError::from)?; + } + Ok(()) + } } /// The main persistent manager config, which, however can only be @@ -417,11 +427,36 @@ impl PersistenceManagerConfig { Ok(assigned_paths) } - pub fn create_operator_snapshot_readers( - &self, + fn create_operator_snapshot_merger( + &mut self, + persistent_id: PersistentId, + receiver: mpsc::Receiver<()>, + ) -> Result + where + D: ExchangeData, + R: ExchangeData + Semigroup, + { + let merger_backend = self.get_writer_backend(persistent_id)?; + let metadata_backend = self.backend.create()?; + let time_querier = FinalizedTimeQuerier::new(metadata_backend, self.total_workers); + let merger = ConcreteSnapshotMerger::new::( + merger_backend, + self.snapshot_interval, + time_querier, + receiver, + ); + Ok(merger) + } + + pub fn create_operator_snapshot_readers( + &mut self, persistent_id: PersistentId, threshold_time: TotalFrontier, - ) -> Result { + ) -> Result<(MultiConcreteSnapshotReader, ConcreteSnapshotMerger), PersistenceBackendError> + where + D: ExchangeData, + R: ExchangeData + Semigroup, + { info!("Using threshold time: {threshold_time:?}"); let mut readers: Vec = Vec::new(); let backends = @@ -430,27 +465,22 @@ impl PersistenceManagerConfig { let reader = ConcreteSnapshotReader::new(backend, threshold_time); readers.push(reader); } - Ok(MultiConcreteSnapshotReader::new(readers)) + let (sender, receiver) = mpsc::channel(); // pair used to block merger until reader finishes + let reader = MultiConcreteSnapshotReader::new(readers, sender); + let merger = self.create_operator_snapshot_merger::(persistent_id, receiver)?; + Ok((reader, merger)) } pub fn create_operator_snapshot_writer( &mut self, persistent_id: PersistentId, - ) -> Result<(ConcreteSnapshotWriter, ConcreteSnapshotMerger), PersistenceBackendError> + ) -> Result, PersistenceBackendError> where D: ExchangeData, R: ExchangeData + Semigroup, { let backend = self.get_writer_backend(persistent_id)?; - let merger_backend = self.get_writer_backend(persistent_id)?; let writer = ConcreteSnapshotWriter::new(backend, self.snapshot_interval); - let metadata_backend = self.backend.create()?; - let time_querier = FinalizedTimeQuerier::new(metadata_backend, self.total_workers); - let merger = ConcreteSnapshotMerger::new::( - merger_backend, - self.snapshot_interval, - time_querier, - ); - Ok((writer, merger)) + Ok(writer) } } diff --git a/src/persistence/operator_snapshot.rs b/src/persistence/operator_snapshot.rs index 6906d2b0..d96bb241 100644 --- a/src/persistence/operator_snapshot.rs +++ b/src/persistence/operator_snapshot.rs @@ -197,11 +197,15 @@ where pub struct MultiConcreteSnapshotReader { snapshot_readers: Vec, + sender: mpsc::Sender<()>, } impl MultiConcreteSnapshotReader { - pub fn new(snapshot_readers: Vec) -> Self { - Self { snapshot_readers } + pub fn new(snapshot_readers: Vec, sender: mpsc::Sender<()>) -> Self { + Self { + snapshot_readers, + sender, + } } } @@ -220,6 +224,7 @@ where result.append(&mut v); } consolidate(&mut result); + self.sender.send(()).expect("merger should exist"); // inform merger that it can start its work Ok(result) } } @@ -344,13 +349,14 @@ impl ConcreteSnapshotMerger { backend: Box, snapshot_interval: core::time::Duration, time_querier: FinalizedTimeQuerier, + receiver: mpsc::Receiver<()>, ) -> Self where D: ExchangeData, R: ExchangeData + Semigroup, { let (finish_sender, thread_handle) = - Self::start::(backend, snapshot_interval, time_querier); + Self::start::(backend, snapshot_interval, time_querier, receiver); Self { finish_sender, thread_handle: Some(thread_handle), @@ -436,10 +442,14 @@ impl ConcreteSnapshotMerger { receiver: &mpsc::Receiver<()>, timeout: core::time::Duration, time_querier: &mut FinalizedTimeQuerier, + reader_finished_receiver: &mpsc::Receiver<()>, ) where D: ExchangeData, R: ExchangeData + Semigroup, { + if reader_finished_receiver.recv().is_err() { + error!("Can't start snapshot merger as snapshot reader didn't finish gracefully"); + } let mut next_try_at = Instant::now(); loop { let now = Instant::now(); @@ -464,6 +474,7 @@ impl ConcreteSnapshotMerger { backend: Box, timeout: core::time::Duration, mut time_querier: FinalizedTimeQuerier, + reader_finished_receiver: mpsc::Receiver<()>, ) -> (mpsc::Sender<()>, thread::JoinHandle<()>) where D: ExchangeData, @@ -473,7 +484,15 @@ impl ConcreteSnapshotMerger { let (sender, receiver) = mpsc::channel(); let thread_handle = thread::Builder::new() .name("SnapshotMerger".to_string()) // TODO maybe better name - .spawn(move || Self::run::(backend, &receiver, timeout, &mut time_querier)) + .spawn(move || { + Self::run::( + backend, + &receiver, + timeout, + &mut time_querier, + &reader_finished_receiver, + ); + }) .expect("persistence read thread creation should succeed"); (sender, thread_handle) } diff --git a/src/persistence/tracker.rs b/src/persistence/tracker.rs index 896d4f67..73ad39bc 100644 --- a/src/persistence/tracker.rs +++ b/src/persistence/tracker.rs @@ -263,17 +263,19 @@ impl WorkerPersistentStorage { } pub fn create_operator_snapshot_reader( - &self, + &mut self, persistent_id: PersistentId, ) -> Result + Send>, PersistenceBackendError> where D: ExchangeData, R: ExchangeData + Semigroup, { - Ok(Box::new(self.config.create_operator_snapshot_readers( + let (reader, merger) = self.config.create_operator_snapshot_readers::( persistent_id, self.metadata_storage.past_runs_threshold_time(), - )?)) + )?; + self.operator_snapshot_mergers.push(merger); + Ok(Box::new(reader)) } pub fn create_operator_snapshot_writer( @@ -284,12 +286,11 @@ impl WorkerPersistentStorage { D: ExchangeData, R: ExchangeData + Semigroup, { - let (writer, merger) = self.config.create_operator_snapshot_writer(persistent_id)?; + let writer = self.config.create_operator_snapshot_writer(persistent_id)?; let writer = Arc::new(Mutex::new(writer)); let writer_flushable: Arc> = writer.clone(); self.operator_snapshot_writers .insert(persistent_id, writer_flushable); - self.operator_snapshot_mergers.push(merger); Ok(writer) } } diff --git a/src/python_api.rs b/src/python_api.rs index ef56f738..d20aeb0f 100644 --- a/src/python_api.rs +++ b/src/python_api.rs @@ -3305,14 +3305,16 @@ pub fn run_with_new_graph( let config = Config::from_env().map_err(|msg| { PyErr::from_type_bound(ENGINE_ERROR_TYPE.bind(py).clone(), msg.to_string()) })?; + let license = License::new(license_key)?; let persistence_config = { if let Some(persistence_config) = persistence_config { - Some(persistence_config.prepare(py)?) + let persistence_config = persistence_config.prepare(py)?; + persistence_config.validate(&license)?; + Some(persistence_config) } else { None } }; - let license = License::new(license_key)?; let telemetry_config = EngineTelemetryConfig::create(&license, run_id, monitoring_server, trace_parent)?; let results: Vec> = run_with_wakeup_receiver(py, |wakeup_receiver| { diff --git a/tests/integration/test_operator_persistence.rs b/tests/integration/test_operator_persistence.rs index 89802c41..d4cb45f7 100644 --- a/tests/integration/test_operator_persistence.rs +++ b/tests/integration/test_operator_persistence.rs @@ -25,7 +25,7 @@ use serde::Deserialize; use std::collections::HashMap; use std::fmt::Debug; use std::ops::ControlFlow; -use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::{mpsc, Arc, Mutex, MutexGuard}; use std::thread; use std::time::{Duration, SystemTime}; use timely::communication::allocator::Generic; @@ -420,11 +420,16 @@ fn test_operator_snapshot_reader_reads_correct_files_1() { .times(1) .returning(|_key| Ok(())); } - let mut reader = MultiConcreteSnapshotReader::new(vec![ConcreteSnapshotReader::new( - Box::new(backend), - TotalFrontier::At(Timestamp(34)), - )]); + let (sender, receiver) = mpsc::channel(); + let mut reader = MultiConcreteSnapshotReader::new( + vec![ConcreteSnapshotReader::new( + Box::new(backend), + TotalFrontier::At(Timestamp(34)), + )], + sender, + ); assert_eq!(reader.load_persisted().unwrap(), vec![(3, 7)]); + receiver.recv().unwrap(); } #[test] @@ -480,11 +485,16 @@ fn test_operator_snapshot_reader_consolidates() { .times(1) .returning(|_key| Ok(())); } - let mut reader = MultiConcreteSnapshotReader::new(vec![ConcreteSnapshotReader::new( - Box::new(backend), - TotalFrontier::At(Timestamp(22)), - )]); + let (sender, receiver) = mpsc::channel(); + let mut reader = MultiConcreteSnapshotReader::new( + vec![ConcreteSnapshotReader::new( + Box::new(backend), + TotalFrontier::At(Timestamp(22)), + )], + sender, + ); let mut result = reader.load_persisted().unwrap(); + receiver.recv().unwrap(); result.sort(); assert_eq!( result,