diff --git a/CHANGELOG.md b/CHANGELOG.md index 19eca5b..5dbe121 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Added Clickhouse appropriate versions of comparison level `PairwiseStringDistanceFunctionLevel` and comparison `PairwiseStringDistanceFunctionAtThresholds` to the relevant libraries [#51](https://github.com/ADBond/splinkclickhouse/pull/51) +- `ClickhouseAPI` can now properly register `pandas` tables with string array columns [#51](https://github.com/ADBond/splinkclickhouse/pull/51) + ### Fixed - Table registration in `chdb` now works for pandas tables whose indexes do not have a `0` entry [#49](https://github.com/ADBond/splinkclickhouse/pull/49). diff --git a/splinkclickhouse/clickhouse/database_api.py b/splinkclickhouse/clickhouse/database_api.py index e429fd5..41c62cc 100644 --- a/splinkclickhouse/clickhouse/database_api.py +++ b/splinkclickhouse/clickhouse/database_api.py @@ -109,20 +109,24 @@ def _create_table_from_pandas_frame(self, df: pd.DataFrame, table_name: str) -> sql = f"CREATE OR REPLACE TABLE {table_name} (" first_col = True - for column in df.columns: + for column_name in df.columns: if not first_col: sql += ", " - col_type = df[column].dtype + column = df[column_name] + col_type = column.dtype first_col = False if pd.api.types.is_integer_dtype(col_type): - sql += f"{column} Nullable(UInt32)" + sql += f"{column_name} Nullable(UInt32)" elif pd.api.types.is_float_dtype(col_type): - sql += f"{column} Nullable(Float64)" + sql += f"{column_name} Nullable(Float64)" + elif pd.api.types.is_list_like(column[0]): + sql += f"{column_name} Array(String)" elif pd.api.types.is_string_dtype(col_type): - sql += f"{column} Nullable(String)" + sql += f"{column_name} Nullable(String)" else: raise ValueError(f"Unknown data type {col_type}") sql += ") ENGINE MergeTree ORDER BY tuple()" + return sql diff --git a/splinkclickhouse/comparison_level_library.py b/splinkclickhouse/comparison_level_library.py index e63104c..270734b 100644 --- a/splinkclickhouse/comparison_level_library.py +++ b/splinkclickhouse/comparison_level_library.py @@ -10,6 +10,9 @@ from splink.internals.comparison_level_library import ( DateMetricType, ) +from splink.internals.comparison_level_library import ( + PairwiseStringDistanceFunctionLevel as SplinkPairwiseStringDistanceFunctionLevel, +) from .column_expression import ColumnExpression as CHColumnExpression from .dialect import ClickhouseDialect, SplinkDialect @@ -164,3 +167,44 @@ def create_sql(self, sql_dialect: SplinkDialect) -> str: f"<= {self.time_threshold_seconds}" ) return sql + + +class PairwiseStringDistanceFunctionLevel(SplinkPairwiseStringDistanceFunctionLevel): + def create_sql(self, sql_dialect: SplinkDialect) -> str: + self.col_expression.sql_dialect = sql_dialect + col = self.col_expression + distance_function_name_transpiled = { + "levenshtein": sql_dialect.levenshtein_function_name, + "damerau_levenshtein": sql_dialect.damerau_levenshtein_function_name, + "jaro_winkler": sql_dialect.jaro_winkler_function_name, + "jaro": sql_dialect.jaro_function_name, + }[self.distance_function_name] + + aggregator_func = { + "min": sql_dialect.array_min_function_name, + "max": sql_dialect.array_max_function_name, + }[self._aggregator()] + + # order of the arguments is different in Clickhouse than tha expected by Splink + # specifically the lambda must come first in Clickhouse + # this is not fixable with UDF as having it in second argument in general + # will cause Clickhouse parser to fail + # also need to use a workaround to get 'flatten' equivalent for a single level + return f"""{aggregator_func}( + {sql_dialect.array_transform_function_name}( + pair -> {distance_function_name_transpiled}( + pair[{sql_dialect.array_first_index}], + pair[{sql_dialect.array_first_index + 1}] + ), + arrayReduce( + 'array_concat_agg', + {sql_dialect.array_transform_function_name}( + x -> {sql_dialect.array_transform_function_name}( + y -> [x, y], + {col.name_r} + ), + {col.name_l} + ) + ) + ) + ) {self._comparator()} {self.distance_threshold}""" diff --git a/splinkclickhouse/comparison_library.py b/splinkclickhouse/comparison_library.py index 0d1b4e3..018dbc7 100644 --- a/splinkclickhouse/comparison_library.py +++ b/splinkclickhouse/comparison_library.py @@ -16,6 +16,9 @@ from splink.internals.comparison_library import ( DateOfBirthComparison as SplinkDateOfBirthComparison, ) +from splink.internals.comparison_library import ( + PairwiseStringDistanceFunctionAtThresholds as SplinkPairwiseStringDistanceFunctionAtThresholds, # noqa: E501 (can't keep format and check happy) +) from splink.internals.misc import ensure_is_iterable import splinkclickhouse.comparison_level_library as cll_ch @@ -305,3 +308,24 @@ def create_comparison_levels(self) -> list[ComparisonLevelCreator]: levels.append(cll.ElseLevel()) return levels + + +class PairwiseStringDistanceFunctionAtThresholds( + SplinkPairwiseStringDistanceFunctionAtThresholds +): + def create_comparison_levels(self) -> list[ComparisonLevelCreator]: + return [ + cll.NullLevel(self.col_expression), + # It is assumed that any string distance treats identical + # arrays as the most similar + cll.ArrayIntersectLevel(self.col_expression, min_intersection=1), + *[ + cll_ch.PairwiseStringDistanceFunctionLevel( + self.col_expression, + distance_threshold=threshold, + distance_function_name=self.distance_function_name, + ) + for threshold in self.thresholds + ], + cll.ElseLevel(), + ] diff --git a/splinkclickhouse/dialect.py b/splinkclickhouse/dialect.py index db10a3e..1f4fb4a 100644 --- a/splinkclickhouse/dialect.py +++ b/splinkclickhouse/dialect.py @@ -32,6 +32,22 @@ def jaro_function_name(self) -> str: def jaccard_function_name(self) -> str: return "stringJaccardIndexUTF8" + @property + def array_first_index(self) -> int: + return 1 + + @property + def array_min_function_name(self) -> str: + return "arrayMin" + + @property + def array_max_function_name(self) -> str: + return "arrayMax" + + @property + def array_transform_function_name(self) -> str: + return "arrayMap" + def _regex_extract_raw( self, name: str, pattern: str, capture_group: int = 0 ) -> str: diff --git a/tests/conftest.py b/tests/conftest.py index eedaa57..4ed5b9c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,26 @@ def pytest_collection_modifyitems(items, config): item.add_marker(mark) +_NAMES = ( + "tom", + "tim", + "jen", + "jan", + "ken", + "sam", + "katherine", + "ben", + "benjamin", + "benny", + "jenny", + "jennifer", + "samuel", + "thom", + "thomas", + "thoams", +) + + @fixture def chdb_api_factory(): con = dbapi.connect() @@ -153,7 +173,7 @@ def input_nodes_with_lat_longs(): longs = np.random.uniform(low=long_low, high=long_high, size=n_rows) # also include some names so we have a second comparison names = np.random.choice( - ("tom", "tim", "jen", "jan", "ken", "sam", "katherine"), + _NAMES, size=n_rows, ) return pd.DataFrame( @@ -164,3 +184,19 @@ def input_nodes_with_lat_longs(): "longitude": longs, } ) + + +@fixture +def input_nodes_with_name_arrays(): + n_rows = 1_000 + sizes = np.random.randint(low=1, high=5, size=n_rows) + return pd.DataFrame( + { + "unique_id": range(n_rows), + "aliases": map(lambda s: np.random.choice(_NAMES, size=s), sizes), + "username": np.random.choice( + _NAMES, + size=n_rows, + ), + } + ) diff --git a/tests/test_cl_ch.py b/tests/test_cl_ch.py index 6950f5c..ada4dcb 100644 --- a/tests/test_cl_ch.py +++ b/tests/test_cl_ch.py @@ -1,5 +1,5 @@ import splink.comparison_library as cl -from pytest import raises +from pytest import mark, raises from splink import DuckDBAPI, Linker, SettingsCreator import splinkclickhouse.comparison_library as cl_ch @@ -105,3 +105,37 @@ def test_clickhouse_date_of_birth_comparison(api_info, fake_1000): linker = Linker(fake_1000, settings, db_api) linker.inference.predict() + + +# TODO: for now there's not a straightforward way (afaik) to get an array column +# into chdb. So for the time being we test only clickhouse server version +@mark.clickhouse +@mark.clickhouse_no_core +def test_pairwise_string_distance(clickhouse_api_factory, input_nodes_with_name_arrays): + db_api = clickhouse_api_factory() + + settings = SettingsCreator( + link_type="dedupe_only", + comparisons=[ + cl.ExactMatch("username"), + # can pretend these are distinct + cl_ch.PairwiseStringDistanceFunctionAtThresholds( + "aliases", "levenshtein", [1, 2] + ), + cl_ch.PairwiseStringDistanceFunctionAtThresholds( + "aliases_2", "damerau_levenshtein", [1, 2, 3] + ), + cl_ch.PairwiseStringDistanceFunctionAtThresholds( + "aliases_3", "jaro", [0.88, 0.7] + ), + cl_ch.PairwiseStringDistanceFunctionAtThresholds( + "aliases_4", "jaro_winkler", [0.88, 0.7] + ), + ], + ) + + input_nodes_with_name_arrays["aliases_2"] = input_nodes_with_name_arrays["aliases"] + input_nodes_with_name_arrays["aliases_3"] = input_nodes_with_name_arrays["aliases"] + input_nodes_with_name_arrays["aliases_4"] = input_nodes_with_name_arrays["aliases"] + linker = Linker(input_nodes_with_name_arrays, settings, db_api) + linker.inference.predict()