From 8d3d09b79a1d1cba1f3bd237cc3165a212320983 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 11:06:37 +0000 Subject: [PATCH 01/23] one col at a time --- splink/internals/estimate_u.py | 103 ++++++++++++++++++++++----------- 1 file changed, 68 insertions(+), 35 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 5347a53ebf..5ac104ef8b 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -15,6 +15,7 @@ m_u_records_to_lookup_dict, ) from splink.internals.pipeline import CTEPipeline +from splink.internals.settings import Settings from splink.internals.vertically_concatenate import ( enqueue_df_concat, split_df_concat_with_tf_into_two_tables_sqls, @@ -176,55 +177,87 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non pipeline.enqueue_list_of_sqls(sql_infos) blocked_pairs = linker._db_api.sql_pipeline_to_splink_dataframe(pipeline) - pipeline = CTEPipeline([blocked_pairs, df_sample]) + uid_columns = settings_obj.column_info_settings.unique_id_input_columns + source_dataset_col = settings_obj.column_info_settings.source_dataset_input_column + uid_col = settings_obj.column_info_settings.unique_id_input_column - sqls = compute_comparison_vector_values_from_id_pairs_sqls( - settings_obj._columns_to_select_for_blocking, - settings_obj._columns_to_select_for_comparison_vector_values, - input_tablename_l="__splink__df_concat_sample", - input_tablename_r="__splink__df_concat_sample", - source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, - unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, - ) + # Build common blocking columns (UID columns that all comparisons need) + common_blocking_cols: list[str] = [] + for uid_column in uid_columns: + common_blocking_cols.extend(uid_column.l_r_names_as_l_r) - pipeline.enqueue_list_of_sqls(sqls) + for i, comparison in enumerate(settings_obj.comparisons): + logger.info( + f"\nEstimating u for: {comparison.output_column_name} " + f"({i+1}/{len(settings_obj.comparisons)})" + ) + original_comparison = original_settings_obj.comparisons[i] - sql = """ - select *, cast(0.0 as float8) as match_probability - from __splink__df_comparison_vectors - """ + pipeline = CTEPipeline([blocked_pairs, df_sample]) - pipeline.enqueue_sql(sql, "__splink__df_predict") + # Blocking needs UIDs + comparison-specific columns + blocking_cols = ( + common_blocking_cols + comparison._columns_to_select_for_blocking() + ) - sql = compute_new_parameters_sql( - estimate_without_term_frequencies=False, - comparisons=settings_obj.comparisons, - ) + # Comparison vector needs UIDs + comparison output + match_key + cv_cols = Settings.columns_to_select_for_comparison_vector_values( + unique_id_input_columns=uid_columns, + comparisons=[comparison], + retain_matching_columns=False, + additional_columns_to_retain=[], + ) - pipeline.enqueue_sql(sql, "__splink__m_u_counts") - df_params = db_api.sql_pipeline_to_splink_dataframe(pipeline) + sqls = compute_comparison_vector_values_from_id_pairs_sqls( + blocking_cols, + cv_cols, + input_tablename_l="__splink__df_concat_sample", + input_tablename_r="__splink__df_concat_sample", + source_dataset_input_column=source_dataset_col, + unique_id_input_column=uid_col, + ) - param_records = df_params.as_pandas_dataframe() - param_records = compute_proportions_for_new_parameters(param_records) - df_params.drop_table_from_database_and_remove_from_cache() - df_sample.drop_table_from_database_and_remove_from_cache() - blocked_pairs.drop_table_from_database_and_remove_from_cache() + pipeline.enqueue_list_of_sqls(sqls) + + # Add dummy match_probability column required by compute_new_parameters_sql + sql = """ + select *, cast(0.0 as float8) as match_probability + from __splink__df_comparison_vectors + """ + pipeline.enqueue_sql(sql, "__splink__df_predict") + + # Compute u probability counts for this comparison + sql = compute_new_parameters_sql( + estimate_without_term_frequencies=False, + comparisons=[comparison], + ) + pipeline.enqueue_sql(sql, "__splink__m_u_counts") + + df_params = db_api.sql_pipeline_to_splink_dataframe(pipeline) - m_u_records = [ - r - for r in param_records - if r["output_column_name"] != "_probability_two_random_records_match" - ] + # Convert counts to proportions (u probabilities) + param_records = df_params.as_pandas_dataframe() + param_records = compute_proportions_for_new_parameters(param_records) + df_params.drop_table_from_database_and_remove_from_cache() - m_u_records_lookup = m_u_records_to_lookup_dict(m_u_records) + # Extract just the u records (filter out lambda) + m_u_records = [ + r + for r in param_records + if r["output_column_name"] != "_probability_two_random_records_match" + ] + m_u_records_lookup = m_u_records_to_lookup_dict(m_u_records) - for c in original_settings_obj.comparisons: - for cl in c._comparison_levels_excluding_null: + # Apply estimated u values to the original settings object + for cl in original_comparison._comparison_levels_excluding_null: append_u_probability_to_comparison_level_trained_probabilities( cl, m_u_records_lookup, - c.output_column_name, + original_comparison.output_column_name, "estimate u by random sampling", ) + df_sample.drop_table_from_database_and_remove_from_cache() + blocked_pairs.drop_table_from_database_and_remove_from_cache() + logger.info("\nEstimated u probabilities using random sampling") From 9e3083b0437163ecb64a40942bac4e81fa98f2a5 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 11:25:44 +0000 Subject: [PATCH 02/23] don't materialise blockedp airs --- splink/internals/estimate_u.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 5ac104ef8b..8c7b70d21d 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -145,18 +145,18 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non pipeline.enqueue_sql(sql, "__splink__df_concat_sample") df_sample = db_api.sql_pipeline_to_splink_dataframe(pipeline) - pipeline = CTEPipeline(input_dataframes=[df_sample]) - settings_obj._blocking_rules_to_generate_predictions = [] input_tablename_sample_l = "__splink__df_concat_sample" input_tablename_sample_r = "__splink__df_concat_sample" + split_sqls: list[dict[str, str]] = [] + if ( len(linker._input_tables_dict) == 2 and linker._settings_obj._link_type == "link_only" ): - sqls = split_df_concat_with_tf_into_two_tables_sqls( + split_sqls = split_df_concat_with_tf_into_two_tables_sqls( "__splink__df_concat", linker._settings_obj.column_info_settings.source_dataset_column_name, sample_switch=True, @@ -164,9 +164,7 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non input_tablename_sample_l = "__splink__df_concat_sample_left" input_tablename_sample_r = "__splink__df_concat_sample_right" - pipeline.enqueue_list_of_sqls(sqls) - - sql_infos = block_using_rules_sqls( + blocking_sqls = block_using_rules_sqls( input_tablename_l=input_tablename_sample_l, input_tablename_r=input_tablename_sample_r, blocking_rules=settings_obj._blocking_rules_to_generate_predictions, @@ -174,8 +172,6 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, ) - pipeline.enqueue_list_of_sqls(sql_infos) - blocked_pairs = linker._db_api.sql_pipeline_to_splink_dataframe(pipeline) uid_columns = settings_obj.column_info_settings.unique_id_input_columns source_dataset_col = settings_obj.column_info_settings.source_dataset_input_column @@ -193,7 +189,12 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non ) original_comparison = original_settings_obj.comparisons[i] - pipeline = CTEPipeline([blocked_pairs, df_sample]) + pipeline = CTEPipeline(input_dataframes=[df_sample]) + + if split_sqls: + pipeline.enqueue_list_of_sqls(split_sqls) + + pipeline.enqueue_list_of_sqls(blocking_sqls) # Blocking needs UIDs + comparison-specific columns blocking_cols = ( @@ -208,16 +209,16 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non additional_columns_to_retain=[], ) - sqls = compute_comparison_vector_values_from_id_pairs_sqls( + cv_sqls = compute_comparison_vector_values_from_id_pairs_sqls( blocking_cols, cv_cols, - input_tablename_l="__splink__df_concat_sample", - input_tablename_r="__splink__df_concat_sample", + input_tablename_l=input_tablename_sample_l, + input_tablename_r=input_tablename_sample_r, source_dataset_input_column=source_dataset_col, unique_id_input_column=uid_col, ) - pipeline.enqueue_list_of_sqls(sqls) + pipeline.enqueue_list_of_sqls(cv_sqls) # Add dummy match_probability column required by compute_new_parameters_sql sql = """ @@ -258,6 +259,5 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non ) df_sample.drop_table_from_database_and_remove_from_cache() - blocked_pairs.drop_table_from_database_and_remove_from_cache() logger.info("\nEstimated u probabilities using random sampling") From 4e1ec5a17ff87c2bf00c9a7c55edc377bbd3ddb9 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 13:55:17 +0000 Subject: [PATCH 03/23] chunking --- splink/internals/estimate_u.py | 165 ++++++++++++++++++++++----------- 1 file changed, 110 insertions(+), 55 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 8c7b70d21d..e4c928106f 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -4,8 +4,11 @@ from copy import deepcopy from typing import TYPE_CHECKING, List +import pandas as pd + from splink.internals.blocking import ( block_using_rules_sqls, + BlockingRule, ) from splink.internals.comparison_vector_values import ( compute_comparison_vector_values_from_id_pairs_sqls, @@ -164,14 +167,8 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non input_tablename_sample_l = "__splink__df_concat_sample_left" input_tablename_sample_r = "__splink__df_concat_sample_right" - blocking_sqls = block_using_rules_sqls( - input_tablename_l=input_tablename_sample_l, - input_tablename_r=input_tablename_sample_r, - blocking_rules=settings_obj._blocking_rules_to_generate_predictions, - link_type=linker._settings_obj._link_type, - source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, - unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, - ) + # We chunk only the RHS. Start with a hardcoded chunk count. + rhs_num_chunks = 10 uid_columns = settings_obj.column_info_settings.unique_id_input_columns source_dataset_col = settings_obj.column_info_settings.source_dataset_input_column @@ -189,64 +186,122 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non ) original_comparison = original_settings_obj.comparisons[i] - pipeline = CTEPipeline(input_dataframes=[df_sample]) + # Keep a running total of m/u counts across RHS chunks. + # Keyed by (output_column_name, comparison_vector_value). + counts_lookup: dict[tuple[str, int], dict[str, float | int | str]] = {} - if split_sqls: - pipeline.enqueue_list_of_sqls(split_sqls) + for rhs_chunk_num in range(1, rhs_num_chunks + 1): + logger.info(f" RHS chunk {rhs_chunk_num}/{rhs_num_chunks}") - pipeline.enqueue_list_of_sqls(blocking_sqls) + pipeline = CTEPipeline(input_dataframes=[df_sample]) - # Blocking needs UIDs + comparison-specific columns - blocking_cols = ( - common_blocking_cols + comparison._columns_to_select_for_blocking() - ) + if split_sqls: + pipeline.enqueue_list_of_sqls(split_sqls) - # Comparison vector needs UIDs + comparison output + match_key - cv_cols = Settings.columns_to_select_for_comparison_vector_values( - unique_id_input_columns=uid_columns, - comparisons=[comparison], - retain_matching_columns=False, - additional_columns_to_retain=[], - ) + # Ensure chunking uses the correct backend dialect, even when there are + # no user-provided blocking rules. + blocking_rules = [ + BlockingRule("1=1", sql_dialect_str=db_api.sql_dialect.sql_dialect_str) + ] - cv_sqls = compute_comparison_vector_values_from_id_pairs_sqls( - blocking_cols, - cv_cols, - input_tablename_l=input_tablename_sample_l, - input_tablename_r=input_tablename_sample_r, - source_dataset_input_column=source_dataset_col, - unique_id_input_column=uid_col, - ) + blocking_sqls = block_using_rules_sqls( + input_tablename_l=input_tablename_sample_l, + input_tablename_r=input_tablename_sample_r, + blocking_rules=blocking_rules, + link_type=linker._settings_obj._link_type, + source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, + unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, + right_chunk=(rhs_chunk_num, rhs_num_chunks), + ) - pipeline.enqueue_list_of_sqls(cv_sqls) + # Persist each chunk's blocked pairs under a unique name, but keep + # `__splink__blocked_id_pairs` available for downstream SQL. + chunk_blocked_pairs_table = ( + f"__splink__blocked_id_pairs_R{rhs_chunk_num}of{rhs_num_chunks}" + ) + for s in blocking_sqls: + if s.get("output_table_name") == "__splink__blocked_id_pairs": + s["output_table_name"] = chunk_blocked_pairs_table - # Add dummy match_probability column required by compute_new_parameters_sql - sql = """ - select *, cast(0.0 as float8) as match_probability - from __splink__df_comparison_vectors - """ - pipeline.enqueue_sql(sql, "__splink__df_predict") + pipeline.enqueue_list_of_sqls(blocking_sqls) - # Compute u probability counts for this comparison - sql = compute_new_parameters_sql( - estimate_without_term_frequencies=False, - comparisons=[comparison], - ) - pipeline.enqueue_sql(sql, "__splink__m_u_counts") + sql = f"select * from {chunk_blocked_pairs_table}" + pipeline.enqueue_sql(sql, "__splink__blocked_id_pairs") + + # Blocking needs UIDs + comparison-specific columns + blocking_cols = ( + common_blocking_cols + comparison._columns_to_select_for_blocking() + ) + + # Comparison vector needs UIDs + comparison output + match_key + cv_cols = Settings.columns_to_select_for_comparison_vector_values( + unique_id_input_columns=uid_columns, + comparisons=[comparison], + retain_matching_columns=False, + additional_columns_to_retain=[], + ) + + cv_sqls = compute_comparison_vector_values_from_id_pairs_sqls( + blocking_cols, + cv_cols, + input_tablename_l=input_tablename_sample_l, + input_tablename_r=input_tablename_sample_r, + source_dataset_input_column=source_dataset_col, + unique_id_input_column=uid_col, + ) - df_params = db_api.sql_pipeline_to_splink_dataframe(pipeline) + pipeline.enqueue_list_of_sqls(cv_sqls) - # Convert counts to proportions (u probabilities) - param_records = df_params.as_pandas_dataframe() - param_records = compute_proportions_for_new_parameters(param_records) - df_params.drop_table_from_database_and_remove_from_cache() + # Add dummy match_probability column required by compute_new_parameters_sql + sql = """ + select *, cast(0.0 as float8) as match_probability + from __splink__df_comparison_vectors + """ + pipeline.enqueue_sql(sql, "__splink__df_predict") - # Extract just the u records (filter out lambda) - m_u_records = [ - r - for r in param_records - if r["output_column_name"] != "_probability_two_random_records_match" - ] + # Compute u probability counts for this comparison and chunk + sql = compute_new_parameters_sql( + estimate_without_term_frequencies=False, + comparisons=[comparison], + ) + pipeline.enqueue_sql(sql, "__splink__m_u_counts") + + df_params = db_api.sql_pipeline_to_splink_dataframe(pipeline) + chunk_counts = df_params.as_pandas_dataframe() + df_params.drop_table_from_database_and_remove_from_cache() + + # Drop lambda row: it isn't additive across chunks (it's already a + # proportion), and we don't use it here anyway. + chunk_counts = chunk_counts[ + chunk_counts.output_column_name + != "_probability_two_random_records_match" + ] + + for row in chunk_counts.to_dict("records"): + key = (row["output_column_name"], int(row["comparison_vector_value"])) + if key not in counts_lookup: + counts_lookup[key] = { + "comparison_vector_value": int(row["comparison_vector_value"]), + "output_column_name": row["output_column_name"], + "m_count": float(row["m_count"]), + "u_count": float(row["u_count"]), + } + else: + existing_m_count = float(counts_lookup[key]["m_count"]) + existing_u_count = float(counts_lookup[key]["u_count"]) + counts_lookup[key]["m_count"] = existing_m_count + float( + row["m_count"] + ) + counts_lookup[key]["u_count"] = existing_u_count + float( + row["u_count"] + ) + + aggregated_counts_df = pd.DataFrame(list(counts_lookup.values())) + + # Convert aggregated counts to proportions (u probabilities) + param_records = compute_proportions_for_new_parameters(aggregated_counts_df) + + m_u_records = param_records m_u_records_lookup = m_u_records_to_lookup_dict(m_u_records) # Apply estimated u values to the original settings object From 5a1ed0ca542b77b30911554f92ecbd65612f4e5f Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 14:08:25 +0000 Subject: [PATCH 04/23] use defaultdict --- splink/internals/estimate_u.py | 42 +++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index e4c928106f..c9cc4204ff 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from collections import defaultdict from copy import deepcopy from typing import TYPE_CHECKING, List @@ -188,7 +189,9 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non # Keep a running total of m/u counts across RHS chunks. # Keyed by (output_column_name, comparison_vector_value). - counts_lookup: dict[tuple[str, int], dict[str, float | int | str]] = {} + counts_lookup: defaultdict[tuple[str, int], list[float]] = defaultdict( + lambda: [0.0, 0.0] + ) for rhs_chunk_num in range(1, rhs_num_chunks + 1): logger.info(f" RHS chunk {rhs_chunk_num}/{rhs_num_chunks}") @@ -277,26 +280,23 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non != "_probability_two_random_records_match" ] - for row in chunk_counts.to_dict("records"): - key = (row["output_column_name"], int(row["comparison_vector_value"])) - if key not in counts_lookup: - counts_lookup[key] = { - "comparison_vector_value": int(row["comparison_vector_value"]), - "output_column_name": row["output_column_name"], - "m_count": float(row["m_count"]), - "u_count": float(row["u_count"]), - } - else: - existing_m_count = float(counts_lookup[key]["m_count"]) - existing_u_count = float(counts_lookup[key]["u_count"]) - counts_lookup[key]["m_count"] = existing_m_count + float( - row["m_count"] - ) - counts_lookup[key]["u_count"] = existing_u_count + float( - row["u_count"] - ) - - aggregated_counts_df = pd.DataFrame(list(counts_lookup.values())) + for r in chunk_counts.itertuples(index=False): + key = (r.output_column_name, int(r.comparison_vector_value)) + totals = counts_lookup[key] + totals[0] += float(r.m_count) + totals[1] += float(r.u_count) + + aggregated_counts_df = pd.DataFrame( + [ + { + "output_column_name": ocn, + "comparison_vector_value": cvv, + "m_count": totals[0], + "u_count": totals[1], + } + for (ocn, cvv), totals in counts_lookup.items() + ] + ) # Convert aggregated counts to proportions (u probabilities) param_records = compute_proportions_for_new_parameters(aggregated_counts_df) From 777f515f05383f2e1e33a1555d2d2d762f551372 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 14:12:53 +0000 Subject: [PATCH 05/23] slight update --- splink/internals/estimate_u.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index c9cc4204ff..d2b8cf3d6b 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -149,7 +149,9 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non pipeline.enqueue_sql(sql, "__splink__df_concat_sample") df_sample = db_api.sql_pipeline_to_splink_dataframe(pipeline) - settings_obj._blocking_rules_to_generate_predictions = [] + blocking_rules_for_u = [ + BlockingRule("1=1", sql_dialect_str=db_api.sql_dialect.sql_dialect_str) + ] input_tablename_sample_l = "__splink__df_concat_sample" input_tablename_sample_r = "__splink__df_concat_sample" @@ -201,16 +203,10 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non if split_sqls: pipeline.enqueue_list_of_sqls(split_sqls) - # Ensure chunking uses the correct backend dialect, even when there are - # no user-provided blocking rules. - blocking_rules = [ - BlockingRule("1=1", sql_dialect_str=db_api.sql_dialect.sql_dialect_str) - ] - blocking_sqls = block_using_rules_sqls( input_tablename_l=input_tablename_sample_l, input_tablename_r=input_tablename_sample_r, - blocking_rules=blocking_rules, + blocking_rules=blocking_rules_for_u, link_type=linker._settings_obj._link_type, source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, From 351837d1abc81170bf3646c12e64eaf86d51b036 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 14:32:14 +0000 Subject: [PATCH 06/23] tidy before making chunking adaptive --- splink/internals/estimate_u.py | 177 +++++++++++++++++++-------------- 1 file changed, 101 insertions(+), 76 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index d2b8cf3d6b..4cbaaed1ea 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -8,8 +8,8 @@ import pandas as pd from splink.internals.blocking import ( - block_using_rules_sqls, BlockingRule, + block_using_rules_sqls, ) from splink.internals.comparison_vector_values import ( compute_comparison_vector_values_from_id_pairs_sqls, @@ -19,7 +19,7 @@ m_u_records_to_lookup_dict, ) from splink.internals.pipeline import CTEPipeline -from splink.internals.settings import Settings +from splink.internals.settings import LinkTypeLiteralType, Settings from splink.internals.vertically_concatenate import ( enqueue_df_concat, split_df_concat_with_tf_into_two_tables_sqls, @@ -32,11 +32,86 @@ # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports if TYPE_CHECKING: + from splink.internals.comparison import Comparison + from splink.internals.database_api import DatabaseAPISubClass + from splink.internals.input_column import InputColumn from splink.internals.linker import Linker + from splink.internals.splink_dataframe import SplinkDataFrame logger = logging.getLogger(__name__) +def _u_counts_for_comparison_rhs_chunk( + *, + db_api: "DatabaseAPISubClass", + df_sample: "SplinkDataFrame", + split_sqls: list[dict[str, str]], + input_tablename_sample_l: str, + input_tablename_sample_r: str, + blocking_rules_for_u: list[BlockingRule], + link_type: LinkTypeLiteralType, + source_dataset_input_column: "InputColumn | None", + unique_id_input_column: "InputColumn", + comparison: "Comparison", + blocking_cols: list[str], + cv_cols: list[str], + right_chunk: tuple[int, int] | None, +) -> pd.DataFrame: + pipeline = CTEPipeline(input_dataframes=[df_sample]) + + if split_sqls: + pipeline.enqueue_list_of_sqls(split_sqls) + + blocking_sqls = block_using_rules_sqls( + input_tablename_l=input_tablename_sample_l, + input_tablename_r=input_tablename_sample_r, + blocking_rules=blocking_rules_for_u, + link_type=link_type, + source_dataset_input_column=source_dataset_input_column, + unique_id_input_column=unique_id_input_column, + right_chunk=right_chunk, + ) + + pipeline.enqueue_list_of_sqls(blocking_sqls) + + cv_sqls = compute_comparison_vector_values_from_id_pairs_sqls( + blocking_cols, + cv_cols, + input_tablename_l=input_tablename_sample_l, + input_tablename_r=input_tablename_sample_r, + source_dataset_input_column=source_dataset_input_column, + unique_id_input_column=unique_id_input_column, + ) + + pipeline.enqueue_list_of_sqls(cv_sqls) + + # Add dummy match_probability column required by compute_new_parameters_sql + sql = """ + select *, cast(0.0 as float8) as match_probability + from __splink__df_comparison_vectors + """ + pipeline.enqueue_sql(sql, "__splink__df_predict") + + sql = compute_new_parameters_sql( + estimate_without_term_frequencies=False, + comparisons=[comparison], + ) + pipeline.enqueue_sql(sql, "__splink__m_u_counts") + + df_params = db_api.sql_pipeline_to_splink_dataframe(pipeline) + try: + chunk_counts = df_params.as_pandas_dataframe() + finally: + # Drop final output table + df_params.drop_table_from_database_and_remove_from_cache() + + # Drop lambda row: it isn't additive across chunks (it's already a + # proportion), and we don't use it here anyway. + return chunk_counts[ + chunk_counts.output_column_name != "_probability_two_random_records_match" + ] + + def _rows_needed_for_n_pairs(n_pairs): # Number of pairs generated by cartesian product is # p(r) = r(r-1)/2, where r is input rows @@ -174,8 +249,7 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non rhs_num_chunks = 10 uid_columns = settings_obj.column_info_settings.unique_id_input_columns - source_dataset_col = settings_obj.column_info_settings.source_dataset_input_column - uid_col = settings_obj.column_info_settings.unique_id_input_column + # Note: we pass the actual InputColumn objects through to helper calls. # Build common blocking columns (UID columns that all comparisons need) common_blocking_cols: list[str] = [] @@ -195,87 +269,38 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non lambda: [0.0, 0.0] ) - for rhs_chunk_num in range(1, rhs_num_chunks + 1): - logger.info(f" RHS chunk {rhs_chunk_num}/{rhs_num_chunks}") + # Blocking needs UIDs + comparison-specific columns + blocking_cols = ( + common_blocking_cols + comparison._columns_to_select_for_blocking() + ) - pipeline = CTEPipeline(input_dataframes=[df_sample]) + # Comparison vector needs UIDs + comparison output + match_key + cv_cols = Settings.columns_to_select_for_comparison_vector_values( + unique_id_input_columns=uid_columns, + comparisons=[comparison], + retain_matching_columns=False, + additional_columns_to_retain=[], + ) - if split_sqls: - pipeline.enqueue_list_of_sqls(split_sqls) + for rhs_chunk_num in range(1, rhs_num_chunks + 1): + logger.info(f" RHS chunk {rhs_chunk_num}/{rhs_num_chunks}") - blocking_sqls = block_using_rules_sqls( - input_tablename_l=input_tablename_sample_l, - input_tablename_r=input_tablename_sample_r, - blocking_rules=blocking_rules_for_u, + chunk_counts = _u_counts_for_comparison_rhs_chunk( + db_api=db_api, + df_sample=df_sample, + split_sqls=split_sqls, + input_tablename_sample_l=input_tablename_sample_l, + input_tablename_sample_r=input_tablename_sample_r, + blocking_rules_for_u=blocking_rules_for_u, link_type=linker._settings_obj._link_type, source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, + comparison=comparison, + blocking_cols=blocking_cols, + cv_cols=cv_cols, right_chunk=(rhs_chunk_num, rhs_num_chunks), ) - # Persist each chunk's blocked pairs under a unique name, but keep - # `__splink__blocked_id_pairs` available for downstream SQL. - chunk_blocked_pairs_table = ( - f"__splink__blocked_id_pairs_R{rhs_chunk_num}of{rhs_num_chunks}" - ) - for s in blocking_sqls: - if s.get("output_table_name") == "__splink__blocked_id_pairs": - s["output_table_name"] = chunk_blocked_pairs_table - - pipeline.enqueue_list_of_sqls(blocking_sqls) - - sql = f"select * from {chunk_blocked_pairs_table}" - pipeline.enqueue_sql(sql, "__splink__blocked_id_pairs") - - # Blocking needs UIDs + comparison-specific columns - blocking_cols = ( - common_blocking_cols + comparison._columns_to_select_for_blocking() - ) - - # Comparison vector needs UIDs + comparison output + match_key - cv_cols = Settings.columns_to_select_for_comparison_vector_values( - unique_id_input_columns=uid_columns, - comparisons=[comparison], - retain_matching_columns=False, - additional_columns_to_retain=[], - ) - - cv_sqls = compute_comparison_vector_values_from_id_pairs_sqls( - blocking_cols, - cv_cols, - input_tablename_l=input_tablename_sample_l, - input_tablename_r=input_tablename_sample_r, - source_dataset_input_column=source_dataset_col, - unique_id_input_column=uid_col, - ) - - pipeline.enqueue_list_of_sqls(cv_sqls) - - # Add dummy match_probability column required by compute_new_parameters_sql - sql = """ - select *, cast(0.0 as float8) as match_probability - from __splink__df_comparison_vectors - """ - pipeline.enqueue_sql(sql, "__splink__df_predict") - - # Compute u probability counts for this comparison and chunk - sql = compute_new_parameters_sql( - estimate_without_term_frequencies=False, - comparisons=[comparison], - ) - pipeline.enqueue_sql(sql, "__splink__m_u_counts") - - df_params = db_api.sql_pipeline_to_splink_dataframe(pipeline) - chunk_counts = df_params.as_pandas_dataframe() - df_params.drop_table_from_database_and_remove_from_cache() - - # Drop lambda row: it isn't additive across chunks (it's already a - # proportion), and we don't use it here anyway. - chunk_counts = chunk_counts[ - chunk_counts.output_column_name - != "_probability_two_random_records_match" - ] - for r in chunk_counts.itertuples(index=False): key = (r.output_column_name, int(r.comparison_vector_value)) totals = counts_lookup[key] From c1eec726bbf72d8849f33db7ba53492f7745f396 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 15:06:45 +0000 Subject: [PATCH 07/23] early stop with min count --- splink/internals/estimate_u.py | 82 +++++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 22 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 4cbaaed1ea..32772e94e9 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -from collections import defaultdict from copy import deepcopy from typing import TYPE_CHECKING, List @@ -41,6 +40,49 @@ logger = logging.getLogger(__name__) +class _MUCountsAccumulator: + def __init__(self, comparison: "Comparison") -> None: + self._output_column_name = comparison.output_column_name + self._counts_by_cvv: dict[int, list[float]] = { + int(cl.comparison_vector_value): [0.0, 0.0] + for cl in comparison._comparison_levels_excluding_null + } + + def update_from_chunk_counts(self, chunk_counts: pd.DataFrame) -> None: + for r in chunk_counts.itertuples(index=False): + cvv = int(r.comparison_vector_value) + totals = self._counts_by_cvv.get(cvv) + if totals is None: + continue + totals[0] += float(r.m_count) + totals[1] += float(r.u_count) + + def min_u_count(self) -> float: + if not self._counts_by_cvv: + return 0.0 + return min(totals[1] for totals in self._counts_by_cvv.values()) + + def all_levels_meet_min_u_count(self, min_count: int) -> bool: + return self.min_u_count() >= min_count + + def to_dataframe(self) -> pd.DataFrame: + return pd.DataFrame( + [ + { + "output_column_name": self._output_column_name, + "comparison_vector_value": cvv, + "m_count": totals[0], + "u_count": totals[1], + } + for cvv, totals in sorted(self._counts_by_cvv.items()) + ] + ) + + def pretty_table(self) -> str: + df = self.to_dataframe().drop(columns=["output_column_name"]) + return df.to_string(index=False) + + def _u_counts_for_comparison_rhs_chunk( *, db_api: "DatabaseAPISubClass", @@ -143,6 +185,7 @@ def _proportion_sample_size_link_only( def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> None: logger.info("----- Estimating u probabilities using random sampling -----") + min_count_per_level = 100 pipeline = CTEPipeline() pipeline = enqueue_df_concat(linker, pipeline) @@ -263,11 +306,7 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non ) original_comparison = original_settings_obj.comparisons[i] - # Keep a running total of m/u counts across RHS chunks. - # Keyed by (output_column_name, comparison_vector_value). - counts_lookup: defaultdict[tuple[str, int], list[float]] = defaultdict( - lambda: [0.0, 0.0] - ) + counts_accumulator = _MUCountsAccumulator(comparison) # Blocking needs UIDs + comparison-specific columns blocking_cols = ( @@ -301,23 +340,22 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non right_chunk=(rhs_chunk_num, rhs_num_chunks), ) - for r in chunk_counts.itertuples(index=False): - key = (r.output_column_name, int(r.comparison_vector_value)) - totals = counts_lookup[key] - totals[0] += float(r.m_count) - totals[1] += float(r.u_count) + counts_accumulator.update_from_chunk_counts(chunk_counts) - aggregated_counts_df = pd.DataFrame( - [ - { - "output_column_name": ocn, - "comparison_vector_value": cvv, - "m_count": totals[0], - "u_count": totals[1], - } - for (ocn, cvv), totals in counts_lookup.items() - ] - ) + logger.info( + " Current min u_count across levels: " + f"{counts_accumulator.min_u_count():,.0f}/{min_count_per_level}" + ) + logger.info("\n" + counts_accumulator.pretty_table()) + + if counts_accumulator.all_levels_meet_min_u_count(min_count_per_level): + logger.info( + " Stopping early: all levels reached at least " + f"{min_count_per_level} u observations" + ) + break + + aggregated_counts_df = counts_accumulator.to_dataframe() # Convert aggregated counts to proportions (u probabilities) param_records = compute_proportions_for_new_parameters(aggregated_counts_df) From 2e0fa0633af9e08c717a22d8c8c4b82ff5c22b34 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 15:39:09 +0000 Subject: [PATCH 08/23] handle count 0 properly --- splink/internals/estimate_u.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 32772e94e9..5026c7bc0e 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -10,6 +10,7 @@ BlockingRule, block_using_rules_sqls, ) +from splink.internals.constants import LEVEL_NOT_OBSERVED_TEXT from splink.internals.comparison_vector_values import ( compute_comparison_vector_values_from_id_pairs_sqls, ) @@ -360,6 +361,20 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non # Convert aggregated counts to proportions (u probabilities) param_records = compute_proportions_for_new_parameters(aggregated_counts_df) + # Principled handling of unobserved levels: + # - We explicitly include every level (via enumeration) so that convergence + # checks can treat missing GROUP BY rows as 0 counts. + # - But for the final trained u values, a level with u_count == 0 should be + # treated as "not observed" (not as u_probability = 0.0). + u_count_by_cvv = { + int(row["comparison_vector_value"]): float(row["u_count"]) + for row in aggregated_counts_df.to_dict("records") + } + for r in param_records: + cvv = int(r["comparison_vector_value"]) + if u_count_by_cvv.get(cvv, 0.0) == 0.0: + r["u_probability"] = LEVEL_NOT_OBSERVED_TEXT + m_u_records = param_records m_u_records_lookup = m_u_records_to_lookup_dict(m_u_records) From 13d799942e4aa4dcf62cc64a15b0442cd6e4ac24 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 15:49:07 +0000 Subject: [PATCH 09/23] 100x probe --- splink/internals/estimate_u.py | 136 +++++++++++++++++++++++++-------- 1 file changed, 104 insertions(+), 32 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 5026c7bc0e..5f10a11faf 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -10,10 +10,10 @@ BlockingRule, block_using_rules_sqls, ) -from splink.internals.constants import LEVEL_NOT_OBSERVED_TEXT from splink.internals.comparison_vector_values import ( compute_comparison_vector_values_from_id_pairs_sqls, ) +from splink.internals.constants import LEVEL_NOT_OBSERVED_TEXT from splink.internals.m_u_records_to_parameters import ( append_u_probability_to_comparison_level_trained_probabilities, m_u_records_to_lookup_dict, @@ -155,6 +155,62 @@ def _u_counts_for_comparison_rhs_chunk( ] +def _process_rhs_chunk_and_check_convergence( + *, + db_api: "DatabaseAPISubClass", + df_sample: "SplinkDataFrame", + split_sqls: list[dict[str, str]], + input_tablename_sample_l: str, + input_tablename_sample_r: str, + blocking_rules_for_u: list[BlockingRule], + link_type: LinkTypeLiteralType, + source_dataset_input_column: "InputColumn | None", + unique_id_input_column: "InputColumn", + comparison: "Comparison", + blocking_cols: list[str], + cv_cols: list[str], + rhs_chunk_num: int, + rhs_num_chunks: int, + counts_accumulator: _MUCountsAccumulator, + min_count_per_level: int, + chunk_label: str, +) -> bool: + logger.info(f" {chunk_label} {rhs_chunk_num}/{rhs_num_chunks}") + + chunk_counts = _u_counts_for_comparison_rhs_chunk( + db_api=db_api, + df_sample=df_sample, + split_sqls=split_sqls, + input_tablename_sample_l=input_tablename_sample_l, + input_tablename_sample_r=input_tablename_sample_r, + blocking_rules_for_u=blocking_rules_for_u, + link_type=link_type, + source_dataset_input_column=source_dataset_input_column, + unique_id_input_column=unique_id_input_column, + comparison=comparison, + blocking_cols=blocking_cols, + cv_cols=cv_cols, + right_chunk=(rhs_chunk_num, rhs_num_chunks), + ) + + counts_accumulator.update_from_chunk_counts(chunk_counts) + + logger.info( + " Current min u_count across levels: " + f"{counts_accumulator.min_u_count():,.0f}/{min_count_per_level}" + ) + logger.info("\n" + counts_accumulator.pretty_table()) + + if counts_accumulator.all_levels_meet_min_u_count(min_count_per_level): + logger.info( + " Stopping early: all levels reached at least " + f"{min_count_per_level} u observations" + ) + return True + + return False + + def _rows_needed_for_n_pairs(n_pairs): # Number of pairs generated by cartesian product is # p(r) = r(r-1)/2, where r is input rows @@ -322,39 +378,55 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non additional_columns_to_retain=[], ) - for rhs_chunk_num in range(1, rhs_num_chunks + 1): - logger.info(f" RHS chunk {rhs_chunk_num}/{rhs_num_chunks}") - - chunk_counts = _u_counts_for_comparison_rhs_chunk( - db_api=db_api, - df_sample=df_sample, - split_sqls=split_sqls, - input_tablename_sample_l=input_tablename_sample_l, - input_tablename_sample_r=input_tablename_sample_r, - blocking_rules_for_u=blocking_rules_for_u, - link_type=linker._settings_obj._link_type, - source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, - unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, - comparison=comparison, - blocking_cols=blocking_cols, - cv_cols=cv_cols, - right_chunk=(rhs_chunk_num, rhs_num_chunks), - ) - - counts_accumulator.update_from_chunk_counts(chunk_counts) - - logger.info( - " Current min u_count across levels: " - f"{counts_accumulator.min_u_count():,.0f}/{min_count_per_level}" - ) - logger.info("\n" + counts_accumulator.pretty_table()) + probe_multiplier = 10 + probe_rhs_num_chunks = rhs_num_chunks * probe_multiplier + + converged = _process_rhs_chunk_and_check_convergence( + db_api=db_api, + df_sample=df_sample, + split_sqls=split_sqls, + input_tablename_sample_l=input_tablename_sample_l, + input_tablename_sample_r=input_tablename_sample_r, + blocking_rules_for_u=blocking_rules_for_u, + link_type=linker._settings_obj._link_type, + source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, + unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, + comparison=comparison, + blocking_cols=blocking_cols, + cv_cols=cv_cols, + rhs_chunk_num=1, + rhs_num_chunks=probe_rhs_num_chunks, + counts_accumulator=counts_accumulator, + min_count_per_level=min_count_per_level, + chunk_label="RHS probe chunk", + ) - if counts_accumulator.all_levels_meet_min_u_count(min_count_per_level): - logger.info( - " Stopping early: all levels reached at least " - f"{min_count_per_level} u observations" + if not converged: + logger.info(" Probe did not converge; restarting with normal chunking") + counts_accumulator = _MUCountsAccumulator(comparison) + + for rhs_chunk_num in range(1, rhs_num_chunks + 1): + converged = _process_rhs_chunk_and_check_convergence( + db_api=db_api, + df_sample=df_sample, + split_sqls=split_sqls, + input_tablename_sample_l=input_tablename_sample_l, + input_tablename_sample_r=input_tablename_sample_r, + blocking_rules_for_u=blocking_rules_for_u, + link_type=linker._settings_obj._link_type, + source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, + unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, + comparison=comparison, + blocking_cols=blocking_cols, + cv_cols=cv_cols, + rhs_chunk_num=rhs_chunk_num, + rhs_num_chunks=rhs_num_chunks, + counts_accumulator=counts_accumulator, + min_count_per_level=min_count_per_level, + chunk_label="RHS chunk", ) - break + if converged: + break aggregated_counts_df = counts_accumulator.to_dataframe() From e636b4359e29532d926e28acf49935226d51bb31 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 16:20:37 +0000 Subject: [PATCH 10/23] tests and api --- splink/internals/estimate_u.py | 15 +++++++++++---- splink/internals/linker_components/training.py | 18 ++++++++++++++++-- tests/test_u_train.py | 6 +++--- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 5f10a11faf..b1d32ffa62 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -240,9 +240,16 @@ def _proportion_sample_size_link_only( return proportion, sample_size -def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> None: +def estimate_u_values( + linker: Linker, + max_pairs: float, + seed: int | None = None, + min_count_per_level: int = 100, + num_chunks: int = 10, +) -> None: logger.info("----- Estimating u probabilities using random sampling -----") - min_count_per_level = 100 + if num_chunks < 1: + raise ValueError("num_chunks must be >= 1") pipeline = CTEPipeline() pipeline = enqueue_df_concat(linker, pipeline) @@ -345,8 +352,8 @@ def estimate_u_values(linker: Linker, max_pairs: float, seed: int = None) -> Non input_tablename_sample_l = "__splink__df_concat_sample_left" input_tablename_sample_r = "__splink__df_concat_sample_right" - # We chunk only the RHS. Start with a hardcoded chunk count. - rhs_num_chunks = 10 + # We chunk only the RHS. + rhs_num_chunks = num_chunks uid_columns = settings_obj.column_info_settings.unique_id_input_columns # Note: we pass the actual InputColumn objects through to helper calls. diff --git a/splink/internals/linker_components/training.py b/splink/internals/linker_components/training.py index b0891055d0..7c092e0540 100644 --- a/splink/internals/linker_components/training.py +++ b/splink/internals/linker_components/training.py @@ -163,7 +163,11 @@ def estimate_probability_two_random_records_match( ) def estimate_u_using_random_sampling( - self, max_pairs: float = 1e6, seed: int = None + self, + max_pairs: float = 1e6, + seed: int | None = None, + min_count_per_level: int = 100, + num_chunks: int = 10, ) -> None: """Estimate the u parameters of the linkage model using random sampling. @@ -190,6 +194,10 @@ def estimate_u_using_random_sampling( seed (int): Seed for random sampling. Assign to get reproducible u probabilities. Note, seed for random sampling is only supported for DuckDB and Spark, for SQLite set to None. + min_count_per_level (int): Minimum number of u observations required for + each comparison level before stopping chunking early. Defaults to 100. + num_chunks (int): Number of chunks to split the RHS of the cartesian + product into while estimating u. Defaults to 10. Examples: ```py @@ -209,7 +217,13 @@ def estimate_u_using_random_sampling( "result in more accurate estimates, but with a longer run time." ) - estimate_u_values(self._linker, max_pairs, seed) + estimate_u_values( + self._linker, + max_pairs=max_pairs, + seed=seed, + min_count_per_level=min_count_per_level, + num_chunks=num_chunks, + ) self._linker._populate_m_u_from_trained_values() self._linker._settings_obj._columns_without_estimated_parameters_message() diff --git a/tests/test_u_train.py b/tests/test_u_train.py index ef0f21d91b..92a0a3ec1e 100644 --- a/tests/test_u_train.py +++ b/tests/test_u_train.py @@ -151,7 +151,7 @@ def test_u_train_link_only_sample(test_helpers, dialect): linker._debug_mode = True linker._db_api.debug_keep_temp_views = True - linker.training.estimate_u_using_random_sampling(max_pairs=max_pairs) + linker.training.estimate_u_using_random_sampling(max_pairs=max_pairs, num_chunks=1) # count how many pairs we _actually_ generated in random sampling check_blocking_sql = """ @@ -280,7 +280,7 @@ def test_u_train_multilink(test_helpers, dialect): ) linker._debug_mode = True linker._db_api.debug_keep_temp_views = True - linker.training.estimate_u_using_random_sampling(max_pairs=1e6) + linker.training.estimate_u_using_random_sampling(max_pairs=1e6, num_chunks=1) cc_name = linker._settings_obj.comparisons[0] check_blocking_sql = """ @@ -318,7 +318,7 @@ def test_u_train_multilink(test_helpers, dialect): ) linker._debug_mode = True linker._db_api.debug_keep_temp_views = True - linker.training.estimate_u_using_random_sampling(max_pairs=1e6) + linker.training.estimate_u_using_random_sampling(max_pairs=1e6, num_chunks=1) cc_name = linker._settings_obj.comparisons[0] check_blocking_sql = """ From 2d40d4f841c2154f2adddffd48df221d3a61778b Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 16:39:05 +0000 Subject: [PATCH 11/23] fix tests --- splink/internals/estimate_u.py | 52 ++++++++++++++++++---------------- tests/test_find_new_matches.py | 2 +- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index b1d32ffa62..31efe0d0a3 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -385,32 +385,36 @@ def estimate_u_values( additional_columns_to_retain=[], ) - probe_multiplier = 10 - probe_rhs_num_chunks = rhs_num_chunks * probe_multiplier - - converged = _process_rhs_chunk_and_check_convergence( - db_api=db_api, - df_sample=df_sample, - split_sqls=split_sqls, - input_tablename_sample_l=input_tablename_sample_l, - input_tablename_sample_r=input_tablename_sample_r, - blocking_rules_for_u=blocking_rules_for_u, - link_type=linker._settings_obj._link_type, - source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, - unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, - comparison=comparison, - blocking_cols=blocking_cols, - cv_cols=cv_cols, - rhs_chunk_num=1, - rhs_num_chunks=probe_rhs_num_chunks, - counts_accumulator=counts_accumulator, - min_count_per_level=min_count_per_level, - chunk_label="RHS probe chunk", - ) + use_probe = rhs_num_chunks > 1 + + converged = False + if use_probe: + probe_multiplier = 10 + probe_rhs_num_chunks = rhs_num_chunks * probe_multiplier + converged = _process_rhs_chunk_and_check_convergence( + db_api=db_api, + df_sample=df_sample, + split_sqls=split_sqls, + input_tablename_sample_l=input_tablename_sample_l, + input_tablename_sample_r=input_tablename_sample_r, + blocking_rules_for_u=blocking_rules_for_u, + link_type=linker._settings_obj._link_type, + source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, + unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, + comparison=comparison, + blocking_cols=blocking_cols, + cv_cols=cv_cols, + rhs_chunk_num=1, + rhs_num_chunks=probe_rhs_num_chunks, + counts_accumulator=counts_accumulator, + min_count_per_level=min_count_per_level, + chunk_label="RHS probe chunk", + ) if not converged: - logger.info(" Probe did not converge; restarting with normal chunking") - counts_accumulator = _MUCountsAccumulator(comparison) + if use_probe: + logger.info(" Probe did not converge; restarting with normal chunking") + counts_accumulator = _MUCountsAccumulator(comparison) for rhs_chunk_num in range(1, rhs_num_chunks + 1): converged = _process_rhs_chunk_and_check_convergence( diff --git a/tests/test_find_new_matches.py b/tests/test_find_new_matches.py index 14c4d6506e..91b902c50f 100644 --- a/tests/test_find_new_matches.py +++ b/tests/test_find_new_matches.py @@ -85,7 +85,7 @@ def test_matches_work(test_helpers, dialect): linker = helper.linker_with_registration(df, get_settings_dict()) # Train our model to get more reasonable outputs... - linker.training.estimate_u_using_random_sampling(max_pairs=1e6) + linker.training.estimate_u_using_random_sampling(max_pairs=1e6, num_chunks=1) linker.visualisations.match_weights_chart().save("mwc.html") blocking_rule = block_on("first_name", "surname") From b347e00a84ce9dbfa7d6dbec2ea00b2411d9d24b Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 17:01:16 +0000 Subject: [PATCH 12/23] allow user to skip early exit --- splink/internals/estimate_u.py | 14 +++++++++----- splink/internals/linker_components/training.py | 8 +++++--- tests/test_find_new_matches.py | 6 +++++- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 31efe0d0a3..bc3e257057 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -172,7 +172,7 @@ def _process_rhs_chunk_and_check_convergence( rhs_chunk_num: int, rhs_num_chunks: int, counts_accumulator: _MUCountsAccumulator, - min_count_per_level: int, + min_count_per_level: int | None, chunk_label: str, ) -> bool: logger.info(f" {chunk_label} {rhs_chunk_num}/{rhs_num_chunks}") @@ -195,11 +195,15 @@ def _process_rhs_chunk_and_check_convergence( counts_accumulator.update_from_chunk_counts(chunk_counts) + logger.info("\n" + counts_accumulator.pretty_table()) + + if min_count_per_level is None: + return False + logger.info( " Current min u_count across levels: " f"{counts_accumulator.min_u_count():,.0f}/{min_count_per_level}" ) - logger.info("\n" + counts_accumulator.pretty_table()) if counts_accumulator.all_levels_meet_min_u_count(min_count_per_level): logger.info( @@ -244,7 +248,7 @@ def estimate_u_values( linker: Linker, max_pairs: float, seed: int | None = None, - min_count_per_level: int = 100, + min_count_per_level: int | None = 100, num_chunks: int = 10, ) -> None: logger.info("----- Estimating u probabilities using random sampling -----") @@ -385,7 +389,7 @@ def estimate_u_values( additional_columns_to_retain=[], ) - use_probe = rhs_num_chunks > 1 + use_probe = (rhs_num_chunks > 1) and (min_count_per_level is not None) converged = False if use_probe: @@ -436,7 +440,7 @@ def estimate_u_values( min_count_per_level=min_count_per_level, chunk_label="RHS chunk", ) - if converged: + if converged and (min_count_per_level is not None): break aggregated_counts_df = counts_accumulator.to_dataframe() diff --git a/splink/internals/linker_components/training.py b/splink/internals/linker_components/training.py index 7c092e0540..c4825ae888 100644 --- a/splink/internals/linker_components/training.py +++ b/splink/internals/linker_components/training.py @@ -166,7 +166,7 @@ def estimate_u_using_random_sampling( self, max_pairs: float = 1e6, seed: int | None = None, - min_count_per_level: int = 100, + min_count_per_level: int | None = 100, num_chunks: int = 10, ) -> None: """Estimate the u parameters of the linkage model using random sampling. @@ -194,8 +194,10 @@ def estimate_u_using_random_sampling( seed (int): Seed for random sampling. Assign to get reproducible u probabilities. Note, seed for random sampling is only supported for DuckDB and Spark, for SQLite set to None. - min_count_per_level (int): Minimum number of u observations required for - each comparison level before stopping chunking early. Defaults to 100. + min_count_per_level (int | None): Minimum number of u observations + required for each comparison level before stopping chunking early. + If None, disables the probe phase and disables early stopping (all + chunks are processed). Defaults to 100. num_chunks (int): Number of chunks to split the RHS of the cartesian product into while estimating u. Defaults to 10. diff --git a/tests/test_find_new_matches.py b/tests/test_find_new_matches.py index 91b902c50f..059bbd9ae0 100644 --- a/tests/test_find_new_matches.py +++ b/tests/test_find_new_matches.py @@ -85,7 +85,11 @@ def test_matches_work(test_helpers, dialect): linker = helper.linker_with_registration(df, get_settings_dict()) # Train our model to get more reasonable outputs... - linker.training.estimate_u_using_random_sampling(max_pairs=1e6, num_chunks=1) + linker.training.estimate_u_using_random_sampling( + max_pairs=1e6, + num_chunks=10, + min_count_per_level=None, + ) linker.visualisations.match_weights_chart().save("mwc.html") blocking_rule = block_on("first_name", "surname") From 3ebb0c49fbfdc2b6124e14316936e9d37868e307 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 6 Jan 2026 17:06:18 +0000 Subject: [PATCH 13/23] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d5410cbc73..b00afe1d1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- `estimate_u_using_random_sampling()` now estimates u probabilities using chunking and can stop early once each comparison level has enough u observations (controlled by `min_count_per_level`). This makes u estimation a lot faster and less memory intensive - Support for chunking to allow processing of very large datasets in blocking and prediction [#2850](https://github.com/moj-analytical-services/splink/pull/2850) - New `table_management` functions to explicitly manage table caching [#2848](https://github.com/moj-analytical-services/splink/pull/2848) From 398dd212d015aa800154aa632ee4a505fef9cb0e Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Wed, 7 Jan 2026 09:24:28 +0000 Subject: [PATCH 14/23] add nice logging --- splink/internals/estimate_u.py | 78 ++++++++++++++++++++++++++++++---- 1 file changed, 69 insertions(+), 9 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index bc3e257057..e7f30a528d 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import time from copy import deepcopy from typing import TYPE_CHECKING, List @@ -44,11 +45,25 @@ class _MUCountsAccumulator: def __init__(self, comparison: "Comparison") -> None: self._output_column_name = comparison.output_column_name + self._label_by_cvv: dict[int, str] = {} self._counts_by_cvv: dict[int, list[float]] = { int(cl.comparison_vector_value): [0.0, 0.0] for cl in comparison._comparison_levels_excluding_null } + # Prefer unique labels if there are duplicates. + comparison_levels = list(comparison._comparison_levels_excluding_null) + for cl in comparison_levels: + cvv = int(cl.comparison_vector_value) + label = cl.label_for_charts + if hasattr(cl, "_label_for_charts_no_duplicates"): + try: + label = cl._label_for_charts_no_duplicates(comparison_levels) + except Exception: + # Defensive: never let logging break estimation + label = cl.label_for_charts + self._label_by_cvv[cvv] = str(label) + def update_from_chunk_counts(self, chunk_counts: pd.DataFrame) -> None: for r in chunk_counts.itertuples(index=False): cvv = int(r.comparison_vector_value) @@ -63,6 +78,21 @@ def min_u_count(self) -> float: return 0.0 return min(totals[1] for totals in self._counts_by_cvv.values()) + def min_u_count_level(self) -> tuple[float, int | None, str | None]: + """Return (min_u_count, cvv, label) for the current accumulator.""" + if not self._counts_by_cvv: + return 0.0, None, None + min_cvv: int | None = None + min_u: float | None = None + for cvv in sorted(self._counts_by_cvv): + u_count = float(self._counts_by_cvv[cvv][1]) + if (min_u is None) or (u_count < min_u): + min_u = u_count + min_cvv = cvv + assert min_u is not None + label = self._label_by_cvv.get(min_cvv) if min_cvv is not None else None + return min_u, min_cvv, label + def all_levels_meet_min_u_count(self, min_count: int) -> bool: return self.min_u_count() >= min_count @@ -174,8 +204,16 @@ def _process_rhs_chunk_and_check_convergence( counts_accumulator: _MUCountsAccumulator, min_count_per_level: int | None, chunk_label: str, + probe_percent_of_max_pairs: float | None = None, ) -> bool: - logger.info(f" {chunk_label} {rhs_chunk_num}/{rhs_num_chunks}") + if probe_percent_of_max_pairs is not None: + logger.info( + f" Running probe chunk (~{probe_percent_of_max_pairs:.2f}% of max_pairs)" + ) + else: + logger.info(f" Running chunk {rhs_chunk_num}/{rhs_num_chunks}") + + t0 = time.perf_counter() chunk_counts = _u_counts_for_comparison_rhs_chunk( db_api=db_api, @@ -195,23 +233,37 @@ def _process_rhs_chunk_and_check_convergence( counts_accumulator.update_from_chunk_counts(chunk_counts) - logger.info("\n" + counts_accumulator.pretty_table()) + chunk_elapsed_s = time.perf_counter() - t0 + + logger.debug("\n" + counts_accumulator.pretty_table()) if min_count_per_level is None: + logger.info(f" Chunk took {chunk_elapsed_s:.1f} seconds") return False - logger.info( - " Current min u_count across levels: " - f"{counts_accumulator.min_u_count():,.0f}/{min_count_per_level}" + min_u, min_cvv, min_label = counts_accumulator.min_u_count_level() + level_desc = ( + f"{min_label} (cvv={min_cvv})" if (min_label is not None) else str(min_cvv) ) + if probe_percent_of_max_pairs is not None: + logger.info(f" Min u_count: {min_u:,.0f} for comparison level {level_desc}") + else: + logger.info( + f" Count of {min_u:,.0f} for level {level_desc}. " + f"Chunk took {chunk_elapsed_s:.1f} seconds." + ) + if counts_accumulator.all_levels_meet_min_u_count(min_count_per_level): logger.info( - " Stopping early: all levels reached at least " - f"{min_count_per_level} u observations" + f" Exiting early since min {min_u:,.0f} exceeds " + f"min_count_per_level = {min_count_per_level}" ) return True + if probe_percent_of_max_pairs is None: + logger.info(" Min u_count not hit, continuing.") + return False @@ -252,6 +304,11 @@ def estimate_u_values( num_chunks: int = 10, ) -> None: logger.info("----- Estimating u probabilities using random sampling -----") + logger.info( + "Estimating u with: " + f"max_pairs = {max_pairs:,.0f}, min_count_per_level = {min_count_per_level}, " + f"num_chunks = {num_chunks}" + ) if num_chunks < 1: raise ValueError("num_chunks must be >= 1") pipeline = CTEPipeline() @@ -370,7 +427,7 @@ def estimate_u_values( for i, comparison in enumerate(settings_obj.comparisons): logger.info( f"\nEstimating u for: {comparison.output_column_name} " - f"({i+1}/{len(settings_obj.comparisons)})" + f"(Comparison {i+1} of {len(settings_obj.comparisons)})" ) original_comparison = original_settings_obj.comparisons[i] @@ -413,11 +470,14 @@ def estimate_u_values( counts_accumulator=counts_accumulator, min_count_per_level=min_count_per_level, chunk_label="RHS probe chunk", + probe_percent_of_max_pairs=100.0 / (rhs_num_chunks * probe_multiplier), ) if not converged: if use_probe: - logger.info(" Probe did not converge; restarting with normal chunking") + logger.info( + "\n Probe did not converge; restarting with normal chunking" + ) counts_accumulator = _MUCountsAccumulator(comparison) for rhs_chunk_num in range(1, rhs_num_chunks + 1): From 4d3bfeb9e841b163f1b4b6f8be5bac1331966e66 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Wed, 7 Jan 2026 09:35:20 +0000 Subject: [PATCH 15/23] add nice logging2 --- splink/internals/estimate_u.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index e7f30a528d..b23c660265 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -256,7 +256,7 @@ def _process_rhs_chunk_and_check_convergence( if counts_accumulator.all_levels_meet_min_u_count(min_count_per_level): logger.info( - f" Exiting early since min {min_u:,.0f} exceeds " + f" Exiting early since min count of {min_u:,.0f} exceeds " f"min_count_per_level = {min_count_per_level}" ) return True @@ -476,7 +476,7 @@ def estimate_u_values( if not converged: if use_probe: logger.info( - "\n Probe did not converge; restarting with normal chunking" + " Probe did not converge; restarting with normal chunking\n" ) counts_accumulator = _MUCountsAccumulator(comparison) From 0a5bb7950d5fcccd7dacbc248cea339bca33359e Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 12 Jan 2026 15:40:51 +0000 Subject: [PATCH 16/23] slightly clarify interface --- splink/internals/estimate_u.py | 2 +- splink/internals/linker_components/training.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index b23c660265..7e4d1ddb80 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -446,7 +446,7 @@ def estimate_u_values( additional_columns_to_retain=[], ) - use_probe = (rhs_num_chunks > 1) and (min_count_per_level is not None) + use_probe = rhs_num_chunks > 1 converged = False if use_probe: diff --git a/splink/internals/linker_components/training.py b/splink/internals/linker_components/training.py index c4825ae888..cdc032c0ec 100644 --- a/splink/internals/linker_components/training.py +++ b/splink/internals/linker_components/training.py @@ -195,11 +195,11 @@ def estimate_u_using_random_sampling( probabilities. Note, seed for random sampling is only supported for DuckDB and Spark, for SQLite set to None. min_count_per_level (int | None): Minimum number of u observations - required for each comparison level before stopping chunking early. - If None, disables the probe phase and disables early stopping (all - chunks are processed). Defaults to 100. - num_chunks (int): Number of chunks to split the RHS of the cartesian - product into while estimating u. Defaults to 10. + required for each comparison level before stopping estimation early. + If None, disables early stopping (all chunks are processed). + Defaults to 100. + num_chunks (int): Number of chunks to split the workload while estimating u. + If set to 1, disables the probe phase. Defaults to 10. Examples: ```py From 43cc6fd936000fa96e16cfcaf1497e0597dbdeb9 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 12 Jan 2026 15:56:18 +0000 Subject: [PATCH 17/23] simplify --- splink/internals/estimate_u.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 7e4d1ddb80..6a4863ff98 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -51,17 +51,10 @@ def __init__(self, comparison: "Comparison") -> None: for cl in comparison._comparison_levels_excluding_null } - # Prefer unique labels if there are duplicates. comparison_levels = list(comparison._comparison_levels_excluding_null) for cl in comparison_levels: cvv = int(cl.comparison_vector_value) - label = cl.label_for_charts - if hasattr(cl, "_label_for_charts_no_duplicates"): - try: - label = cl._label_for_charts_no_duplicates(comparison_levels) - except Exception: - # Defensive: never let logging break estimation - label = cl.label_for_charts + label = cl._label_for_charts_no_duplicates(comparison_levels) self._label_by_cvv[cvv] = str(label) def update_from_chunk_counts(self, chunk_counts: pd.DataFrame) -> None: From ac21f3cc98421d98275130df17b36b072a61e4cf Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 12 Jan 2026 16:13:20 +0000 Subject: [PATCH 18/23] better comments --- splink/internals/estimate_u.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 6a4863ff98..8497863950 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -406,13 +406,13 @@ def estimate_u_values( input_tablename_sample_l = "__splink__df_concat_sample_left" input_tablename_sample_r = "__splink__df_concat_sample_right" - # We chunk only the RHS. + # At this point we've computed our data sample and we're ready to 'block and count' + + # Only chunk on RHS. Input data is sample and thus always small enough. rhs_num_chunks = num_chunks uid_columns = settings_obj.column_info_settings.unique_id_input_columns - # Note: we pass the actual InputColumn objects through to helper calls. - # Build common blocking columns (UID columns that all comparisons need) common_blocking_cols: list[str] = [] for uid_column in uid_columns: common_blocking_cols.extend(uid_column.l_r_names_as_l_r) From 1c45e9d1dd321ceaf277a5ce20a2cc4770b458b3 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 12 Jan 2026 16:28:43 +0000 Subject: [PATCH 19/23] simplify --- splink/internals/estimate_u.py | 76 +++++++++------------------------- 1 file changed, 20 insertions(+), 56 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 8497863950..0b5af52522 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -107,7 +107,7 @@ def pretty_table(self) -> str: return df.to_string(index=False) -def _u_counts_for_comparison_rhs_chunk( +def _run_rhs_chunk_and_check_convergence( *, db_api: "DatabaseAPISubClass", df_sample: "SplinkDataFrame", @@ -121,8 +121,21 @@ def _u_counts_for_comparison_rhs_chunk( comparison: "Comparison", blocking_cols: list[str], cv_cols: list[str], - right_chunk: tuple[int, int] | None, -) -> pd.DataFrame: + rhs_chunk_num: int, + rhs_num_chunks: int, + counts_accumulator: _MUCountsAccumulator, + min_count_per_level: int | None, + probe_percent_of_max_pairs: float | None = None, +) -> bool: + if probe_percent_of_max_pairs is not None: + logger.info( + f" Running probe chunk (~{probe_percent_of_max_pairs:.2f}% of max_pairs)" + ) + else: + logger.info(f" Running chunk {rhs_chunk_num}/{rhs_num_chunks}") + + t0 = time.perf_counter() + pipeline = CTEPipeline(input_dataframes=[df_sample]) if split_sqls: @@ -135,7 +148,7 @@ def _u_counts_for_comparison_rhs_chunk( link_type=link_type, source_dataset_input_column=source_dataset_input_column, unique_id_input_column=unique_id_input_column, - right_chunk=right_chunk, + right_chunk=(rhs_chunk_num, rhs_num_chunks), ) pipeline.enqueue_list_of_sqls(blocking_sqls) @@ -173,57 +186,10 @@ def _u_counts_for_comparison_rhs_chunk( # Drop lambda row: it isn't additive across chunks (it's already a # proportion), and we don't use it here anyway. - return chunk_counts[ + chunk_counts = chunk_counts[ chunk_counts.output_column_name != "_probability_two_random_records_match" ] - -def _process_rhs_chunk_and_check_convergence( - *, - db_api: "DatabaseAPISubClass", - df_sample: "SplinkDataFrame", - split_sqls: list[dict[str, str]], - input_tablename_sample_l: str, - input_tablename_sample_r: str, - blocking_rules_for_u: list[BlockingRule], - link_type: LinkTypeLiteralType, - source_dataset_input_column: "InputColumn | None", - unique_id_input_column: "InputColumn", - comparison: "Comparison", - blocking_cols: list[str], - cv_cols: list[str], - rhs_chunk_num: int, - rhs_num_chunks: int, - counts_accumulator: _MUCountsAccumulator, - min_count_per_level: int | None, - chunk_label: str, - probe_percent_of_max_pairs: float | None = None, -) -> bool: - if probe_percent_of_max_pairs is not None: - logger.info( - f" Running probe chunk (~{probe_percent_of_max_pairs:.2f}% of max_pairs)" - ) - else: - logger.info(f" Running chunk {rhs_chunk_num}/{rhs_num_chunks}") - - t0 = time.perf_counter() - - chunk_counts = _u_counts_for_comparison_rhs_chunk( - db_api=db_api, - df_sample=df_sample, - split_sqls=split_sqls, - input_tablename_sample_l=input_tablename_sample_l, - input_tablename_sample_r=input_tablename_sample_r, - blocking_rules_for_u=blocking_rules_for_u, - link_type=link_type, - source_dataset_input_column=source_dataset_input_column, - unique_id_input_column=unique_id_input_column, - comparison=comparison, - blocking_cols=blocking_cols, - cv_cols=cv_cols, - right_chunk=(rhs_chunk_num, rhs_num_chunks), - ) - counts_accumulator.update_from_chunk_counts(chunk_counts) chunk_elapsed_s = time.perf_counter() - t0 @@ -445,7 +411,7 @@ def estimate_u_values( if use_probe: probe_multiplier = 10 probe_rhs_num_chunks = rhs_num_chunks * probe_multiplier - converged = _process_rhs_chunk_and_check_convergence( + converged = _run_rhs_chunk_and_check_convergence( db_api=db_api, df_sample=df_sample, split_sqls=split_sqls, @@ -462,7 +428,6 @@ def estimate_u_values( rhs_num_chunks=probe_rhs_num_chunks, counts_accumulator=counts_accumulator, min_count_per_level=min_count_per_level, - chunk_label="RHS probe chunk", probe_percent_of_max_pairs=100.0 / (rhs_num_chunks * probe_multiplier), ) @@ -474,7 +439,7 @@ def estimate_u_values( counts_accumulator = _MUCountsAccumulator(comparison) for rhs_chunk_num in range(1, rhs_num_chunks + 1): - converged = _process_rhs_chunk_and_check_convergence( + converged = _run_rhs_chunk_and_check_convergence( db_api=db_api, df_sample=df_sample, split_sqls=split_sqls, @@ -491,7 +456,6 @@ def estimate_u_values( rhs_num_chunks=rhs_num_chunks, counts_accumulator=counts_accumulator, min_count_per_level=min_count_per_level, - chunk_label="RHS chunk", ) if converged and (min_count_per_level is not None): break From 54ba4151f3a9ab3db2e10f6abc8642bbf4775518 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 12 Jan 2026 16:36:31 +0000 Subject: [PATCH 20/23] curry/bind early for more succinct and readable code --- splink/internals/estimate_u.py | 53 ++++++++++++++-------------------- 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 0b5af52522..0491c66060 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -3,6 +3,7 @@ import logging import time from copy import deepcopy +from functools import partial from typing import TYPE_CHECKING, List import pandas as pd @@ -181,11 +182,9 @@ def _run_rhs_chunk_and_check_convergence( try: chunk_counts = df_params.as_pandas_dataframe() finally: - # Drop final output table df_params.drop_table_from_database_and_remove_from_cache() - # Drop lambda row: it isn't additive across chunks (it's already a - # proportion), and we don't use it here anyway. + # Drop lambda row: it isn't additive and we don't use it here anyway chunk_counts = chunk_counts[ chunk_counts.output_column_name != "_probability_two_random_records_match" ] @@ -407,27 +406,32 @@ def estimate_u_values( use_probe = rhs_num_chunks > 1 + # Bind invariant args once to avoid repetition + run_rhs_chunk = partial( + _run_rhs_chunk_and_check_convergence, + db_api=db_api, + df_sample=df_sample, + split_sqls=split_sqls, + input_tablename_sample_l=input_tablename_sample_l, + input_tablename_sample_r=input_tablename_sample_r, + blocking_rules_for_u=blocking_rules_for_u, + link_type=linker._settings_obj._link_type, + source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, + unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, + comparison=comparison, + blocking_cols=blocking_cols, + cv_cols=cv_cols, + min_count_per_level=min_count_per_level, + ) + converged = False if use_probe: probe_multiplier = 10 probe_rhs_num_chunks = rhs_num_chunks * probe_multiplier - converged = _run_rhs_chunk_and_check_convergence( - db_api=db_api, - df_sample=df_sample, - split_sqls=split_sqls, - input_tablename_sample_l=input_tablename_sample_l, - input_tablename_sample_r=input_tablename_sample_r, - blocking_rules_for_u=blocking_rules_for_u, - link_type=linker._settings_obj._link_type, - source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, - unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, - comparison=comparison, - blocking_cols=blocking_cols, - cv_cols=cv_cols, + converged = run_rhs_chunk( rhs_chunk_num=1, rhs_num_chunks=probe_rhs_num_chunks, counts_accumulator=counts_accumulator, - min_count_per_level=min_count_per_level, probe_percent_of_max_pairs=100.0 / (rhs_num_chunks * probe_multiplier), ) @@ -439,23 +443,10 @@ def estimate_u_values( counts_accumulator = _MUCountsAccumulator(comparison) for rhs_chunk_num in range(1, rhs_num_chunks + 1): - converged = _run_rhs_chunk_and_check_convergence( - db_api=db_api, - df_sample=df_sample, - split_sqls=split_sqls, - input_tablename_sample_l=input_tablename_sample_l, - input_tablename_sample_r=input_tablename_sample_r, - blocking_rules_for_u=blocking_rules_for_u, - link_type=linker._settings_obj._link_type, - source_dataset_input_column=settings_obj.column_info_settings.source_dataset_input_column, - unique_id_input_column=settings_obj.column_info_settings.unique_id_input_column, - comparison=comparison, - blocking_cols=blocking_cols, - cv_cols=cv_cols, + converged = run_rhs_chunk( rhs_chunk_num=rhs_chunk_num, rhs_num_chunks=rhs_num_chunks, counts_accumulator=counts_accumulator, - min_count_per_level=min_count_per_level, ) if converged and (min_count_per_level is not None): break From 2eba5eced82873fe636145a3333148fefe343776 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 12 Jan 2026 16:46:49 +0000 Subject: [PATCH 21/23] name things better --- splink/internals/estimate_u.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index 0491c66060..a58bc35f29 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -108,7 +108,7 @@ def pretty_table(self) -> str: return df.to_string(index=False) -def _run_rhs_chunk_and_check_convergence( +def _accumulate_u_counts_from_chunk_and_check_min_count( *, db_api: "DatabaseAPISubClass", df_sample: "SplinkDataFrame", @@ -407,8 +407,8 @@ def estimate_u_values( use_probe = rhs_num_chunks > 1 # Bind invariant args once to avoid repetition - run_rhs_chunk = partial( - _run_rhs_chunk_and_check_convergence, + run_chunk = partial( + _accumulate_u_counts_from_chunk_and_check_min_count, db_api=db_api, df_sample=df_sample, split_sqls=split_sqls, @@ -424,18 +424,18 @@ def estimate_u_values( min_count_per_level=min_count_per_level, ) - converged = False + min_count_condition_met = False if use_probe: probe_multiplier = 10 probe_rhs_num_chunks = rhs_num_chunks * probe_multiplier - converged = run_rhs_chunk( + min_count_condition_met = run_chunk( rhs_chunk_num=1, rhs_num_chunks=probe_rhs_num_chunks, counts_accumulator=counts_accumulator, probe_percent_of_max_pairs=100.0 / (rhs_num_chunks * probe_multiplier), ) - if not converged: + if not min_count_condition_met: if use_probe: logger.info( " Probe did not converge; restarting with normal chunking\n" @@ -443,12 +443,12 @@ def estimate_u_values( counts_accumulator = _MUCountsAccumulator(comparison) for rhs_chunk_num in range(1, rhs_num_chunks + 1): - converged = run_rhs_chunk( + min_count_condition_met = run_chunk( rhs_chunk_num=rhs_chunk_num, rhs_num_chunks=rhs_num_chunks, counts_accumulator=counts_accumulator, ) - if converged and (min_count_per_level is not None): + if min_count_condition_met and (min_count_per_level is not None): break aggregated_counts_df = counts_accumulator.to_dataframe() From 50ec0aea8d9fd92541761a960d9dd3667a84c2a2 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 12 Jan 2026 17:07:04 +0000 Subject: [PATCH 22/23] improve comment --- splink/internals/estimate_u.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/splink/internals/estimate_u.py b/splink/internals/estimate_u.py index a58bc35f29..5f83450886 100644 --- a/splink/internals/estimate_u.py +++ b/splink/internals/estimate_u.py @@ -456,7 +456,8 @@ def estimate_u_values( # Convert aggregated counts to proportions (u probabilities) param_records = compute_proportions_for_new_parameters(aggregated_counts_df) - # Principled handling of unobserved levels: + # Handling of unobserved levels is consistent with splink 4 + # 'LEVEL_NOT_OBSERVED_TEXT' behaviour whilst enabling the 'break early' check # - We explicitly include every level (via enumeration) so that convergence # checks can treat missing GROUP BY rows as 0 counts. # - But for the final trained u values, a level with u_count == 0 should be From 1171090e85edd95e81798621df031ec087ea91b9 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 12 Jan 2026 17:10:33 +0000 Subject: [PATCH 23/23] update tests for clarity --- tests/test_find_new_matches.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_find_new_matches.py b/tests/test_find_new_matches.py index 059bbd9ae0..91b902c50f 100644 --- a/tests/test_find_new_matches.py +++ b/tests/test_find_new_matches.py @@ -85,11 +85,7 @@ def test_matches_work(test_helpers, dialect): linker = helper.linker_with_registration(df, get_settings_dict()) # Train our model to get more reasonable outputs... - linker.training.estimate_u_using_random_sampling( - max_pairs=1e6, - num_chunks=10, - min_count_per_level=None, - ) + linker.training.estimate_u_using_random_sampling(max_pairs=1e6, num_chunks=1) linker.visualisations.match_weights_chart().save("mwc.html") blocking_rule = block_on("first_name", "surname")