diff --git a/crmprtd/__init__.py b/crmprtd/__init__.py index 3b0d874e..65666b8f 100644 --- a/crmprtd/__init__.py +++ b/crmprtd/__init__.py @@ -51,7 +51,6 @@ import logging import logging.config import os.path - import pkg_resources from pkg_resources import resource_stream from collections import namedtuple diff --git a/crmprtd/align.py b/crmprtd/align.py index 9abadf13..9c36a374 100644 --- a/crmprtd/align.py +++ b/crmprtd/align.py @@ -9,7 +9,6 @@ """ import logging -from types import SimpleNamespace from sqlalchemy import and_, cast from geoalchemy2.functions import ( @@ -25,6 +24,7 @@ from pycds import Obs, History, Network, Variable, Station from crmprtd.db_exceptions import InsertionError +from crmprtd.db_helpers import cached_function log = logging.getLogger(__name__) ureg = UnitRegistry() @@ -41,43 +41,6 @@ ureg.define(def_) -def cached_function(attrs): - """A decorator factory that can be used to cache database results - - Neither database sessions (i.e. the sesh parameter of each wrapped - function) nor SQLAlchemy mapped objects (the results of queries) are - cachable or reusable. Therefore one cannot memoize database query - functions using builtin things like the lrucache. - - This wrapper works, by a) assuming that the wrapped function's first - argument is a database session b) assuming that the result of the - query returns a single SQLAlchemy object (e.g. a History instance), - and c) accepting as a parameter a list of attributes to retrieve and - store in the cache result. - - args (except sesh) and kwargs to the wrapped function are used as - the cache key, and results are the parametrized object attributes. - """ - - def wrapper(f): - cache = {} - - def memoize(sesh, *args, **kwargs): - nonlocal cache - key = (args) + tuple(kwargs.items()) - if key not in cache: - obj = f(sesh, *args, **kwargs) - log.debug(f"Cache miss: {f.__name__} {key} -> {repr(obj)}") - cache[key] = obj and SimpleNamespace( - **{attr: getattr(obj, attr) for attr in attrs} - ) - return cache[key] - - return memoize - - return wrapper - - def histories_within_threshold(sesh, network_name, lon, lat, threshold): """ Find existing histories associated with the given network and within a threshold diff --git a/crmprtd/db_helpers.py b/crmprtd/db_helpers.py new file mode 100644 index 00000000..d2918d20 --- /dev/null +++ b/crmprtd/db_helpers.py @@ -0,0 +1,44 @@ +import logging +from types import SimpleNamespace + + +def sanitize_connection(sesh): + return sesh.bind.url.render_as_string(hide_password=True) + + +def cached_function(attrs): + """A decorator factory that can be used to cache database results + Neither database sessions (i.e. the sesh parameter of each wrapped + function) nor SQLAlchemy mapped objects (the results of queries) are + cachable or reusable. Therefore one cannot memoize database query + functions using builtin things like the lrucache. + This wrapper works, by a) assuming that the wrapped function's first + argument is a database session b) assuming that the result of the + query returns a single SQLAlchemy object (e.g. a History instance), + and c) accepting as a parameter a list of attributes to retrieve and + store in the cache result. + args (except sesh) and kwargs to the wrapped function are used as + the cache key, and results are the parametrized object attributes. + """ + log = logging.getLogger(__name__) + + def wrapper(f): + cache = {} + + def memoize(sesh, *args, **kwargs): + nonlocal cache + key = (args) + tuple(kwargs.items()) + if key not in cache: + obj = f(sesh, *args, **kwargs) + log.debug( + f"Cache miss: {f.__name__} {key} -> {repr(obj)}", + extra={"database": sanitize_connection(sesh)}, + ) + cache[key] = obj and SimpleNamespace( + **{attr: getattr(obj, attr) for attr in attrs} + ) + return cache[key] + + return memoize + + return wrapper diff --git a/crmprtd/insert.py b/crmprtd/insert.py index 54b24e07..3cd29974 100644 --- a/crmprtd/insert.py +++ b/crmprtd/insert.py @@ -11,15 +11,16 @@ import logging import time import random - +from itertools import groupby from sqlalchemy import and_ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.exc import IntegrityError, DBAPIError from crmprtd.constants import InsertStrategy from crmprtd.db_exceptions import InsertionError -from pycds import Obs +from pycds import Obs, Variable +from crmprtd.db_helpers import cached_function log = logging.getLogger(__name__) @@ -54,6 +55,26 @@ def max_power_of_two(num): return 2 ** floor(mathlog(num, 2)) +def get_network_name(sesh, obs): + @cached_function(["name"]) + def network_query(sesh, obs): + obs_var = sesh.query(Variable).filter_by(id=obs.vars_id).first() + return obs_var.network + + return network_query(sesh, obs).name + + +def obs_by_network(observations, sesh): + obs_sorted = sorted(observations, key=lambda obs: get_network_name(sesh, obs)) + + return { + network_name: list(obs_group) + for network_name, obs_group in groupby( + obs_sorted, key=lambda obs: get_network_name(sesh, obs) + ) + } + + def get_bisection_chunk_sizes(remainder): chunk_list = [] while remainder != 0: @@ -130,7 +151,11 @@ def insert_single_obs(sesh, obs): except IntegrityError as e: log.debug( "Failure, observation already exists", - extra={"observation": obs, "exception": e}, + extra={ + "observation": obs, + "exception": e, + "network": get_network_name(sesh, obs), + }, ) db_metrics = DBMetrics(0, 1, 0) except InsertionError as e: @@ -138,11 +163,18 @@ def insert_single_obs(sesh, obs): # SQLAlchemy unless something very tricky is going on. Why is this here? log.warning( "Failure occured during insertion", - extra={"observation": obs, "exception": e}, + extra={ + "observation": obs, + "exception": e, + "network": get_network_name(sesh, obs), + }, ) db_metrics = DBMetrics(0, 0, 1) else: - log.info("Successfully inserted observations: 1") + log.info( + "Successfully inserted observations: 1", + extra={"network": get_network_name(sesh, obs)}, + ) db_metrics = DBMetrics(1, 0, 0) sesh.commit() return db_metrics @@ -187,7 +219,10 @@ def bisect_insert_strategy(sesh, observations): but in the optimal case it reduces the transactions to a constant 1. """ - log.debug("Begin mass observation insertion", extra={"num_obs": len(observations)}) + log.debug( + "Begin mass observation insertion", + extra={"num_obs": len(observations)}, + ) # Base cases if len(observations) < 1: @@ -198,7 +233,10 @@ def bisect_insert_strategy(sesh, observations): else: try: with sesh.begin_nested(): - log.debug("New SAVEPOINT", extra={"num_obs": len(observations)}) + log.debug( + "New SAVEPOINT", + extra={"num_obs": len(observations)}, + ) sesh.add_all(observations) except IntegrityError: log.debug("Failed, splitting observations.") @@ -294,9 +332,18 @@ def bulk_insert_strategy(sesh, observations, chunk_size=1000): dbm += chunk_dbm log.info( f"Bulk insert progress: " - f"{dbm.successes} inserted, {dbm.skips} skipped, {dbm.failures} failed" + f"{dbm.successes} inserted, {dbm.skips} skipped, {dbm.failures} failed", + extra={"network": get_network_name(sesh, chunk[0])}, + ) + if len(observations) > 0: + log.info( + f"Successfully inserted observations: {dbm.successes}", + extra={"network": get_network_name(sesh, observations[0])}, + ) + else: + log.info( + f"Successfully inserted observations: {dbm.successes}", ) - log.info(f"Successfully inserted observations: {dbm.successes}") return dbm @@ -324,25 +371,51 @@ def insert( # in the database. sesh.commit() - with Timer() as tmr: - if strategy is InsertStrategy.BULK: - dbm = bulk_insert_strategy(sesh, observations, chunk_size=bulk_chunk_size) - elif strategy is InsertStrategy.SINGLE: - dbm = single_insert_strategy(sesh, observations) - elif strategy is InsertStrategy.CHUNK_BISECT: - dbm = chunk_bisect_insert_strategy(sesh, observations) - elif strategy is InsertStrategy.ADAPTIVE: - if contains_all_duplicates(sesh, observations, sample_size): - dbm = single_insert_strategy(sesh, observations) + results_total = { + "successes": 0, + "skips": 0, + "failures": 0, + "insertions_per_sec": 0, + } + if len(observations) < 1: + return results_total + + obs_by_network_dict = obs_by_network(observations, sesh) + + for obs in obs_by_network_dict.values(): + dbm = DBMetrics(0, 0, 0) + with Timer() as tmr: + if strategy is InsertStrategy.BULK: + dbm += bulk_insert_strategy(sesh, obs, chunk_size=bulk_chunk_size) + elif strategy is InsertStrategy.SINGLE: + dbm += single_insert_strategy(sesh, obs) + elif strategy is InsertStrategy.CHUNK_BISECT: + dbm += chunk_bisect_insert_strategy(sesh, obs) + elif strategy is InsertStrategy.ADAPTIVE: + if contains_all_duplicates(sesh, obs, sample_size): + dbm += single_insert_strategy(sesh, obs) + else: + dbm += chunk_bisect_insert_strategy(sesh, obs) else: - dbm = chunk_bisect_insert_strategy(sesh, observations) - else: - raise ValueError(f"Insert strategy has an unrecognized value: {strategy}") + raise ValueError( + f"Insert strategy has an unrecognized value: {strategy}" + ) + results = { + "successes": dbm.successes, + "skips": dbm.skips, + "failures": dbm.failures, + "insertions_per_sec": round(dbm.successes / tmr.run_time, 2), + } + log.info( + "Insert for network {network}: done".format( + network=get_network_name(sesh, obs[0]) + ) + ) + log.info( + "Data insertion results", + extra={"results": results, "network": get_network_name(sesh, obs[0])}, + ) - log.info("Data insertion complete") - return { - "successes": dbm.successes, - "skips": dbm.skips, - "failures": dbm.failures, - "insertions_per_sec": round(dbm.successes / tmr.run_time, 2), - } + for k, v in results.items(): + results_total[k] += v + return results_total diff --git a/crmprtd/process.py b/crmprtd/process.py index ea105356..38c65611 100644 --- a/crmprtd/process.py +++ b/crmprtd/process.py @@ -204,8 +204,6 @@ def process( bulk_chunk_size=bulk_chunk_size, sample_size=sample_size, ) - log.info("Insert: done") - log.info("Data insertion results", extra={"results": results, "network": network}) # Note: this function was buried in crmprtd.__init__.py but is diff --git a/tests/test_insert.py b/tests/test_insert.py index 397b08c7..7df65af8 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -3,10 +3,12 @@ import pytest import pytz +import logging from crmprtd.more_itertools import cycles from pycds import History, Obs from crmprtd.insert import ( + insert, bisect_insert_strategy, split, bisection_chunks, @@ -17,6 +19,7 @@ fixed_length_chunks, bulk_insert_strategy, Timer, + obs_by_network, ) @@ -195,3 +198,43 @@ def test_single_insert_obs_not_unique(test_session): ] dbm = single_insert_strategy(test_session, ob) assert dbm.skips == 1 + + +def test_mult_networks(test_session, caplog): + caplog.set_level(logging.DEBUG, "crmprtd") + + moti = "Brandywine" + wmb = "FIVE MILE" + ec = "Sechelt" + stations = [moti, wmb, ec] + + obs = [] + with test_session: + for station in stations: + hist = ( + test_session.query(History) + .filter(History.station_name == station) + .first() + ) + var = hist.station.network.variables[0] + observation = Obs( + history_id=hist.id, + datum=10, + vars_id=var.id, + time=datetime(2012, 9, 25, 6, tzinfo=pytz.utc), + ) + obs.append(observation) + + expected_networks = {"MoTIe", "FLNRO-WMB", "EC_raw"} + results = insert(test_session, obs) + assert results["successes"] == 3 + assert networks_are_logged(expected_networks, caplog) + + +def networks_are_logged(networks, caplog): + rec_networks = { + getattr(record, "network") + for record in caplog.records + if getattr(record, "network", None) is not None + } + return networks == rec_networks