1
1
"""This module contains the base classes for the metrics used across synthesized."""
2
+
2
3
import os
3
4
import typing as ty
4
5
from abc import ABC , abstractmethod
@@ -71,6 +72,7 @@ def _add_to_database(
71
72
dataset_rows : ty .Optional [int ] = None ,
72
73
dataset_cols : ty .Optional [int ] = None ,
73
74
category : ty .Optional [str ] = None ,
75
+ session : ty .Optional [Session ] = None ,
74
76
):
75
77
"""
76
78
Adds the metric result to the database. The metric result should be specified as value.
@@ -101,7 +103,23 @@ def _add_to_database(
101
103
if hasattr (value , "item" ):
102
104
value = value .item ()
103
105
104
- with self ._session as session :
106
+ if session is None :
107
+ with self ._session as session :
108
+ metric_id = utils .get_metric_id (self .name , session , category = category )
109
+ version_id = utils .get_version_id (version , session )
110
+ dataset_id = utils .get_df_id (
111
+ dataset_name , session , num_rows = dataset_rows , num_columns = dataset_cols
112
+ )
113
+ result = model .Result (
114
+ metric_id = metric_id ,
115
+ dataset_id = dataset_id ,
116
+ version_id = version_id ,
117
+ value = value ,
118
+ run_id = run_id ,
119
+ )
120
+ session .add (result )
121
+ session .commit ()
122
+ else :
105
123
metric_id = utils .get_metric_id (self .name , session , category = category )
106
124
version_id = utils .get_version_id (version , session )
107
125
dataset_id = utils .get_df_id (
@@ -115,7 +133,6 @@ def _add_to_database(
115
133
run_id = run_id ,
116
134
)
117
135
session .add (result )
118
- session .commit ()
119
136
120
137
121
138
class OneColumnMetric (_Metric ):
@@ -167,7 +184,7 @@ def check_column_types(cls, sr: pd.Series, check: Check = ColumnCheck()) -> bool
167
184
def _compute_metric (self , sr : pd .Series ):
168
185
...
169
186
170
- def __call__ (self , sr : pd .Series , dataset_name : ty .Optional [str ] = None ):
187
+ def __call__ (self , sr : pd .Series , dataset_name : ty .Optional [str ] = None , session = None ):
171
188
if not self .check_column_types (sr , self .check ):
172
189
value = None
173
190
else :
@@ -181,6 +198,7 @@ def __call__(self, sr: pd.Series, dataset_name: ty.Optional[str] = None):
181
198
dataset_rows = len (sr ),
182
199
category = "OneColumnMetric" ,
183
200
dataset_cols = 1 ,
201
+ session = session ,
184
202
)
185
203
186
204
return value
@@ -237,7 +255,9 @@ def check_column_types(cls, sr_a: pd.Series, sr_b: pd.Series, check: Check = Col
237
255
def _compute_metric (self , sr_a : pd .Series , sr_b : pd .Series ):
238
256
...
239
257
240
- def __call__ (self , sr_a : pd .Series , sr_b : pd .Series , dataset_name : ty .Optional [str ] = None ):
258
+ def __call__ (
259
+ self , sr_a : pd .Series , sr_b : pd .Series , dataset_name : ty .Optional [str ] = None , session = None
260
+ ):
241
261
if not self .check_column_types (sr_a , sr_b , self .check ):
242
262
value = None
243
263
else :
@@ -251,6 +271,7 @@ def __call__(self, sr_a: pd.Series, sr_b: pd.Series, dataset_name: ty.Optional[s
251
271
dataset_rows = len (sr_a ),
252
272
category = "TwoColumnMetric" ,
253
273
dataset_cols = 1 ,
274
+ session = session ,
254
275
)
255
276
256
277
return value
0 commit comments