Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pairwise string comparison #51

Merged
merged 9 commits into from
Dec 16, 2024
Merged
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Added Clickhouse appropriate versions of comparison level `PairwiseStringDistanceFunctionLevel` and comparison `PairwiseStringDistanceFunctionAtThresholds` to the relevant libraries [#51](https://github.com/ADBond/splinkclickhouse/pull/51)
- `ClickhouseAPI` can now properly register `pandas` tables with string array columns [#51](https://github.com/ADBond/splinkclickhouse/pull/51)

### Fixed

- Table registration in `chdb` now works for pandas tables whose indexes do not have a `0` entry [#49](https://github.com/ADBond/splinkclickhouse/pull/49).
Expand Down
14 changes: 9 additions & 5 deletions splinkclickhouse/clickhouse/database_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,24 @@ def _create_table_from_pandas_frame(self, df: pd.DataFrame, table_name: str) ->
sql = f"CREATE OR REPLACE TABLE {table_name} ("

first_col = True
for column in df.columns:
for column_name in df.columns:
if not first_col:
sql += ", "
col_type = df[column].dtype
column = df[column_name]
col_type = column.dtype
first_col = False

if pd.api.types.is_integer_dtype(col_type):
sql += f"{column} Nullable(UInt32)"
sql += f"{column_name} Nullable(UInt32)"
elif pd.api.types.is_float_dtype(col_type):
sql += f"{column} Nullable(Float64)"
sql += f"{column_name} Nullable(Float64)"
elif pd.api.types.is_list_like(column[0]):
sql += f"{column_name} Array(String)"
elif pd.api.types.is_string_dtype(col_type):
sql += f"{column} Nullable(String)"
sql += f"{column_name} Nullable(String)"
else:
raise ValueError(f"Unknown data type {col_type}")

sql += ") ENGINE MergeTree ORDER BY tuple()"

return sql
44 changes: 44 additions & 0 deletions splinkclickhouse/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from splink.internals.comparison_level_library import (
DateMetricType,
)
from splink.internals.comparison_level_library import (
PairwiseStringDistanceFunctionLevel as SplinkPairwiseStringDistanceFunctionLevel,
)

from .column_expression import ColumnExpression as CHColumnExpression
from .dialect import ClickhouseDialect, SplinkDialect
Expand Down Expand Up @@ -164,3 +167,44 @@ def create_sql(self, sql_dialect: SplinkDialect) -> str:
f"<= {self.time_threshold_seconds}"
)
return sql


class PairwiseStringDistanceFunctionLevel(SplinkPairwiseStringDistanceFunctionLevel):
def create_sql(self, sql_dialect: SplinkDialect) -> str:
self.col_expression.sql_dialect = sql_dialect
col = self.col_expression
distance_function_name_transpiled = {
"levenshtein": sql_dialect.levenshtein_function_name,
"damerau_levenshtein": sql_dialect.damerau_levenshtein_function_name,
"jaro_winkler": sql_dialect.jaro_winkler_function_name,
"jaro": sql_dialect.jaro_function_name,
}[self.distance_function_name]

aggregator_func = {
"min": sql_dialect.array_min_function_name,
"max": sql_dialect.array_max_function_name,
}[self._aggregator()]

# order of the arguments is different in Clickhouse than tha expected by Splink
# specifically the lambda must come first in Clickhouse
# this is not fixable with UDF as having it in second argument in general
# will cause Clickhouse parser to fail
# also need to use a workaround to get 'flatten' equivalent for a single level
return f"""{aggregator_func}(
{sql_dialect.array_transform_function_name}(
pair -> {distance_function_name_transpiled}(
pair[{sql_dialect.array_first_index}],
pair[{sql_dialect.array_first_index + 1}]
),
arrayReduce(
'array_concat_agg',
{sql_dialect.array_transform_function_name}(
x -> {sql_dialect.array_transform_function_name}(
y -> [x, y],
{col.name_r}
),
{col.name_l}
)
)
)
) {self._comparator()} {self.distance_threshold}"""
24 changes: 24 additions & 0 deletions splinkclickhouse/comparison_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from splink.internals.comparison_library import (
DateOfBirthComparison as SplinkDateOfBirthComparison,
)
from splink.internals.comparison_library import (
PairwiseStringDistanceFunctionAtThresholds as SplinkPairwiseStringDistanceFunctionAtThresholds, # noqa: E501 (can't keep format and check happy)
)
from splink.internals.misc import ensure_is_iterable

import splinkclickhouse.comparison_level_library as cll_ch
Expand Down Expand Up @@ -305,3 +308,24 @@ def create_comparison_levels(self) -> list[ComparisonLevelCreator]:

levels.append(cll.ElseLevel())
return levels


class PairwiseStringDistanceFunctionAtThresholds(
SplinkPairwiseStringDistanceFunctionAtThresholds
):
def create_comparison_levels(self) -> list[ComparisonLevelCreator]:
return [
cll.NullLevel(self.col_expression),
# It is assumed that any string distance treats identical
# arrays as the most similar
cll.ArrayIntersectLevel(self.col_expression, min_intersection=1),
*[
cll_ch.PairwiseStringDistanceFunctionLevel(
self.col_expression,
distance_threshold=threshold,
distance_function_name=self.distance_function_name,
)
for threshold in self.thresholds
],
cll.ElseLevel(),
]
16 changes: 16 additions & 0 deletions splinkclickhouse/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ def jaro_function_name(self) -> str:
def jaccard_function_name(self) -> str:
return "stringJaccardIndexUTF8"

@property
def array_first_index(self) -> int:
return 1

@property
def array_min_function_name(self) -> str:
return "arrayMin"

@property
def array_max_function_name(self) -> str:
return "arrayMax"

@property
def array_transform_function_name(self) -> str:
return "arrayMap"

def _regex_extract_raw(
self, name: str, pattern: str, capture_group: int = 0
) -> str:
Expand Down
38 changes: 37 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@ def pytest_collection_modifyitems(items, config):
item.add_marker(mark)


_NAMES = (
"tom",
"tim",
"jen",
"jan",
"ken",
"sam",
"katherine",
"ben",
"benjamin",
"benny",
"jenny",
"jennifer",
"samuel",
"thom",
"thomas",
"thoams",
)


@fixture
def chdb_api_factory():
con = dbapi.connect()
Expand Down Expand Up @@ -153,7 +173,7 @@ def input_nodes_with_lat_longs():
longs = np.random.uniform(low=long_low, high=long_high, size=n_rows)
# also include some names so we have a second comparison
names = np.random.choice(
("tom", "tim", "jen", "jan", "ken", "sam", "katherine"),
_NAMES,
size=n_rows,
)
return pd.DataFrame(
Expand All @@ -164,3 +184,19 @@ def input_nodes_with_lat_longs():
"longitude": longs,
}
)


@fixture
def input_nodes_with_name_arrays():
n_rows = 1_000
sizes = np.random.randint(low=1, high=5, size=n_rows)
return pd.DataFrame(
{
"unique_id": range(n_rows),
"aliases": map(lambda s: np.random.choice(_NAMES, size=s), sizes),
"username": np.random.choice(
_NAMES,
size=n_rows,
),
}
)
36 changes: 35 additions & 1 deletion tests/test_cl_ch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import splink.comparison_library as cl
from pytest import raises
from pytest import mark, raises
from splink import DuckDBAPI, Linker, SettingsCreator

import splinkclickhouse.comparison_library as cl_ch
Expand Down Expand Up @@ -105,3 +105,37 @@ def test_clickhouse_date_of_birth_comparison(api_info, fake_1000):

linker = Linker(fake_1000, settings, db_api)
linker.inference.predict()


# TODO: for now there's not a straightforward way (afaik) to get an array column
# into chdb. So for the time being we test only clickhouse server version
@mark.clickhouse
@mark.clickhouse_no_core
def test_pairwise_string_distance(clickhouse_api_factory, input_nodes_with_name_arrays):
db_api = clickhouse_api_factory()

settings = SettingsCreator(
link_type="dedupe_only",
comparisons=[
cl.ExactMatch("username"),
# can pretend these are distinct
cl_ch.PairwiseStringDistanceFunctionAtThresholds(
"aliases", "levenshtein", [1, 2]
),
cl_ch.PairwiseStringDistanceFunctionAtThresholds(
"aliases_2", "damerau_levenshtein", [1, 2, 3]
),
cl_ch.PairwiseStringDistanceFunctionAtThresholds(
"aliases_3", "jaro", [0.88, 0.7]
),
cl_ch.PairwiseStringDistanceFunctionAtThresholds(
"aliases_4", "jaro_winkler", [0.88, 0.7]
),
],
)

input_nodes_with_name_arrays["aliases_2"] = input_nodes_with_name_arrays["aliases"]
input_nodes_with_name_arrays["aliases_3"] = input_nodes_with_name_arrays["aliases"]
input_nodes_with_name_arrays["aliases_4"] = input_nodes_with_name_arrays["aliases"]
linker = Linker(input_nodes_with_name_arrays, settings, db_api)
linker.inference.predict()
Loading