From 3adbbbcd50a390f924ed9a28b9de86bf74216b57 Mon Sep 17 00:00:00 2001 From: Quintin Date: Fri, 28 Jul 2023 15:41:18 -0700 Subject: [PATCH 01/15] Group observations by network before calling insert --- crmprtd/process.py | 47 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/crmprtd/process.py b/crmprtd/process.py index ea105356..f194c6d5 100644 --- a/crmprtd/process.py +++ b/crmprtd/process.py @@ -10,6 +10,8 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from pycds import Obs, Network, Variable + from crmprtd.constants import InsertStrategy from crmprtd.align import align from crmprtd.insert import insert @@ -109,6 +111,26 @@ def get_normalization_module(network): return import_module(f"crmprtd.networks.{network}.normalize") +def obs_by_network(observations, sesh): + + obs_by_network_dict = {} + for obs in observations: + Ob_var = sesh.query(Variable).filter_by(id=obs.vars_id).first() + if Ob_var: + network_name = Ob_var.network.name + if network_name not in obs_by_network_dict: + obs_by_network_dict[network_name] = [] + obs_by_network_dict[network_name].append(obs) + + # obs_by_network_dict = {} + # for obs in observations: + # network_name = obs.variable.network.name + # if network_name not in obs_by_network_dict: + # obs_by_network_dict[network_name] = [] + # obs_by_network_dict[network_name].append(obs) + return obs_by_network_dict + + def process( connection_string, sample_size, @@ -196,16 +218,23 @@ def process( log.info(obs) return + obs_by_network_dict = obs_by_network(observations, sesh) + log.info("Insert: start") - results = insert( - sesh, - observations, - strategy=insert_strategy, - bulk_chunk_size=bulk_chunk_size, - sample_size=sample_size, - ) - log.info("Insert: done") - log.info("Data insertion results", extra={"results": results, "network": network}) + + for network_key in obs_by_network_dict: + + results = insert( + sesh, + obs_by_network_dict[network_key], + strategy=insert_strategy, + bulk_chunk_size=bulk_chunk_size, + sample_size=sample_size, + ) + log.info("Insert: done") + log.info( + "Data insertion results", extra={"results": results, "network": network_key} + ) # Note: this function was buried in crmprtd.__init__.py but is From 2836f1fae4532b1bc246c518cbf2b25f3c6c8b41 Mon Sep 17 00:00:00 2001 From: Quintin Date: Fri, 28 Jul 2023 17:39:54 -0700 Subject: [PATCH 02/15] Pass black format test --- crmprtd/process.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crmprtd/process.py b/crmprtd/process.py index f194c6d5..f65b504f 100644 --- a/crmprtd/process.py +++ b/crmprtd/process.py @@ -115,7 +115,8 @@ def obs_by_network(observations, sesh): obs_by_network_dict = {} for obs in observations: - Ob_var = sesh.query(Variable).filter_by(id=obs.vars_id).first() + var_id = obs.vars_id + Ob_var = sesh.query(Variable).filter_by(id=var_id).first() if Ob_var: network_name = Ob_var.network.name if network_name not in obs_by_network_dict: From bf146435b83880b6cc96f76e31917f8fd8904b9c Mon Sep 17 00:00:00 2001 From: Quintin Date: Mon, 31 Jul 2023 10:10:01 -0700 Subject: [PATCH 03/15] Add cached function to retrieve network from var_id --- crmprtd/process.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/crmprtd/process.py b/crmprtd/process.py index f65b504f..4ed7ac00 100644 --- a/crmprtd/process.py +++ b/crmprtd/process.py @@ -11,9 +11,9 @@ from sqlalchemy.orm import sessionmaker from pycds import Obs, Network, Variable - +import functools from crmprtd.constants import InsertStrategy -from crmprtd.align import align +from crmprtd.align import align, cached_function from crmprtd.insert import insert from crmprtd.download_utils import verify_date from crmprtd.infer import infer @@ -111,24 +111,19 @@ def get_normalization_module(network): return import_module(f"crmprtd.networks.{network}.normalize") -def obs_by_network(observations, sesh): +@functools.lru_cache(maxsize=None) +def get_network(sesh, var_id): + Obs_var = sesh.query(Variable).filter_by(id=var_id).first() + return Obs_var.network.name + +def obs_by_network(observations, sesh): obs_by_network_dict = {} for obs in observations: - var_id = obs.vars_id - Ob_var = sesh.query(Variable).filter_by(id=var_id).first() - if Ob_var: - network_name = Ob_var.network.name - if network_name not in obs_by_network_dict: - obs_by_network_dict[network_name] = [] - obs_by_network_dict[network_name].append(obs) - - # obs_by_network_dict = {} - # for obs in observations: - # network_name = obs.variable.network.name - # if network_name not in obs_by_network_dict: - # obs_by_network_dict[network_name] = [] - # obs_by_network_dict[network_name].append(obs) + network_name = get_network(sesh, obs.vars_id) + if network_name not in obs_by_network_dict: + obs_by_network_dict[network_name] = [] + obs_by_network_dict[network_name].append(obs) return obs_by_network_dict @@ -224,7 +219,6 @@ def process( log.info("Insert: start") for network_key in obs_by_network_dict: - results = insert( sesh, obs_by_network_dict[network_key], From 11566696e2d0da5481a3dea5f925385c6c87d544 Mon Sep 17 00:00:00 2001 From: Quintin Date: Fri, 4 Aug 2023 15:41:03 -0700 Subject: [PATCH 04/15] Move groupby network into insert.py --- crmprtd/insert.py | 87 ++++++++++++++++++++++++++++++++++---------- crmprtd/process.py | 44 +++++----------------- tests/test_insert.py | 27 ++++++++++++++ 3 files changed, 105 insertions(+), 53 deletions(-) diff --git a/crmprtd/insert.py b/crmprtd/insert.py index 54b24e07..731f8e2f 100644 --- a/crmprtd/insert.py +++ b/crmprtd/insert.py @@ -11,14 +11,14 @@ import logging import time import random - +from itertools import groupby from sqlalchemy import and_ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.exc import IntegrityError, DBAPIError from crmprtd.constants import InsertStrategy from crmprtd.db_exceptions import InsertionError -from pycds import Obs +from pycds import Network, Obs, Variable log = logging.getLogger(__name__) @@ -54,6 +54,22 @@ def max_power_of_two(num): return 2 ** floor(mathlog(num, 2)) +def get_network_name(sesh, obs): + Obs_var = sesh.query(Variable).filter_by(id=obs.vars_id).first() + return Obs_var.network.name + + +def obs_by_network(observations, sesh): + observations.sort(key=lambda obs: get_network_name(sesh, obs)) + obs_by_network_dict = { + network_name: list(obs_group) + for network_name, obs_group in groupby( + observations, key=lambda obs: get_network_name(sesh, obs) + ) + } + return obs_by_network_dict + + def get_bisection_chunk_sizes(remainder): chunk_list = [] while remainder != 0: @@ -130,7 +146,11 @@ def insert_single_obs(sesh, obs): except IntegrityError as e: log.debug( "Failure, observation already exists", - extra={"observation": obs, "exception": e}, + extra={ + "observation": obs, + "exception": e, + "network": get_network_name(sesh, obs), + }, ) db_metrics = DBMetrics(0, 1, 0) except InsertionError as e: @@ -138,11 +158,18 @@ def insert_single_obs(sesh, obs): # SQLAlchemy unless something very tricky is going on. Why is this here? log.warning( "Failure occured during insertion", - extra={"observation": obs, "exception": e}, + extra={ + "observation": obs, + "exception": e, + "network": get_network_name(sesh, obs), + }, ) db_metrics = DBMetrics(0, 0, 1) else: - log.info("Successfully inserted observations: 1") + log.info( + "Successfully inserted observations: 1", + extra={"network": get_network_name(sesh, obs)}, + ) db_metrics = DBMetrics(1, 0, 0) sesh.commit() return db_metrics @@ -187,7 +214,10 @@ def bisect_insert_strategy(sesh, observations): but in the optimal case it reduces the transactions to a constant 1. """ - log.debug("Begin mass observation insertion", extra={"num_obs": len(observations)}) + log.debug( + "Begin mass observation insertion", + extra={"num_obs": len(observations)}, + ) # Base cases if len(observations) < 1: @@ -198,7 +228,10 @@ def bisect_insert_strategy(sesh, observations): else: try: with sesh.begin_nested(): - log.debug("New SAVEPOINT", extra={"num_obs": len(observations)}) + log.debug( + "New SAVEPOINT", + extra={"num_obs": len(observations)}, + ) sesh.add_all(observations) except IntegrityError: log.debug("Failed, splitting observations.") @@ -324,20 +357,36 @@ def insert( # in the database. sesh.commit() + obs_by_network_dict = obs_by_network(observations, sesh) + with Timer() as tmr: - if strategy is InsertStrategy.BULK: - dbm = bulk_insert_strategy(sesh, observations, chunk_size=bulk_chunk_size) - elif strategy is InsertStrategy.SINGLE: - dbm = single_insert_strategy(sesh, observations) - elif strategy is InsertStrategy.CHUNK_BISECT: - dbm = chunk_bisect_insert_strategy(sesh, observations) - elif strategy is InsertStrategy.ADAPTIVE: - if contains_all_duplicates(sesh, observations, sample_size): - dbm = single_insert_strategy(sesh, observations) + dbm = DBMetrics(0, 0, 0) + for network_key in obs_by_network_dict: + if strategy is InsertStrategy.BULK: + dbm += bulk_insert_strategy( + sesh, obs_by_network_dict[network_key], chunk_size=bulk_chunk_size + ) + elif strategy is InsertStrategy.SINGLE: + dbm += single_insert_strategy(sesh, obs_by_network_dict[network_key]) + elif strategy is InsertStrategy.CHUNK_BISECT: + dbm += chunk_bisect_insert_strategy( + sesh, obs_by_network_dict[network_key] + ) + elif strategy is InsertStrategy.ADAPTIVE: + if contains_all_duplicates( + sesh, obs_by_network_dict[network_key], sample_size + ): + dbm += single_insert_strategy( + sesh, obs_by_network_dict[network_key] + ) + else: + dbm += chunk_bisect_insert_strategy( + sesh, obs_by_network_dict[network_key] + ) else: - dbm = chunk_bisect_insert_strategy(sesh, observations) - else: - raise ValueError(f"Insert strategy has an unrecognized value: {strategy}") + raise ValueError( + f"Insert strategy has an unrecognized value: {strategy}" + ) log.info("Data insertion complete") return { diff --git a/crmprtd/process.py b/crmprtd/process.py index 4ed7ac00..ea105356 100644 --- a/crmprtd/process.py +++ b/crmprtd/process.py @@ -10,10 +10,8 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from pycds import Obs, Network, Variable -import functools from crmprtd.constants import InsertStrategy -from crmprtd.align import align, cached_function +from crmprtd.align import align from crmprtd.insert import insert from crmprtd.download_utils import verify_date from crmprtd.infer import infer @@ -111,22 +109,6 @@ def get_normalization_module(network): return import_module(f"crmprtd.networks.{network}.normalize") -@functools.lru_cache(maxsize=None) -def get_network(sesh, var_id): - Obs_var = sesh.query(Variable).filter_by(id=var_id).first() - return Obs_var.network.name - - -def obs_by_network(observations, sesh): - obs_by_network_dict = {} - for obs in observations: - network_name = get_network(sesh, obs.vars_id) - if network_name not in obs_by_network_dict: - obs_by_network_dict[network_name] = [] - obs_by_network_dict[network_name].append(obs) - return obs_by_network_dict - - def process( connection_string, sample_size, @@ -214,22 +196,16 @@ def process( log.info(obs) return - obs_by_network_dict = obs_by_network(observations, sesh) - log.info("Insert: start") - - for network_key in obs_by_network_dict: - results = insert( - sesh, - obs_by_network_dict[network_key], - strategy=insert_strategy, - bulk_chunk_size=bulk_chunk_size, - sample_size=sample_size, - ) - log.info("Insert: done") - log.info( - "Data insertion results", extra={"results": results, "network": network_key} - ) + results = insert( + sesh, + observations, + strategy=insert_strategy, + bulk_chunk_size=bulk_chunk_size, + sample_size=sample_size, + ) + log.info("Insert: done") + log.info("Data insertion results", extra={"results": results, "network": network}) # Note: this function was buried in crmprtd.__init__.py but is diff --git a/tests/test_insert.py b/tests/test_insert.py index 397b08c7..a089a83d 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -3,6 +3,7 @@ import pytest import pytz +import logging from crmprtd.more_itertools import cycles from pycds import History, Obs @@ -17,6 +18,7 @@ fixed_length_chunks, bulk_insert_strategy, Timer, + obs_by_network, ) @@ -195,3 +197,28 @@ def test_single_insert_obs_not_unique(test_session): ] dbm = single_insert_strategy(test_session, ob) assert dbm.skips == 1 + + +# def test_mult_networks(crmp_session, caplog): +# caplog.set_level(logging.DEBUG, "crmprtd") +# observations = [] + +# obs = Obs( +# history_id=20, +# vars_id=2, +# time=datetime(2012, 9, 24, 6, tzinfo=pytz.utc), +# datum=i, +# ) +# observations.append(obs) +# assert networks_logged(obs, crmp_session, caplog) + + +def networks_logged(observations, test_session, caplog): + obs_by_network_dict = obs_by_network(observations, test_session) + + networks = [] + for record in caplog.records: + if "network" in record.__dict__: + networks.append(getattr(record, "network", {})) + + return obs_by_network_dict.keys() == networks From c9b75eb568810b5ce42377c546e38be8e59ada55 Mon Sep 17 00:00:00 2001 From: Quintin Date: Mon, 7 Aug 2023 15:57:42 -0700 Subject: [PATCH 05/15] Add tests for insert from multiple networks --- crmprtd/insert.py | 16 ++++++++-- tests/test_insert.py | 72 +++++++++++++++++++++++++++++--------------- 2 files changed, 61 insertions(+), 27 deletions(-) diff --git a/crmprtd/insert.py b/crmprtd/insert.py index 731f8e2f..39f942f5 100644 --- a/crmprtd/insert.py +++ b/crmprtd/insert.py @@ -18,7 +18,7 @@ from crmprtd.constants import InsertStrategy from crmprtd.db_exceptions import InsertionError -from pycds import Network, Obs, Variable +from pycds import Obs, Variable log = logging.getLogger(__name__) @@ -56,6 +56,7 @@ def max_power_of_two(num): def get_network_name(sesh, obs): Obs_var = sesh.query(Variable).filter_by(id=obs.vars_id).first() + # print(Obs_var.network.name) return Obs_var.network.name @@ -327,9 +328,18 @@ def bulk_insert_strategy(sesh, observations, chunk_size=1000): dbm += chunk_dbm log.info( f"Bulk insert progress: " - f"{dbm.successes} inserted, {dbm.skips} skipped, {dbm.failures} failed" + f"{dbm.successes} inserted, {dbm.skips} skipped, {dbm.failures} failed", + extra={"network": get_network_name(sesh, chunk[0])}, + ) + if len(observations) > 0: + log.info( + f"Successfully inserted observations: {dbm.successes}", + extra={"network": get_network_name(sesh, observations[0])}, + ) + else: + log.info( + f"Successfully inserted observations: {dbm.successes}", ) - log.info(f"Successfully inserted observations: {dbm.successes}") return dbm diff --git a/tests/test_insert.py b/tests/test_insert.py index a089a83d..038c8f7d 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -6,8 +6,9 @@ import logging from crmprtd.more_itertools import cycles -from pycds import History, Obs +from pycds import History, Obs, Variable, Station from crmprtd.insert import ( + insert, bisect_insert_strategy, split, bisection_chunks, @@ -199,26 +200,49 @@ def test_single_insert_obs_not_unique(test_session): assert dbm.skips == 1 -# def test_mult_networks(crmp_session, caplog): -# caplog.set_level(logging.DEBUG, "crmprtd") -# observations = [] - -# obs = Obs( -# history_id=20, -# vars_id=2, -# time=datetime(2012, 9, 24, 6, tzinfo=pytz.utc), -# datum=i, -# ) -# observations.append(obs) -# assert networks_logged(obs, crmp_session, caplog) - - -def networks_logged(observations, test_session, caplog): - obs_by_network_dict = obs_by_network(observations, test_session) - - networks = [] - for record in caplog.records: - if "network" in record.__dict__: - networks.append(getattr(record, "network", {})) - - return obs_by_network_dict.keys() == networks +def test_mult_networks(test_session, caplog): + caplog.set_level(logging.DEBUG, "crmprtd") + + moti = "Brandywine" + ec = "Sechelt" + wmb = "FIVE MILE" + stations = [moti, wmb, ec] + obs = [] + with test_session: + for station in stations: + hist = ( + test_session.query(History) + .filter(History.station_name == station) + .first() + ) + var = hist.station.network.variables[0] + # print(hist.id) + # print(var.network.name) + # print(var.id) + observation = Obs( + history_id=hist.id, + datum=10, + vars_id=var.id, + time=datetime(2012, 9, 24, 6, tzinfo=pytz.utc), + ) + obs.append(observation) + # print(obs) + results = insert(test_session, obs) + assert networks_logged(obs, test_session, caplog) + + +def networks_logged(obs, test_session, caplog): + obs_by_network_dict = obs_by_network(obs, test_session) + + # networks = [] + # for record in caplog.records: + # if "network" in record.__dict__: + # networks.append(getattr(record, "network", {})) + + networks = { + getattr(record, "network") + for record in caplog.records + if getattr(record, "network", None) is not None + } + assert len(networks) == len(obs_by_network_dict) + return obs_by_network_dict.keys() == set(networks) From 6f3bbcf8490efdb17367ee1a162285a3f925361f Mon Sep 17 00:00:00 2001 From: Quintin Date: Tue, 8 Aug 2023 08:43:27 -0700 Subject: [PATCH 06/15] Remove comments and print statements --- crmprtd/insert.py | 1 - tests/test_insert.py | 13 +++---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/crmprtd/insert.py b/crmprtd/insert.py index 39f942f5..1c06bc3e 100644 --- a/crmprtd/insert.py +++ b/crmprtd/insert.py @@ -56,7 +56,6 @@ def max_power_of_two(num): def get_network_name(sesh, obs): Obs_var = sesh.query(Variable).filter_by(id=obs.vars_id).first() - # print(Obs_var.network.name) return Obs_var.network.name diff --git a/tests/test_insert.py b/tests/test_insert.py index 038c8f7d..26d0c608 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -204,9 +204,10 @@ def test_mult_networks(test_session, caplog): caplog.set_level(logging.DEBUG, "crmprtd") moti = "Brandywine" - ec = "Sechelt" wmb = "FIVE MILE" + ec = "Sechelt" stations = [moti, wmb, ec] + obs = [] with test_session: for station in stations: @@ -216,9 +217,6 @@ def test_mult_networks(test_session, caplog): .first() ) var = hist.station.network.variables[0] - # print(hist.id) - # print(var.network.name) - # print(var.id) observation = Obs( history_id=hist.id, datum=10, @@ -226,7 +224,7 @@ def test_mult_networks(test_session, caplog): time=datetime(2012, 9, 24, 6, tzinfo=pytz.utc), ) obs.append(observation) - # print(obs) + results = insert(test_session, obs) assert networks_logged(obs, test_session, caplog) @@ -234,11 +232,6 @@ def test_mult_networks(test_session, caplog): def networks_logged(obs, test_session, caplog): obs_by_network_dict = obs_by_network(obs, test_session) - # networks = [] - # for record in caplog.records: - # if "network" in record.__dict__: - # networks.append(getattr(record, "network", {})) - networks = { getattr(record, "network") for record in caplog.records From 19510bd9b605e4bbd9b7d8b58b2c3a186ebff034 Mon Sep 17 00:00:00 2001 From: Quintin Date: Tue, 8 Aug 2023 09:26:20 -0700 Subject: [PATCH 07/15] Remove unused imports, fix capitalization --- crmprtd/insert.py | 4 ++-- tests/test_insert.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crmprtd/insert.py b/crmprtd/insert.py index 1c06bc3e..2abf4599 100644 --- a/crmprtd/insert.py +++ b/crmprtd/insert.py @@ -55,8 +55,8 @@ def max_power_of_two(num): def get_network_name(sesh, obs): - Obs_var = sesh.query(Variable).filter_by(id=obs.vars_id).first() - return Obs_var.network.name + obs_var = sesh.query(Variable).filter_by(id=obs.vars_id).first() + return obs_var.network.name def obs_by_network(observations, sesh): diff --git a/tests/test_insert.py b/tests/test_insert.py index 26d0c608..25aa5f70 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -6,7 +6,7 @@ import logging from crmprtd.more_itertools import cycles -from pycds import History, Obs, Variable, Station +from pycds import History, Obs from crmprtd.insert import ( insert, bisect_insert_strategy, From 0600e89c2800174e7dfba23b5f973c81898ce33e Mon Sep 17 00:00:00 2001 From: Quintin Date: Wed, 9 Aug 2023 10:00:10 -0700 Subject: [PATCH 08/15] Simplify network tests and include cached_function decorator --- crmprtd/insert.py | 42 +++++++++++++++++------------------------- tests/test_insert.py | 15 +++++++-------- 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/crmprtd/insert.py b/crmprtd/insert.py index 2abf4599..359bf382 100644 --- a/crmprtd/insert.py +++ b/crmprtd/insert.py @@ -20,6 +20,7 @@ from crmprtd.db_exceptions import InsertionError from pycds import Obs, Variable +from crmprtd.align import cached_function log = logging.getLogger(__name__) @@ -54,17 +55,18 @@ def max_power_of_two(num): return 2 ** floor(mathlog(num, 2)) +@cached_function(["name"]) def get_network_name(sesh, obs): obs_var = sesh.query(Variable).filter_by(id=obs.vars_id).first() - return obs_var.network.name + return obs_var.network def obs_by_network(observations, sesh): - observations.sort(key=lambda obs: get_network_name(sesh, obs)) + obs_sorted = sorted(observations, key=lambda obs: get_network_name(sesh, obs).name) obs_by_network_dict = { network_name: list(obs_group) for network_name, obs_group in groupby( - observations, key=lambda obs: get_network_name(sesh, obs) + obs_sorted, key=lambda obs: get_network_name(sesh, obs).name ) } return obs_by_network_dict @@ -149,7 +151,7 @@ def insert_single_obs(sesh, obs): extra={ "observation": obs, "exception": e, - "network": get_network_name(sesh, obs), + "network": get_network_name(sesh, obs).name, }, ) db_metrics = DBMetrics(0, 1, 0) @@ -161,14 +163,14 @@ def insert_single_obs(sesh, obs): extra={ "observation": obs, "exception": e, - "network": get_network_name(sesh, obs), + "network": get_network_name(sesh, obs).name, }, ) db_metrics = DBMetrics(0, 0, 1) else: log.info( "Successfully inserted observations: 1", - extra={"network": get_network_name(sesh, obs)}, + extra={"network": get_network_name(sesh, obs).name}, ) db_metrics = DBMetrics(1, 0, 0) sesh.commit() @@ -328,12 +330,12 @@ def bulk_insert_strategy(sesh, observations, chunk_size=1000): log.info( f"Bulk insert progress: " f"{dbm.successes} inserted, {dbm.skips} skipped, {dbm.failures} failed", - extra={"network": get_network_name(sesh, chunk[0])}, + extra={"network": get_network_name(sesh, chunk[0]).name}, ) if len(observations) > 0: log.info( f"Successfully inserted observations: {dbm.successes}", - extra={"network": get_network_name(sesh, observations[0])}, + extra={"network": get_network_name(sesh, observations[0]).name}, ) else: log.info( @@ -370,28 +372,18 @@ def insert( with Timer() as tmr: dbm = DBMetrics(0, 0, 0) - for network_key in obs_by_network_dict: + for obs in obs_by_network_dict.values(): if strategy is InsertStrategy.BULK: - dbm += bulk_insert_strategy( - sesh, obs_by_network_dict[network_key], chunk_size=bulk_chunk_size - ) + dbm += bulk_insert_strategy(sesh, obs, chunk_size=bulk_chunk_size) elif strategy is InsertStrategy.SINGLE: - dbm += single_insert_strategy(sesh, obs_by_network_dict[network_key]) + dbm += single_insert_strategy(sesh, obs) elif strategy is InsertStrategy.CHUNK_BISECT: - dbm += chunk_bisect_insert_strategy( - sesh, obs_by_network_dict[network_key] - ) + dbm += chunk_bisect_insert_strategy(sesh, obs) elif strategy is InsertStrategy.ADAPTIVE: - if contains_all_duplicates( - sesh, obs_by_network_dict[network_key], sample_size - ): - dbm += single_insert_strategy( - sesh, obs_by_network_dict[network_key] - ) + if contains_all_duplicates(sesh, obs, sample_size): + dbm += single_insert_strategy(sesh, obs) else: - dbm += chunk_bisect_insert_strategy( - sesh, obs_by_network_dict[network_key] - ) + dbm += chunk_bisect_insert_strategy(sesh, obs) else: raise ValueError( f"Insert strategy has an unrecognized value: {strategy}" diff --git a/tests/test_insert.py b/tests/test_insert.py index 25aa5f70..7df65af8 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -221,21 +221,20 @@ def test_mult_networks(test_session, caplog): history_id=hist.id, datum=10, vars_id=var.id, - time=datetime(2012, 9, 24, 6, tzinfo=pytz.utc), + time=datetime(2012, 9, 25, 6, tzinfo=pytz.utc), ) obs.append(observation) + expected_networks = {"MoTIe", "FLNRO-WMB", "EC_raw"} results = insert(test_session, obs) - assert networks_logged(obs, test_session, caplog) - + assert results["successes"] == 3 + assert networks_are_logged(expected_networks, caplog) -def networks_logged(obs, test_session, caplog): - obs_by_network_dict = obs_by_network(obs, test_session) - networks = { +def networks_are_logged(networks, caplog): + rec_networks = { getattr(record, "network") for record in caplog.records if getattr(record, "network", None) is not None } - assert len(networks) == len(obs_by_network_dict) - return obs_by_network_dict.keys() == set(networks) + return networks == rec_networks From b6c68849196e9830af2c49fa10174b8f6dc7730a Mon Sep 17 00:00:00 2001 From: Quintin Date: Wed, 9 Aug 2023 12:12:41 -0700 Subject: [PATCH 09/15] move cached_function def up a level --- crmprtd/__init__.py | 40 +++++++++++++++++++++++++++++++++++++++- crmprtd/align.py | 39 +-------------------------------------- crmprtd/insert.py | 29 ++++++++++++++++------------- 3 files changed, 56 insertions(+), 52 deletions(-) diff --git a/crmprtd/__init__.py b/crmprtd/__init__.py index 3b0d874e..a3507443 100644 --- a/crmprtd/__init__.py +++ b/crmprtd/__init__.py @@ -51,7 +51,7 @@ import logging import logging.config import os.path - +from types import SimpleNamespace import pkg_resources from pkg_resources import resource_stream from collections import namedtuple @@ -234,3 +234,41 @@ def setup_logging(log_conf, log_filename, error_email, log_level, name): def subset_dict(a_dict, keys_wanted): return {key: a_dict[key] for key in keys_wanted if key in a_dict} + + +def cached_function(attrs): + """A decorator factory that can be used to cache database results + + Neither database sessions (i.e. the sesh parameter of each wrapped + function) nor SQLAlchemy mapped objects (the results of queries) are + cachable or reusable. Therefore one cannot memoize database query + functions using builtin things like the lrucache. + + This wrapper works, by a) assuming that the wrapped function's first + argument is a database session b) assuming that the result of the + query returns a single SQLAlchemy object (e.g. a History instance), + and c) accepting as a parameter a list of attributes to retrieve and + store in the cache result. + + args (except sesh) and kwargs to the wrapped function are used as + the cache key, and results are the parametrized object attributes. + """ + log = logging.getLogger(__name__) + + def wrapper(f): + cache = {} + + def memoize(sesh, *args, **kwargs): + nonlocal cache + key = (args) + tuple(kwargs.items()) + if key not in cache: + obj = f(sesh, *args, **kwargs) + log.debug(f"Cache miss: {f.__name__} {key} -> {repr(obj)}") + cache[key] = obj and SimpleNamespace( + **{attr: getattr(obj, attr) for attr in attrs} + ) + return cache[key] + + return memoize + + return wrapper diff --git a/crmprtd/align.py b/crmprtd/align.py index 9abadf13..cc2690ad 100644 --- a/crmprtd/align.py +++ b/crmprtd/align.py @@ -9,7 +9,6 @@ """ import logging -from types import SimpleNamespace from sqlalchemy import and_, cast from geoalchemy2.functions import ( @@ -25,6 +24,7 @@ from pycds import Obs, History, Network, Variable, Station from crmprtd.db_exceptions import InsertionError +from . import cached_function log = logging.getLogger(__name__) ureg = UnitRegistry() @@ -41,43 +41,6 @@ ureg.define(def_) -def cached_function(attrs): - """A decorator factory that can be used to cache database results - - Neither database sessions (i.e. the sesh parameter of each wrapped - function) nor SQLAlchemy mapped objects (the results of queries) are - cachable or reusable. Therefore one cannot memoize database query - functions using builtin things like the lrucache. - - This wrapper works, by a) assuming that the wrapped function's first - argument is a database session b) assuming that the result of the - query returns a single SQLAlchemy object (e.g. a History instance), - and c) accepting as a parameter a list of attributes to retrieve and - store in the cache result. - - args (except sesh) and kwargs to the wrapped function are used as - the cache key, and results are the parametrized object attributes. - """ - - def wrapper(f): - cache = {} - - def memoize(sesh, *args, **kwargs): - nonlocal cache - key = (args) + tuple(kwargs.items()) - if key not in cache: - obj = f(sesh, *args, **kwargs) - log.debug(f"Cache miss: {f.__name__} {key} -> {repr(obj)}") - cache[key] = obj and SimpleNamespace( - **{attr: getattr(obj, attr) for attr in attrs} - ) - return cache[key] - - return memoize - - return wrapper - - def histories_within_threshold(sesh, network_name, lon, lat, threshold): """ Find existing histories associated with the given network and within a threshold diff --git a/crmprtd/insert.py b/crmprtd/insert.py index 359bf382..c7a7d5ce 100644 --- a/crmprtd/insert.py +++ b/crmprtd/insert.py @@ -20,7 +20,7 @@ from crmprtd.db_exceptions import InsertionError from pycds import Obs, Variable -from crmprtd.align import cached_function +from . import cached_function log = logging.getLogger(__name__) @@ -55,21 +55,24 @@ def max_power_of_two(num): return 2 ** floor(mathlog(num, 2)) -@cached_function(["name"]) def get_network_name(sesh, obs): - obs_var = sesh.query(Variable).filter_by(id=obs.vars_id).first() - return obs_var.network + @cached_function(["name"]) + def network_query(sesh, obs): + obs_var = sesh.query(Variable).filter_by(id=obs.vars_id).first() + return obs_var.network + + return network_query(sesh, obs).name def obs_by_network(observations, sesh): - obs_sorted = sorted(observations, key=lambda obs: get_network_name(sesh, obs).name) - obs_by_network_dict = { + obs_sorted = sorted(observations, key=lambda obs: get_network_name(sesh, obs)) + + return { network_name: list(obs_group) for network_name, obs_group in groupby( - obs_sorted, key=lambda obs: get_network_name(sesh, obs).name + obs_sorted, key=lambda obs: get_network_name(sesh, obs) ) } - return obs_by_network_dict def get_bisection_chunk_sizes(remainder): @@ -151,7 +154,7 @@ def insert_single_obs(sesh, obs): extra={ "observation": obs, "exception": e, - "network": get_network_name(sesh, obs).name, + "network": get_network_name(sesh, obs), }, ) db_metrics = DBMetrics(0, 1, 0) @@ -163,14 +166,14 @@ def insert_single_obs(sesh, obs): extra={ "observation": obs, "exception": e, - "network": get_network_name(sesh, obs).name, + "network": get_network_name(sesh, obs), }, ) db_metrics = DBMetrics(0, 0, 1) else: log.info( "Successfully inserted observations: 1", - extra={"network": get_network_name(sesh, obs).name}, + extra={"network": get_network_name(sesh, obs)}, ) db_metrics = DBMetrics(1, 0, 0) sesh.commit() @@ -330,12 +333,12 @@ def bulk_insert_strategy(sesh, observations, chunk_size=1000): log.info( f"Bulk insert progress: " f"{dbm.successes} inserted, {dbm.skips} skipped, {dbm.failures} failed", - extra={"network": get_network_name(sesh, chunk[0]).name}, + extra={"network": get_network_name(sesh, chunk[0])}, ) if len(observations) > 0: log.info( f"Successfully inserted observations: {dbm.successes}", - extra={"network": get_network_name(sesh, observations[0]).name}, + extra={"network": get_network_name(sesh, observations[0])}, ) else: log.info( From 6c668a68f55779e3eeb4c0e5b95d3f416e4b0482 Mon Sep 17 00:00:00 2001 From: Quintin Date: Wed, 9 Aug 2023 14:23:01 -0700 Subject: [PATCH 10/15] Move log helpers from __init__.py to seperate module --- crmprtd/__init__.py | 39 --------------------------------------- crmprtd/align.py | 2 +- crmprtd/insert.py | 2 +- crmprtd/log_helpers.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 41 deletions(-) create mode 100644 crmprtd/log_helpers.py diff --git a/crmprtd/__init__.py b/crmprtd/__init__.py index a3507443..65666b8f 100644 --- a/crmprtd/__init__.py +++ b/crmprtd/__init__.py @@ -51,7 +51,6 @@ import logging import logging.config import os.path -from types import SimpleNamespace import pkg_resources from pkg_resources import resource_stream from collections import namedtuple @@ -234,41 +233,3 @@ def setup_logging(log_conf, log_filename, error_email, log_level, name): def subset_dict(a_dict, keys_wanted): return {key: a_dict[key] for key in keys_wanted if key in a_dict} - - -def cached_function(attrs): - """A decorator factory that can be used to cache database results - - Neither database sessions (i.e. the sesh parameter of each wrapped - function) nor SQLAlchemy mapped objects (the results of queries) are - cachable or reusable. Therefore one cannot memoize database query - functions using builtin things like the lrucache. - - This wrapper works, by a) assuming that the wrapped function's first - argument is a database session b) assuming that the result of the - query returns a single SQLAlchemy object (e.g. a History instance), - and c) accepting as a parameter a list of attributes to retrieve and - store in the cache result. - - args (except sesh) and kwargs to the wrapped function are used as - the cache key, and results are the parametrized object attributes. - """ - log = logging.getLogger(__name__) - - def wrapper(f): - cache = {} - - def memoize(sesh, *args, **kwargs): - nonlocal cache - key = (args) + tuple(kwargs.items()) - if key not in cache: - obj = f(sesh, *args, **kwargs) - log.debug(f"Cache miss: {f.__name__} {key} -> {repr(obj)}") - cache[key] = obj and SimpleNamespace( - **{attr: getattr(obj, attr) for attr in attrs} - ) - return cache[key] - - return memoize - - return wrapper diff --git a/crmprtd/align.py b/crmprtd/align.py index cc2690ad..13672a4a 100644 --- a/crmprtd/align.py +++ b/crmprtd/align.py @@ -24,7 +24,7 @@ from pycds import Obs, History, Network, Variable, Station from crmprtd.db_exceptions import InsertionError -from . import cached_function +from crmprtd.log_helpers import cached_function log = logging.getLogger(__name__) ureg = UnitRegistry() diff --git a/crmprtd/insert.py b/crmprtd/insert.py index c7a7d5ce..525eebc5 100644 --- a/crmprtd/insert.py +++ b/crmprtd/insert.py @@ -20,7 +20,7 @@ from crmprtd.db_exceptions import InsertionError from pycds import Obs, Variable -from . import cached_function +from crmprtd.log_helpers import cached_function log = logging.getLogger(__name__) diff --git a/crmprtd/log_helpers.py b/crmprtd/log_helpers.py new file mode 100644 index 00000000..0ea79093 --- /dev/null +++ b/crmprtd/log_helpers.py @@ -0,0 +1,42 @@ +from types import SimpleNamespace + + +def sanitize_connection(sesh): + return sesh.bind.url.render_as_string(hide_password=True) + + +def cached_function(attrs): + """A decorator factory that can be used to cache database results + Neither database sessions (i.e. the sesh parameter of each wrapped + function) nor SQLAlchemy mapped objects (the results of queries) are + cachable or reusable. Therefore one cannot memoize database query + functions using builtin things like the lrucache. + This wrapper works, by a) assuming that the wrapped function's first + argument is a database session b) assuming that the result of the + query returns a single SQLAlchemy object (e.g. a History instance), + and c) accepting as a parameter a list of attributes to retrieve and + store in the cache result. + args (except sesh) and kwargs to the wrapped function are used as + the cache key, and results are the parametrized object attributes. + """ + + def wrapper(f): + cache = {} + + def memoize(sesh, *args, **kwargs): + nonlocal cache + key = (args) + tuple(kwargs.items()) + if key not in cache: + obj = f(sesh, *args, **kwargs) + log.debug( + f"Cache miss: {f.__name__} {key} -> {repr(obj)}", + extra={"database": sanitize_connection(sesh)}, + ) + cache[key] = obj and SimpleNamespace( + **{attr: getattr(obj, attr) for attr in attrs} + ) + return cache[key] + + return memoize + + return wrapper From f6c0b4ffab445c256360742cb1ce18e46cfdd996 Mon Sep 17 00:00:00 2001 From: Quintin Date: Wed, 9 Aug 2023 14:24:26 -0700 Subject: [PATCH 11/15] Move log helpers from __init__.py to seperate module --- crmprtd/log_helpers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crmprtd/log_helpers.py b/crmprtd/log_helpers.py index 0ea79093..d2918d20 100644 --- a/crmprtd/log_helpers.py +++ b/crmprtd/log_helpers.py @@ -1,3 +1,4 @@ +import logging from types import SimpleNamespace @@ -19,6 +20,7 @@ def cached_function(attrs): args (except sesh) and kwargs to the wrapped function are used as the cache key, and results are the parametrized object attributes. """ + log = logging.getLogger(__name__) def wrapper(f): cache = {} From 97d8307204873f47952fc7ccd915789cdc4aa56f Mon Sep 17 00:00:00 2001 From: Quintin Date: Thu, 10 Aug 2023 08:17:38 -0700 Subject: [PATCH 12/15] Rename log_helpers module to db_helpers --- crmprtd/{log_helpers.py => db_helpers.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename crmprtd/{log_helpers.py => db_helpers.py} (100%) diff --git a/crmprtd/log_helpers.py b/crmprtd/db_helpers.py similarity index 100% rename from crmprtd/log_helpers.py rename to crmprtd/db_helpers.py From 5afa9a984177bd0035a4c4be9e10f565839e7fa6 Mon Sep 17 00:00:00 2001 From: Quintin Date: Thu, 10 Aug 2023 08:18:41 -0700 Subject: [PATCH 13/15] Rename log_helpers module to db_helpers --- crmprtd/align.py | 2 +- crmprtd/insert.py | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/crmprtd/align.py b/crmprtd/align.py index 13672a4a..9c36a374 100644 --- a/crmprtd/align.py +++ b/crmprtd/align.py @@ -24,7 +24,7 @@ from pycds import Obs, History, Network, Variable, Station from crmprtd.db_exceptions import InsertionError -from crmprtd.log_helpers import cached_function +from crmprtd.db_helpers import cached_function log = logging.getLogger(__name__) ureg = UnitRegistry() diff --git a/crmprtd/insert.py b/crmprtd/insert.py index 525eebc5..2b3a4938 100644 --- a/crmprtd/insert.py +++ b/crmprtd/insert.py @@ -20,7 +20,7 @@ from crmprtd.db_exceptions import InsertionError from pycds import Obs, Variable -from crmprtd.log_helpers import cached_function +from crmprtd.db_helpers import cached_function log = logging.getLogger(__name__) @@ -391,11 +391,16 @@ def insert( raise ValueError( f"Insert strategy has an unrecognized value: {strategy}" ) + results = { + "successes": dbm.successes, + "skips": dbm.skips, + "failures": dbm.failures, + "insertions_per_sec": round(dbm.successes / tmr.run_time, 2), + } + log.info( + "Data insertion results", + extra={"results": results, "network": get_network_name(sesh, obs)}, + ) log.info("Data insertion complete") - return { - "successes": dbm.successes, - "skips": dbm.skips, - "failures": dbm.failures, - "insertions_per_sec": round(dbm.successes / tmr.run_time, 2), - } + return results From a4e9232fcd60f865494f0896216a6a49f89f4d13 Mon Sep 17 00:00:00 2001 From: Quintin Date: Thu, 10 Aug 2023 10:55:23 -0700 Subject: [PATCH 14/15] Log per-network db_metrics in insert.py --- crmprtd/insert.py | 41 ++++++++++++++++++++++++++++------------- crmprtd/process.py | 2 -- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/crmprtd/insert.py b/crmprtd/insert.py index 2b3a4938..a0418f9b 100644 --- a/crmprtd/insert.py +++ b/crmprtd/insert.py @@ -371,11 +371,20 @@ def insert( # in the database. sesh.commit() + results_total = { + "successes": 0, + "skips": 0, + "failures": 0, + "insertions_per_sec": 0, + } + if len(observations) < 1: + return results_total + obs_by_network_dict = obs_by_network(observations, sesh) - with Timer() as tmr: + for obs in obs_by_network_dict.values(): dbm = DBMetrics(0, 0, 0) - for obs in obs_by_network_dict.values(): + with Timer() as tmr: if strategy is InsertStrategy.BULK: dbm += bulk_insert_strategy(sesh, obs, chunk_size=bulk_chunk_size) elif strategy is InsertStrategy.SINGLE: @@ -391,16 +400,22 @@ def insert( raise ValueError( f"Insert strategy has an unrecognized value: {strategy}" ) - results = { - "successes": dbm.successes, - "skips": dbm.skips, - "failures": dbm.failures, - "insertions_per_sec": round(dbm.successes / tmr.run_time, 2), - } - log.info( - "Data insertion results", - extra={"results": results, "network": get_network_name(sesh, obs)}, + results = { + "successes": dbm.successes, + "skips": dbm.skips, + "failures": dbm.failures, + "insertions_per_sec": round(dbm.successes / tmr.run_time, 2), + } + log.info( + "Insert for network: {network}: done".format( + network=get_network_name(sesh, obs[0]) ) + ) + log.info( + "Data insertion results", + extra={"results": results, "network": get_network_name(sesh, obs[0])}, + ) - log.info("Data insertion complete") - return results + for k, v in results.items(): + results_total[k] += v + return results_total diff --git a/crmprtd/process.py b/crmprtd/process.py index ea105356..38c65611 100644 --- a/crmprtd/process.py +++ b/crmprtd/process.py @@ -204,8 +204,6 @@ def process( bulk_chunk_size=bulk_chunk_size, sample_size=sample_size, ) - log.info("Insert: done") - log.info("Data insertion results", extra={"results": results, "network": network}) # Note: this function was buried in crmprtd.__init__.py but is From 15ddf0571707a4012da1783ee07b47998f1733c7 Mon Sep 17 00:00:00 2001 From: Quintin Date: Thu, 10 Aug 2023 11:00:50 -0700 Subject: [PATCH 15/15] Format insert for network log --- crmprtd/insert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crmprtd/insert.py b/crmprtd/insert.py index a0418f9b..3cd29974 100644 --- a/crmprtd/insert.py +++ b/crmprtd/insert.py @@ -407,7 +407,7 @@ def insert( "insertions_per_sec": round(dbm.successes / tmr.run_time, 2), } log.info( - "Insert for network: {network}: done".format( + "Insert for network {network}: done".format( network=get_network_name(sesh, obs[0]) ) )