Skip to content

Commit

Permalink
Merge pull request #51 from ADBond/feature/pariwise-string-comparison
Browse files Browse the repository at this point in the history
Pairwise string comparison
  • Loading branch information
ADBond authored Dec 16, 2024
2 parents ee129fa + 06e53fe commit e8de8cd
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 7 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Added Clickhouse appropriate versions of comparison level `PairwiseStringDistanceFunctionLevel` and comparison `PairwiseStringDistanceFunctionAtThresholds` to the relevant libraries [#51](https://github.com/ADBond/splinkclickhouse/pull/51)
- `ClickhouseAPI` can now properly register `pandas` tables with string array columns [#51](https://github.com/ADBond/splinkclickhouse/pull/51)

### Fixed

- Table registration in `chdb` now works for pandas tables whose indexes do not have a `0` entry [#49](https://github.com/ADBond/splinkclickhouse/pull/49).
Expand Down
14 changes: 9 additions & 5 deletions splinkclickhouse/clickhouse/database_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,24 @@ def _create_table_from_pandas_frame(self, df: pd.DataFrame, table_name: str) ->
sql = f"CREATE OR REPLACE TABLE {table_name} ("

first_col = True
for column in df.columns:
for column_name in df.columns:
if not first_col:
sql += ", "
col_type = df[column].dtype
column = df[column_name]
col_type = column.dtype
first_col = False

if pd.api.types.is_integer_dtype(col_type):
sql += f"{column} Nullable(UInt32)"
sql += f"{column_name} Nullable(UInt32)"
elif pd.api.types.is_float_dtype(col_type):
sql += f"{column} Nullable(Float64)"
sql += f"{column_name} Nullable(Float64)"
elif pd.api.types.is_list_like(column[0]):
sql += f"{column_name} Array(String)"
elif pd.api.types.is_string_dtype(col_type):
sql += f"{column} Nullable(String)"
sql += f"{column_name} Nullable(String)"
else:
raise ValueError(f"Unknown data type {col_type}")

sql += ") ENGINE MergeTree ORDER BY tuple()"

return sql
44 changes: 44 additions & 0 deletions splinkclickhouse/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from splink.internals.comparison_level_library import (
DateMetricType,
)
from splink.internals.comparison_level_library import (
PairwiseStringDistanceFunctionLevel as SplinkPairwiseStringDistanceFunctionLevel,
)

from .column_expression import ColumnExpression as CHColumnExpression
from .dialect import ClickhouseDialect, SplinkDialect
Expand Down Expand Up @@ -164,3 +167,44 @@ def create_sql(self, sql_dialect: SplinkDialect) -> str:
f"<= {self.time_threshold_seconds}"
)
return sql


class PairwiseStringDistanceFunctionLevel(SplinkPairwiseStringDistanceFunctionLevel):
def create_sql(self, sql_dialect: SplinkDialect) -> str:
self.col_expression.sql_dialect = sql_dialect
col = self.col_expression
distance_function_name_transpiled = {
"levenshtein": sql_dialect.levenshtein_function_name,
"damerau_levenshtein": sql_dialect.damerau_levenshtein_function_name,
"jaro_winkler": sql_dialect.jaro_winkler_function_name,
"jaro": sql_dialect.jaro_function_name,
}[self.distance_function_name]

aggregator_func = {
"min": sql_dialect.array_min_function_name,
"max": sql_dialect.array_max_function_name,
}[self._aggregator()]

# order of the arguments is different in Clickhouse than tha expected by Splink
# specifically the lambda must come first in Clickhouse
# this is not fixable with UDF as having it in second argument in general
# will cause Clickhouse parser to fail
# also need to use a workaround to get 'flatten' equivalent for a single level
return f"""{aggregator_func}(
{sql_dialect.array_transform_function_name}(
pair -> {distance_function_name_transpiled}(
pair[{sql_dialect.array_first_index}],
pair[{sql_dialect.array_first_index + 1}]
),
arrayReduce(
'array_concat_agg',
{sql_dialect.array_transform_function_name}(
x -> {sql_dialect.array_transform_function_name}(
y -> [x, y],
{col.name_r}
),
{col.name_l}
)
)
)
) {self._comparator()} {self.distance_threshold}"""
24 changes: 24 additions & 0 deletions splinkclickhouse/comparison_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from splink.internals.comparison_library import (
DateOfBirthComparison as SplinkDateOfBirthComparison,
)
from splink.internals.comparison_library import (
PairwiseStringDistanceFunctionAtThresholds as SplinkPairwiseStringDistanceFunctionAtThresholds, # noqa: E501 (can't keep format and check happy)
)
from splink.internals.misc import ensure_is_iterable

import splinkclickhouse.comparison_level_library as cll_ch
Expand Down Expand Up @@ -305,3 +308,24 @@ def create_comparison_levels(self) -> list[ComparisonLevelCreator]:

levels.append(cll.ElseLevel())
return levels


class PairwiseStringDistanceFunctionAtThresholds(
SplinkPairwiseStringDistanceFunctionAtThresholds
):
def create_comparison_levels(self) -> list[ComparisonLevelCreator]:
return [
cll.NullLevel(self.col_expression),
# It is assumed that any string distance treats identical
# arrays as the most similar
cll.ArrayIntersectLevel(self.col_expression, min_intersection=1),
*[
cll_ch.PairwiseStringDistanceFunctionLevel(
self.col_expression,
distance_threshold=threshold,
distance_function_name=self.distance_function_name,
)
for threshold in self.thresholds
],
cll.ElseLevel(),
]
16 changes: 16 additions & 0 deletions splinkclickhouse/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ def jaro_function_name(self) -> str:
def jaccard_function_name(self) -> str:
return "stringJaccardIndexUTF8"

@property
def array_first_index(self) -> int:
return 1

@property
def array_min_function_name(self) -> str:
return "arrayMin"

@property
def array_max_function_name(self) -> str:
return "arrayMax"

@property
def array_transform_function_name(self) -> str:
return "arrayMap"

def _regex_extract_raw(
self, name: str, pattern: str, capture_group: int = 0
) -> str:
Expand Down
38 changes: 37 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@ def pytest_collection_modifyitems(items, config):
item.add_marker(mark)


_NAMES = (
"tom",
"tim",
"jen",
"jan",
"ken",
"sam",
"katherine",
"ben",
"benjamin",
"benny",
"jenny",
"jennifer",
"samuel",
"thom",
"thomas",
"thoams",
)


@fixture
def chdb_api_factory():
con = dbapi.connect()
Expand Down Expand Up @@ -153,7 +173,7 @@ def input_nodes_with_lat_longs():
longs = np.random.uniform(low=long_low, high=long_high, size=n_rows)
# also include some names so we have a second comparison
names = np.random.choice(
("tom", "tim", "jen", "jan", "ken", "sam", "katherine"),
_NAMES,
size=n_rows,
)
return pd.DataFrame(
Expand All @@ -164,3 +184,19 @@ def input_nodes_with_lat_longs():
"longitude": longs,
}
)


@fixture
def input_nodes_with_name_arrays():
n_rows = 1_000
sizes = np.random.randint(low=1, high=5, size=n_rows)
return pd.DataFrame(
{
"unique_id": range(n_rows),
"aliases": map(lambda s: np.random.choice(_NAMES, size=s), sizes),
"username": np.random.choice(
_NAMES,
size=n_rows,
),
}
)
36 changes: 35 additions & 1 deletion tests/test_cl_ch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import splink.comparison_library as cl
from pytest import raises
from pytest import mark, raises
from splink import DuckDBAPI, Linker, SettingsCreator

import splinkclickhouse.comparison_library as cl_ch
Expand Down Expand Up @@ -105,3 +105,37 @@ def test_clickhouse_date_of_birth_comparison(api_info, fake_1000):

linker = Linker(fake_1000, settings, db_api)
linker.inference.predict()


# TODO: for now there's not a straightforward way (afaik) to get an array column
# into chdb. So for the time being we test only clickhouse server version
@mark.clickhouse
@mark.clickhouse_no_core
def test_pairwise_string_distance(clickhouse_api_factory, input_nodes_with_name_arrays):
db_api = clickhouse_api_factory()

settings = SettingsCreator(
link_type="dedupe_only",
comparisons=[
cl.ExactMatch("username"),
# can pretend these are distinct
cl_ch.PairwiseStringDistanceFunctionAtThresholds(
"aliases", "levenshtein", [1, 2]
),
cl_ch.PairwiseStringDistanceFunctionAtThresholds(
"aliases_2", "damerau_levenshtein", [1, 2, 3]
),
cl_ch.PairwiseStringDistanceFunctionAtThresholds(
"aliases_3", "jaro", [0.88, 0.7]
),
cl_ch.PairwiseStringDistanceFunctionAtThresholds(
"aliases_4", "jaro_winkler", [0.88, 0.7]
),
],
)

input_nodes_with_name_arrays["aliases_2"] = input_nodes_with_name_arrays["aliases"]
input_nodes_with_name_arrays["aliases_3"] = input_nodes_with_name_arrays["aliases"]
input_nodes_with_name_arrays["aliases_4"] = input_nodes_with_name_arrays["aliases"]
linker = Linker(input_nodes_with_name_arrays, settings, db_api)
linker.inference.predict()

0 comments on commit e8de8cd

Please sign in to comment.