Skip to content

Commit ae2bc77

Browse files
authored
Bulk commit to DB and Cache IDs (#162)
* Bulk commit and cache database responses
1 parent 68fbadb commit ae2bc77

File tree

7 files changed

+129
-28
lines changed

7 files changed

+129
-28
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@ repos:
2323
hooks:
2424
- id: black
2525
- repo: https://github.com/pre-commit/mirrors-mypy
26-
rev: v1.5.1
26+
rev: v1.7.1
2727
hooks:
2828
- id: mypy
2929
files: src
3030
additional_dependencies:
3131
- numpy>=1.21
3232
- sqlalchemy[mypy]
3333
- alembic
34+
- types-cachetools
3435
args: [--install-types, --non-interactive]
3536
# Note that using the --install-types is problematic if running in
3637
# parallel as mutating the pre-commit env at runtime breaks cache.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ dependencies = [
4545
"pandas >= 1.2",
4646
"scipy >= 1.5",
4747
"seaborn >= 0.11.0",
48+
"cachetools >= 5.0",
4849
]
4950

5051
[project.optional-dependencies]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
alembic==1.7.6
2+
cachetools==5.3.1
23
contourpy==1.1.0
34
cycler==0.11.0
45
fonttools==4.42.1

src/insight/database/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Utils for fetching information from the backend DB."""
2+
23
import os
34
import re
45
import typing as ty
56

67
import pandas as pd
8+
from cachetools import cached
9+
from cachetools.keys import hashkey
710
from sqlalchemy import create_engine
811
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
912
from sqlalchemy.orm import Session, sessionmaker
@@ -14,6 +17,8 @@
1417
NamedModelType = ty.TypeVar("NamedModelType", model.Dataset, model.Metric, model.Version)
1518

1619
_database_fail_note = "Failure to communicate with the database"
20+
_DATASET_ID_MAPPING: ty.Optional[ty.Dict[str, int]] = None
21+
_METRIC_ID_MAPPING: ty.Optional[ty.Dict[str, int]] = None
1722

1823

1924
def get_df(url_or_path: str):
@@ -25,6 +30,7 @@ def get_df(url_or_path: str):
2530
return df
2631

2732

33+
@cached(cache={}, key=lambda df_name, session, **kwargs: hashkey(df_name))
2834
def get_df_id(
2935
df_name: str,
3036
session: Session,
@@ -40,6 +46,17 @@ def get_df_id(
4046
num_columns (int): The number of columns in the dataframe. Optional.
4147
4248
"""
49+
global _DATASET_ID_MAPPING # pylint: disable=global-statement
50+
# create a mapping of df_names to session
51+
if _DATASET_ID_MAPPING is None:
52+
with session:
53+
df_names = session.query(model.Dataset).all()
54+
_DATASET_ID_MAPPING = {df.name: df.id for df in df_names if df.name is not None}
55+
56+
df_id = _DATASET_ID_MAPPING.get(df_name)
57+
if df_id is not None:
58+
return df_id
59+
4360
dataset = get_object_from_db_by_name(df_name, session, model.Dataset)
4461
if dataset is None:
4562
with session:
@@ -51,6 +68,7 @@ def get_df_id(
5168
return int(dataset.id)
5269

5370

71+
@cached(cache={}, key=lambda metric, session, **kwargs: hashkey(metric))
5472
def get_metric_id(metric: str, session: Session, category: ty.Optional[str] = None) -> int:
5573
"""Get the id of a metric in the database. If it doesn't exist, create it.
5674
@@ -59,6 +77,17 @@ def get_metric_id(metric: str, session: Session, category: ty.Optional[str] = No
5977
session (Session): The database session.
6078
category (str): The category of the metric. Optional.
6179
"""
80+
global _METRIC_ID_MAPPING # pylint: disable=global-statement
81+
# create a mapping of df_names to session
82+
if _METRIC_ID_MAPPING is None:
83+
with session:
84+
metrics = session.query(model.Dataset).all()
85+
_METRIC_ID_MAPPING = {m.name: m.id for m in metrics if m.name is not None}
86+
87+
metric_id = _METRIC_ID_MAPPING.get(metric)
88+
if metric_id is not None:
89+
return metric_id
90+
6291
db_metric = get_object_from_db_by_name(metric, session, model.Metric)
6392

6493
if db_metric is None:
@@ -71,6 +100,7 @@ def get_metric_id(metric: str, session: Session, category: ty.Optional[str] = No
71100
return int(db_metric.id)
72101

73102

103+
@cached(cache={}, key=lambda version, session: hashkey(version))
74104
def get_version_id(version: str, session: Session) -> int:
75105
"""Get the id of a version in the database. If it doesn't exist, create it.
76106

src/insight/metrics/base.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""This module contains the base classes for the metrics used across synthesized."""
2+
23
import os
34
import typing as ty
45
from abc import ABC, abstractmethod
@@ -71,6 +72,7 @@ def _add_to_database(
7172
dataset_rows: ty.Optional[int] = None,
7273
dataset_cols: ty.Optional[int] = None,
7374
category: ty.Optional[str] = None,
75+
session: ty.Optional[Session] = None,
7476
):
7577
"""
7678
Adds the metric result to the database. The metric result should be specified as value.
@@ -101,7 +103,23 @@ def _add_to_database(
101103
if hasattr(value, "item"):
102104
value = value.item()
103105

104-
with self._session as session:
106+
if session is None:
107+
with self._session as session:
108+
metric_id = utils.get_metric_id(self.name, session, category=category)
109+
version_id = utils.get_version_id(version, session)
110+
dataset_id = utils.get_df_id(
111+
dataset_name, session, num_rows=dataset_rows, num_columns=dataset_cols
112+
)
113+
result = model.Result(
114+
metric_id=metric_id,
115+
dataset_id=dataset_id,
116+
version_id=version_id,
117+
value=value,
118+
run_id=run_id,
119+
)
120+
session.add(result)
121+
session.commit()
122+
else:
105123
metric_id = utils.get_metric_id(self.name, session, category=category)
106124
version_id = utils.get_version_id(version, session)
107125
dataset_id = utils.get_df_id(
@@ -115,7 +133,6 @@ def _add_to_database(
115133
run_id=run_id,
116134
)
117135
session.add(result)
118-
session.commit()
119136

120137

121138
class OneColumnMetric(_Metric):
@@ -167,7 +184,7 @@ def check_column_types(cls, sr: pd.Series, check: Check = ColumnCheck()) -> bool
167184
def _compute_metric(self, sr: pd.Series):
168185
...
169186

170-
def __call__(self, sr: pd.Series, dataset_name: ty.Optional[str] = None):
187+
def __call__(self, sr: pd.Series, dataset_name: ty.Optional[str] = None, session=None):
171188
if not self.check_column_types(sr, self.check):
172189
value = None
173190
else:
@@ -181,6 +198,7 @@ def __call__(self, sr: pd.Series, dataset_name: ty.Optional[str] = None):
181198
dataset_rows=len(sr),
182199
category="OneColumnMetric",
183200
dataset_cols=1,
201+
session=session,
184202
)
185203

186204
return value
@@ -237,7 +255,9 @@ def check_column_types(cls, sr_a: pd.Series, sr_b: pd.Series, check: Check = Col
237255
def _compute_metric(self, sr_a: pd.Series, sr_b: pd.Series):
238256
...
239257

240-
def __call__(self, sr_a: pd.Series, sr_b: pd.Series, dataset_name: ty.Optional[str] = None):
258+
def __call__(
259+
self, sr_a: pd.Series, sr_b: pd.Series, dataset_name: ty.Optional[str] = None, session=None
260+
):
241261
if not self.check_column_types(sr_a, sr_b, self.check):
242262
value = None
243263
else:
@@ -251,6 +271,7 @@ def __call__(self, sr_a: pd.Series, sr_b: pd.Series, dataset_name: ty.Optional[s
251271
dataset_rows=len(sr_a),
252272
category="TwoColumnMetric",
253273
dataset_cols=1,
274+
session=session,
254275
)
255276

256277
return value

src/insight/metrics/metrics_usage.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,23 @@ def __init__(self, metric: OneColumnMetric):
1717
self.name = f"{metric.name}_map"
1818

1919
def _compute_result(self, df: pd.DataFrame) -> pd.DataFrame:
20-
columns_map = {
21-
col: self._metric(df[col], dataset_name=df.attrs.get("name", "") + f"_{col}")
22-
for col in df.columns
23-
}
24-
result = pd.DataFrame(data=columns_map.values(), index=df.columns, columns=[self.name])
20+
dataset_name = df.attrs.get("name", "")
21+
if self._session is not None:
22+
with self._session as session:
23+
columns_map = {
24+
col: self._metric(
25+
df[col], dataset_name=f"{dataset_name}_{col}", session=session
26+
)
27+
for col in df.columns
28+
}
29+
session.commit()
30+
else:
31+
columns_map = {
32+
col: self._metric(df[col], dataset_name=f"{dataset_name}_{col}", session=None)
33+
for col in df.columns
34+
}
2535

36+
result = pd.DataFrame(data=columns_map.values(), index=df.columns, columns=[self.name])
2637
result.name = self._metric.name
2738
return result
2839

@@ -57,12 +68,24 @@ def _compute_result(self, df: pd.DataFrame) -> pd.DataFrame:
5768
columns = df.columns
5869
matrix = pd.DataFrame(index=columns, columns=columns)
5970

60-
for col_a, col_b in permutations(columns, 2):
61-
matrix[col_a][col_b] = self._metric(
62-
df[col_a],
63-
df[col_b],
64-
dataset_name=df.attrs.get("name", "") + f"_{col_a}_{col_b}",
65-
)
71+
if self._session is not None:
72+
with self._session as session:
73+
for col_a, col_b in permutations(columns, 2):
74+
matrix[col_a][col_b] = self._metric(
75+
df[col_a],
76+
df[col_b],
77+
dataset_name=df.attrs.get("name", "") + f"_{col_a}_{col_b}",
78+
session=session,
79+
)
80+
session.commit()
81+
else:
82+
for col_a, col_b in permutations(columns, 2):
83+
matrix[col_a][col_b] = self._metric(
84+
df[col_a],
85+
df[col_b],
86+
dataset_name=df.attrs.get("name", "") + f"_{col_a}_{col_b}",
87+
session=None,
88+
)
6689

6790
return pd.DataFrame(matrix.astype(np.float32)) # explicit casting for mypy
6891

@@ -105,16 +128,31 @@ def __init__(self, metric: TwoColumnMetric):
105128
self.name = f"{metric.name}_map"
106129

107130
def _compute_result(self, df_old: pd.DataFrame, df_new: pd.DataFrame) -> pd.DataFrame:
108-
columns_map = {
109-
col: self._metric(
110-
df_old[col],
111-
df_new[col],
112-
dataset_name=df_old.attrs.get("name", "") + f"_{col}",
113-
)
114-
for col in df_old.columns
115-
}
116-
result = pd.DataFrame(data=columns_map.values(), index=df_old.columns, columns=[self.name])
117131

132+
if self._session is not None:
133+
with self._session as session:
134+
columns_map = {
135+
col: self._metric(
136+
df_old[col],
137+
df_new[col],
138+
dataset_name=df_old.attrs.get("name", "") + f"_{col}",
139+
session=session,
140+
)
141+
for col in df_old.columns
142+
}
143+
session.commit()
144+
else:
145+
columns_map = {
146+
col: self._metric(
147+
df_old[col],
148+
df_new[col],
149+
dataset_name=df_old.attrs.get("name", "") + f"_{col}",
150+
session=None,
151+
)
152+
for col in df_old.columns
153+
}
154+
155+
result = pd.DataFrame(data=columns_map.values(), index=df_old.columns, columns=[self.name])
118156
result.name = self._metric.name
119157
return result
120158

tests/test_database/test_db.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,18 @@ def tables(engine):
2323
Base.metadata.drop_all(engine)
2424

2525

26-
@pytest.fixture
27-
def db_session(engine, tables):
26+
@pytest.fixture(scope="function")
27+
def clear_utils_cache():
28+
yield utils
29+
utils.get_df_id.cache_clear()
30+
utils.get_metric_id.cache_clear()
31+
utils.get_version_id.cache_clear()
32+
utils._DATASET_ID_MAPPING = None
33+
utils._METRIC_ID_MAPPING = None
34+
35+
36+
@pytest.fixture(scope="function")
37+
def db_session(engine, tables, clear_utils_cache):
2838
connection = engine.connect()
2939
transaction = connection.begin()
3040
session = Session(bind=connection, expire_on_commit=False)
@@ -35,7 +45,6 @@ def db_session(engine, tables):
3545
base.TwoColumnMetric._session = session
3646
base.DataFrameMetric._session = session
3747
base.TwoDataFrameMetric._session = session
38-
3948
yield session
4049

4150
# Return class variables to their original state.

0 commit comments

Comments
 (0)