Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit bc0f757

Browse files
add coltype aware matching and warning
1 parent 64eb985 commit bc0f757

File tree

3 files changed

+39
-24
lines changed

3 files changed

+39
-24
lines changed

data_diff/hashdiff_tables.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from runtype import dataclass
99

10-
from sqeleton.abcs import ColType_UUID, NumericType, PrecisionType, StringType, Boolean
10+
from sqeleton.abcs import ColType_UUID, NumericType, PrecisionType, StringType, Boolean, JSONType
1111

1212
from .info_tree import InfoTree
1313
from .utils import safezip, diffs_are_equiv_jsons
@@ -24,10 +24,7 @@
2424
logger = logging.getLogger("hashdiff_tables")
2525

2626

27-
def diff_sets(a: list, b: list, has_json_cols: bool = None) -> Iterator:
28-
# check unless the only item is the key. TODO: pass a boolean to know whether the schema has json columns or not
29-
has_json_cols = len(a[0]) > 1
30-
27+
def diff_sets(a: list, b: list, json_cols: dict = None) -> Iterator:
3128
sa = set(a)
3229
sb = set(b)
3330

@@ -41,9 +38,17 @@ def diff_sets(a: list, b: list, has_json_cols: bool = None) -> Iterator:
4138
if row not in sa:
4239
d[row[0]].append(("+", row))
4340

41+
warned_diff_cols = set()
4442
for _k, v in sorted(d.items(), key=lambda i: i[0]):
45-
if has_json_cols and diffs_are_equiv_jsons(v):
46-
continue # don't count this as a diff, maybe do and send a warning, maybe parametrized ??
43+
if json_cols:
44+
parsed_match, overriden_diff_cols = diffs_are_equiv_jsons(v, json_cols)
45+
if parsed_match:
46+
to_warn = overriden_diff_cols - warned_diff_cols
47+
for w in to_warn:
48+
logger.warning(f"Equivalent JSON objects with different string representations detected "
49+
f"in column '{w}'. These cases are NOT reported as differences.")
50+
warned_diff_cols.add(w)
51+
continue
4752
yield from v
4853

4954

@@ -199,7 +204,9 @@ def _bisect_and_diff_segments(
199204
# This saves time, as bisection speed is limited by ping and query performance.
200205
if max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2:
201206
rows1, rows2 = self._threaded_call("get_values", [table1, table2])
202-
diff = list(diff_sets(rows1, rows2))
207+
json_cols = {i: colname for i, colname in enumerate(table1.extra_columns)
208+
if isinstance(table1._schema[colname], JSONType)}
209+
diff = list(diff_sets(rows1, rows2, json_cols))
203210

204211
info_tree.info.set_diff(diff)
205212
info_tree.info.rowcounts = {1: len(rows1), 2: len(rows2)}

data_diff/utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,20 +75,25 @@ def get_timestamp(_match):
7575
return re.sub("%t", get_timestamp, name)
7676

7777

78-
def _jsons_equal(a, b):
78+
def _jsons_equiv(a: str, b: str):
7979
try:
8080
return json.loads(a) == json.loads(b)
8181
except (ValueError, TypeError, json.decoder.JSONDecodeError): # not valid jsons
8282
return False
8383

8484

85-
def diffs_are_equiv_jsons(v):
86-
if (len(v) != 2) or ({v[0][0], v[1][0]} != {'+', '-'}): # ignore rows that are missing in one of the tables
85+
def diffs_are_equiv_jsons(diff: list, json_cols: dict):
86+
if (len(diff) != 2) or ({diff[0][0], diff[1][0]} != {'+', '-'}):
8787
return False
88-
# check all extra columns. TODO: would be more efficient if we pass the indices of json cols to only compare those
8988
match = True
90-
for col_a, col_b in safezip(v[0][1][1:], v[1][1][1:]):
91-
match = (col_a == col_b) or _jsons_equal(col_a, col_b)
89+
overriden_diff_cols = set()
90+
for i, (col_a, col_b) in enumerate(safezip(diff[0][1][1:], diff[1][1][1:])): # index 0 is extra_columns first elem
91+
# we only attempt to parse columns of JSONType, but we still need to check if non-json columns don't match
92+
match = col_a == col_b
93+
if not match and (i in json_cols):
94+
if _jsons_equiv(col_a, col_b):
95+
overriden_diff_cols.add(json_cols[i])
96+
match = True
9297
if not match:
9398
break
94-
return match
99+
return match, overriden_diff_cols

tests/test_database_types.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def expand_params(testcase_func, param_num, param):
569569
return name
570570

571571

572-
def _insert_to_table(conn, table_path, values, type):
572+
def _insert_to_table(conn, table_path, values, coltype):
573573
tbl = table(table_path)
574574

575575
current_n_rows = conn.query(tbl.count(), int)
@@ -578,31 +578,34 @@ def _insert_to_table(conn, table_path, values, type):
578578
return
579579
elif current_n_rows > 0:
580580
conn.query(drop_table(table_name))
581-
_create_table_with_indexes(conn, table_path, type)
581+
_create_table_with_indexes(conn, table_path, coltype)
582582

583583
# if BENCHMARK and N_SAMPLES > 10_000:
584584
# description = f"{conn.name}: {table}"
585585
# values = rich.progress.track(values, total=N_SAMPLES, description=description)
586586

587-
if type == "boolean":
587+
if coltype == "boolean":
588588
values = [(i, bool(sample)) for i, sample in values]
589-
elif re.search(r"(time zone|tz)", type):
589+
elif re.search(r"(time zone|tz)", coltype):
590590
values = [(i, sample.replace(tzinfo=timezone.utc)) for i, sample in values]
591591

592592
if isinstance(conn, db.Clickhouse):
593-
if type.startswith("DateTime64"):
593+
if coltype.startswith("DateTime64"):
594594
values = [(i, f"{sample.replace(tzinfo=None)}") for i, sample in values]
595595

596-
elif type == "DateTime":
596+
elif coltype == "DateTime":
597597
# Clickhouse's DateTime does not allow to store micro/milli/nano seconds
598598
values = [(i, str(sample)[:19]) for i, sample in values]
599599

600-
elif type.startswith("Decimal("):
601-
precision = int(type[8:].rstrip(")").split(",")[1])
600+
elif coltype.startswith("Decimal("):
601+
precision = int(coltype[8:].rstrip(")").split(",")[1])
602602
values = [(i, round(sample, precision)) for i, sample in values]
603-
elif isinstance(conn, db.BigQuery) and type == "datetime":
603+
elif isinstance(conn, db.BigQuery) and coltype == "datetime":
604604
values = [(i, Code(f"cast(timestamp '{sample}' as datetime)")) for i, sample in values]
605605

606+
if isinstance(conn, db.Redshift) and coltype == "json":
607+
values = [(i, Code(f"JSON_PARSE('{sample}')")) for i, sample in values]
608+
606609
insert_rows_in_batches(conn, tbl, values, columns=["id", "col"])
607610
conn.query(commit)
608611

0 commit comments

Comments
 (0)