Skip to content

Commit 8fbe1cc

Browse files
[MAINTENANCE] Improve type hints in ExecutionEngine.resolve_metrics() flow and delete unnecessary checks (great-expectations#6804)
1 parent 8c8f5c5 commit 8fbe1cc

File tree

4 files changed

+91
-68
lines changed

4 files changed

+91
-68
lines changed

assets/docker/postgresql/docker-compose.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
version: '3.2'
22
services:
33
travis_db:
4-
image: postgres:12
4+
image: postgres:15.1
55
command:
66
- postgres
77
- "-c"

great_expectations/execution_engine/execution_engine.py

+12-30
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@
2020
import great_expectations.exceptions as gx_exceptions
2121
from great_expectations.core.batch_manager import BatchManager
2222
from great_expectations.core.metric_domain_types import MetricDomainTypes
23-
from great_expectations.core.metric_function_types import (
24-
MetricFunctionTypes,
25-
MetricPartialFunctionTypes,
26-
)
2723
from great_expectations.core.util import (
2824
AzureUrl,
2925
DBFSPath,
@@ -42,6 +38,10 @@
4238
from great_expectations.validator.metric_configuration import MetricConfiguration
4339

4440
if TYPE_CHECKING:
41+
# noinspection PyPep8Naming
42+
import pyspark.sql.functions as F
43+
import sqlalchemy as sa
44+
4545
from great_expectations.core.batch import (
4646
BatchData,
4747
BatchDataType,
@@ -82,7 +82,7 @@ class MetricComputationConfiguration(DictDot):
8282
"""
8383

8484
metric_configuration: MetricConfiguration
85-
metric_fn: Any
85+
metric_fn: sa.func | F
8686
metric_provider_kwargs: dict
8787
compute_domain_kwargs: Optional[dict] = None
8888
accessor_domain_kwargs: Optional[dict] = None
@@ -441,6 +441,8 @@ def _build_direct_and_bundled_metric_computation_configurations(
441441
Directly-computable "MetricConfiguration" must have non-NULL metric function ("metric_fn"). Aggregate metrics
442442
have NULL metric function, but non-NULL partial metric function ("metric_partial_fn"); aggregates are bundled.
443443
444+
See documentation in "MetricProvider._register_metric_functions()" for in-depth description of this mechanism.
445+
444446
Args:
445447
metrics_to_resolve: the metrics to evaluate
446448
metrics: already-computed metrics currently available to the engine
@@ -465,7 +467,8 @@ def _build_direct_and_bundled_metric_computation_configurations(
465467
str, Union[MetricValue, Tuple[Any, dict, dict]]
466468
]
467469
metric_class: MetricProvider
468-
metric_fn: Any
470+
metric_fn: Union[Callable, None]
471+
metric_aggregate_fn: sa.func | F
469472
metric_provider_kwargs: dict
470473
compute_domain_kwargs: dict
471474
accessor_domain_kwargs: dict
@@ -491,10 +494,10 @@ def _build_direct_and_bundled_metric_computation_configurations(
491494
if metric_fn is None:
492495
try:
493496
(
494-
metric_fn,
497+
metric_aggregate_fn,
495498
compute_domain_kwargs,
496499
accessor_domain_kwargs,
497-
) = resolved_metric_dependencies_by_metric_name.pop( # type: ignore[misc,assignment]
500+
) = resolved_metric_dependencies_by_metric_name.pop(
498501
"metric_partial_fn"
499502
)
500503
except KeyError as e:
@@ -505,34 +508,13 @@ def _build_direct_and_bundled_metric_computation_configurations(
505508
metric_fn_bundle_configurations.append(
506509
MetricComputationConfiguration(
507510
metric_configuration=metric_to_resolve,
508-
metric_fn=metric_fn,
511+
metric_fn=metric_aggregate_fn,
509512
metric_provider_kwargs=metric_provider_kwargs,
510513
compute_domain_kwargs=compute_domain_kwargs,
511514
accessor_domain_kwargs=accessor_domain_kwargs,
512515
)
513516
)
514517
else:
515-
metric_fn_type: MetricFunctionTypes = getattr(
516-
metric_fn, "metric_fn_type", MetricFunctionTypes.VALUE
517-
)
518-
if isinstance(
519-
metric_fn_type, (MetricFunctionTypes, MetricPartialFunctionTypes)
520-
) and metric_fn_type not in [
521-
MetricPartialFunctionTypes.MAP_FN,
522-
MetricPartialFunctionTypes.MAP_SERIES,
523-
MetricPartialFunctionTypes.WINDOW_FN,
524-
MetricPartialFunctionTypes.MAP_CONDITION_FN,
525-
MetricPartialFunctionTypes.MAP_CONDITION_SERIES,
526-
MetricPartialFunctionTypes.WINDOW_CONDITION_FN,
527-
MetricPartialFunctionTypes.AGGREGATE_FN,
528-
MetricFunctionTypes.VALUE,
529-
MetricFunctionTypes.MAP_VALUES,
530-
MetricFunctionTypes.MAP_VALUES,
531-
MetricFunctionTypes.AGGREGATE_VALUE,
532-
]:
533-
logger.warning(
534-
f'Unrecognized metric function type while trying to resolve "{metric_to_resolve.id}".'
535-
)
536518
metric_fn_direct_configurations.append(
537519
MetricComputationConfiguration(
538520
metric_configuration=metric_to_resolve,

great_expectations/expectations/metrics/map_metric_provider.py

+37-20
Original file line numberDiff line numberDiff line change
@@ -3306,6 +3306,7 @@ def _register_metric_functions(cls):
33063306
metric_fn_type=MetricFunctionTypes.VALUE,
33073307
)
33083308
if metric_fn_type == MetricPartialFunctionTypes.MAP_CONDITION_FN:
3309+
# Documentation in "MetricProvider._register_metric_functions()" explains registration protocol.
33093310
if domain_type == MetricDomainTypes.COLUMN:
33103311
register_metric(
33113312
metric_name=metric_name
@@ -3443,6 +3444,7 @@ def _register_metric_functions(cls):
34433444
metric_fn_type=MetricFunctionTypes.VALUE,
34443445
)
34453446
if metric_fn_type == MetricPartialFunctionTypes.MAP_CONDITION_FN:
3447+
# Documentation in "MetricProvider._register_metric_functions()" explains registration protocol.
34463448
if domain_type == MetricDomainTypes.COLUMN:
34473449
register_metric(
34483450
metric_name=metric_name
@@ -3576,22 +3578,30 @@ def _get_evaluation_dependencies(
35763578
execution_engine: Optional[ExecutionEngine] = None,
35773579
runtime_configuration: Optional[dict] = None,
35783580
):
3579-
metric_name = metric.metric_name
3581+
dependencies: Dict[str, MetricConfiguration] = {}
3582+
35803583
base_metric_value_kwargs = {
35813584
k: v for k, v in metric.metric_value_kwargs.items() if k != "result_format"
35823585
}
3583-
dependencies: Dict[str, MetricConfiguration] = {}
35843586

3585-
metric_suffix = ".unexpected_count"
3587+
metric_name: str = metric.metric_name
3588+
3589+
metric_suffix: str = ".unexpected_count"
3590+
3591+
# Documentation in "MetricProvider._register_metric_functions()" explains registration/dependency protocol.
35863592
if metric_name.endswith(metric_suffix):
3587-
try:
3588-
_ = get_metric_provider(
3589-
f"{metric_name}.{MetricPartialFunctionTypes.AGGREGATE_FN.metric_suffix}",
3590-
execution_engine,
3591-
)
3592-
has_aggregate_fn = True
3593-
except gx_exceptions.MetricProviderError:
3594-
has_aggregate_fn = False
3593+
has_aggregate_fn: bool = False
3594+
3595+
if execution_engine is not None:
3596+
try:
3597+
_ = get_metric_provider(
3598+
f"{metric_name}.{MetricPartialFunctionTypes.AGGREGATE_FN.metric_suffix}",
3599+
execution_engine,
3600+
)
3601+
has_aggregate_fn = True
3602+
except gx_exceptions.MetricProviderError:
3603+
pass
3604+
35953605
if has_aggregate_fn:
35963606
dependencies["metric_partial_fn"] = MetricConfiguration(
35973607
metric_name=f"{metric_name}.{MetricPartialFunctionTypes.AGGREGATE_FN.metric_suffix}",
@@ -3606,15 +3616,22 @@ def _get_evaluation_dependencies(
36063616
)
36073617

36083618
# MapMetric uses "condition" metric to build "unexpected_count.aggregate_fn" and other listed metrics as well.
3609-
for metric_suffix in [
3610-
f".unexpected_count.{MetricPartialFunctionTypeSuffixes.AGGREGATE_FUNCTION.value}",
3611-
".unexpected_value_counts",
3612-
".unexpected_index_query",
3613-
".unexpected_index_list",
3614-
".filtered_row_count",
3615-
".unexpected_values",
3616-
".unexpected_rows",
3617-
]:
3619+
unexpected_condition_dependent_metric_name_suffixes: List[str] = list(
3620+
filter(
3621+
lambda element: metric_name.endswith(element),
3622+
[
3623+
f".unexpected_count.{MetricPartialFunctionTypeSuffixes.AGGREGATE_FUNCTION.value}",
3624+
".unexpected_value_counts",
3625+
".unexpected_index_query",
3626+
".unexpected_index_list",
3627+
".filtered_row_count",
3628+
".unexpected_values",
3629+
".unexpected_rows",
3630+
],
3631+
)
3632+
)
3633+
if len(unexpected_condition_dependent_metric_name_suffixes) == 1:
3634+
metric_suffix = unexpected_condition_dependent_metric_name_suffixes[0]
36183635
if metric_name.endswith(metric_suffix):
36193636
dependencies["unexpected_condition"] = MetricConfiguration(
36203637
metric_name=f"{metric_name[:-len(metric_suffix)]}.{MetricPartialFunctionTypeSuffixes.CONDITION.value}",

great_expectations/expectations/metrics/metric_provider.py

+41-17
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,30 @@ def _register_metric_functions(cls) -> None:
139139
# This is not a metric (valid metrics possess exectly one metric function).
140140
return
141141

142+
"""
143+
Basic metric implementations (defined by specifying "metric_name" class variable in "metric_class") use
144+
either "@metric_value" decorator (with default "metric_fn_type" set to "MetricFunctionTypes.VALUE"); or
145+
"@metric_partial" decorator with specification "partial_fn_type=MetricPartialFunctionTypes.AGGREGATE_FN"
146+
(which ultimately sets "metric_fn_type" of inner function to this value); or "@column_aggregate_value"
147+
decorator (with default "metric_fn_type" set to "MetricFunctionTypes.VALUE"); or (applicable for column
148+
domain metrics only) "column_aggregate_partial" decorator with "partial_fn_type" explicitly set to
149+
"MetricPartialFunctionTypes.AGGREGATE_FN". When "metric_fn_type" of metric implementation function is
150+
of "aggregate partial" type ("MetricPartialFunctionTypes.AGGREGATE_FN"), underlying backend (e.g., SQL
151+
or Spark) employs "deferred execution" (gather computation needs to build execution plan, then execute
152+
all computations combined). Deferred aggregate function calls are bundled (applies to SQL and Spark).
153+
To instruct "ExecutionEngine" accordingly, original metric is registered with its "declared" name, but
154+
with "metric_provider" function omitted (set to "None"), and additional "AGGREGATE_FN" metric, with its
155+
"metric_provider" set to (decorated) implementation function, defined in metric class, is registered.
156+
Then "AGGREGATE_FN" metric can specified with key "metric_partial_fn" as evaluation metric dependency.
157+
By convention, aggregate partial metric implementation functions return three-valued tuple, containing
158+
deferred execution metric implementation function of corresponding "ExecutionEngine" backend (called
159+
"metric_aggregate") as well as "compute_domain_kwargs" and "accessor_domain_kwargs", which are relevant
160+
for bundled computation and result access, respectively. When "ExecutionEngine.resolve_metrics()" finds
161+
no "metric_provider" (metric_fn being "None"), it then obtains this three-valued tuple from dictionary
162+
of "resolved_metric_dependencies_by_metric_name" using previously declared "metric_partial_fn" key (as
163+
described above), composes full metric execution configuration structure, and adds this configuration
164+
to list of metrics to be resolved as one bundle (specifics pertaining to "ExecutionEngine" subclasses).
165+
"""
142166
if metric_fn_type not in [
143167
MetricFunctionTypes.VALUE,
144168
MetricPartialFunctionTypes.AGGREGATE_FN,
@@ -217,22 +241,22 @@ def _get_evaluation_dependencies(
217241
runtime_configuration: Optional[dict] = None,
218242
):
219243
dependencies: Dict[str, MetricConfiguration] = {}
220-
if execution_engine is not None:
221-
metric_name = metric.metric_name
222-
try:
223-
_ = get_metric_provider(
224-
f"{metric_name}.{MetricPartialFunctionTypes.AGGREGATE_FN.metric_suffix}",
225-
execution_engine,
226-
)
227-
has_aggregate_fn = True
228-
except gx_exceptions.MetricProviderError:
229-
has_aggregate_fn = False
230-
231-
if has_aggregate_fn:
232-
dependencies["metric_partial_fn"] = MetricConfiguration(
233-
metric_name=f"{metric_name}.{MetricPartialFunctionTypes.AGGREGATE_FN.metric_suffix}",
234-
metric_domain_kwargs=metric.metric_domain_kwargs,
235-
metric_value_kwargs=metric.metric_value_kwargs,
236-
)
244+
245+
if execution_engine is None:
246+
return dependencies
247+
248+
try:
249+
metric_name: str = metric.metric_name
250+
_ = get_metric_provider(
251+
f"{metric_name}.{MetricPartialFunctionTypes.AGGREGATE_FN.metric_suffix}",
252+
execution_engine,
253+
)
254+
dependencies["metric_partial_fn"] = MetricConfiguration(
255+
metric_name=f"{metric_name}.{MetricPartialFunctionTypes.AGGREGATE_FN.metric_suffix}",
256+
metric_domain_kwargs=metric.metric_domain_kwargs,
257+
metric_value_kwargs=metric.metric_value_kwargs,
258+
)
259+
except gx_exceptions.MetricProviderError:
260+
pass
237261

238262
return dependencies

0 commit comments

Comments
 (0)