Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion crmprtd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import logging
import logging.config
import os.path

import pkg_resources
from pkg_resources import resource_stream
from collections import namedtuple
Expand Down
39 changes: 1 addition & 38 deletions crmprtd/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"""

import logging
from types import SimpleNamespace

from sqlalchemy import and_, cast
from geoalchemy2.functions import (
Expand All @@ -25,6 +24,7 @@
from pycds import Obs, History, Network, Variable, Station
from crmprtd.db_exceptions import InsertionError

from crmprtd.db_helpers import cached_function

log = logging.getLogger(__name__)
ureg = UnitRegistry()
Expand All @@ -41,43 +41,6 @@
ureg.define(def_)


def cached_function(attrs):
"""A decorator factory that can be used to cache database results

Neither database sessions (i.e. the sesh parameter of each wrapped
function) nor SQLAlchemy mapped objects (the results of queries) are
cachable or reusable. Therefore one cannot memoize database query
functions using builtin things like the lrucache.

This wrapper works, by a) assuming that the wrapped function's first
argument is a database session b) assuming that the result of the
query returns a single SQLAlchemy object (e.g. a History instance),
and c) accepting as a parameter a list of attributes to retrieve and
store in the cache result.

args (except sesh) and kwargs to the wrapped function are used as
the cache key, and results are the parametrized object attributes.
"""

def wrapper(f):
cache = {}

def memoize(sesh, *args, **kwargs):
nonlocal cache
key = (args) + tuple(kwargs.items())
if key not in cache:
obj = f(sesh, *args, **kwargs)
log.debug(f"Cache miss: {f.__name__} {key} -> {repr(obj)}")
cache[key] = obj and SimpleNamespace(
**{attr: getattr(obj, attr) for attr in attrs}
)
return cache[key]

return memoize

return wrapper


def histories_within_threshold(sesh, network_name, lon, lat, threshold):
"""
Find existing histories associated with the given network and within a threshold
Expand Down
44 changes: 44 additions & 0 deletions crmprtd/db_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
from types import SimpleNamespace


def sanitize_connection(sesh):
return sesh.bind.url.render_as_string(hide_password=True)


def cached_function(attrs):
"""A decorator factory that can be used to cache database results
Neither database sessions (i.e. the sesh parameter of each wrapped
function) nor SQLAlchemy mapped objects (the results of queries) are
cachable or reusable. Therefore one cannot memoize database query
functions using builtin things like the lrucache.
This wrapper works, by a) assuming that the wrapped function's first
argument is a database session b) assuming that the result of the
query returns a single SQLAlchemy object (e.g. a History instance),
and c) accepting as a parameter a list of attributes to retrieve and
store in the cache result.
args (except sesh) and kwargs to the wrapped function are used as
the cache key, and results are the parametrized object attributes.
"""
log = logging.getLogger(__name__)

def wrapper(f):
cache = {}

def memoize(sesh, *args, **kwargs):
nonlocal cache
key = (args) + tuple(kwargs.items())
if key not in cache:
obj = f(sesh, *args, **kwargs)
log.debug(
f"Cache miss: {f.__name__} {key} -> {repr(obj)}",
extra={"database": sanitize_connection(sesh)},
)
cache[key] = obj and SimpleNamespace(
**{attr: getattr(obj, attr) for attr in attrs}
)
return cache[key]

return memoize

return wrapper
131 changes: 102 additions & 29 deletions crmprtd/insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
import logging
import time
import random

from itertools import groupby
from sqlalchemy import and_
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.exc import IntegrityError, DBAPIError

from crmprtd.constants import InsertStrategy
from crmprtd.db_exceptions import InsertionError
from pycds import Obs
from pycds import Obs, Variable

from crmprtd.db_helpers import cached_function

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,6 +55,26 @@ def max_power_of_two(num):
return 2 ** floor(mathlog(num, 2))


def get_network_name(sesh, obs):
@cached_function(["name"])
def network_query(sesh, obs):
obs_var = sesh.query(Variable).filter_by(id=obs.vars_id).first()
return obs_var.network

return network_query(sesh, obs).name


def obs_by_network(observations, sesh):
obs_sorted = sorted(observations, key=lambda obs: get_network_name(sesh, obs))

return {
network_name: list(obs_group)
for network_name, obs_group in groupby(
obs_sorted, key=lambda obs: get_network_name(sesh, obs)
)
}


def get_bisection_chunk_sizes(remainder):
chunk_list = []
while remainder != 0:
Expand Down Expand Up @@ -130,19 +151,30 @@ def insert_single_obs(sesh, obs):
except IntegrityError as e:
log.debug(
"Failure, observation already exists",
extra={"observation": obs, "exception": e},
extra={
"observation": obs,
"exception": e,
"network": get_network_name(sesh, obs),
},
)
db_metrics = DBMetrics(0, 1, 0)
except InsertionError as e:
# TODO: InsertionError is an defined by crmprtd. It can't be raised by
# SQLAlchemy unless something very tricky is going on. Why is this here?
log.warning(
"Failure occured during insertion",
extra={"observation": obs, "exception": e},
extra={
"observation": obs,
"exception": e,
"network": get_network_name(sesh, obs),
},
)
db_metrics = DBMetrics(0, 0, 1)
else:
log.info("Successfully inserted observations: 1")
log.info(
"Successfully inserted observations: 1",
extra={"network": get_network_name(sesh, obs)},
)
db_metrics = DBMetrics(1, 0, 0)
sesh.commit()
return db_metrics
Expand Down Expand Up @@ -187,7 +219,10 @@ def bisect_insert_strategy(sesh, observations):
but in the optimal case it reduces the transactions to a constant
1.
"""
log.debug("Begin mass observation insertion", extra={"num_obs": len(observations)})
log.debug(
"Begin mass observation insertion",
extra={"num_obs": len(observations)},
)

# Base cases
if len(observations) < 1:
Expand All @@ -198,7 +233,10 @@ def bisect_insert_strategy(sesh, observations):
else:
try:
with sesh.begin_nested():
log.debug("New SAVEPOINT", extra={"num_obs": len(observations)})
log.debug(
"New SAVEPOINT",
extra={"num_obs": len(observations)},
)
sesh.add_all(observations)
except IntegrityError:
log.debug("Failed, splitting observations.")
Expand Down Expand Up @@ -294,9 +332,18 @@ def bulk_insert_strategy(sesh, observations, chunk_size=1000):
dbm += chunk_dbm
log.info(
f"Bulk insert progress: "
f"{dbm.successes} inserted, {dbm.skips} skipped, {dbm.failures} failed"
f"{dbm.successes} inserted, {dbm.skips} skipped, {dbm.failures} failed",
extra={"network": get_network_name(sesh, chunk[0])},
)
if len(observations) > 0:
log.info(
f"Successfully inserted observations: {dbm.successes}",
extra={"network": get_network_name(sesh, observations[0])},
)
else:
log.info(
f"Successfully inserted observations: {dbm.successes}",
)
log.info(f"Successfully inserted observations: {dbm.successes}")
return dbm


Expand Down Expand Up @@ -324,25 +371,51 @@ def insert(
# in the database.
sesh.commit()

with Timer() as tmr:
if strategy is InsertStrategy.BULK:
dbm = bulk_insert_strategy(sesh, observations, chunk_size=bulk_chunk_size)
elif strategy is InsertStrategy.SINGLE:
dbm = single_insert_strategy(sesh, observations)
elif strategy is InsertStrategy.CHUNK_BISECT:
dbm = chunk_bisect_insert_strategy(sesh, observations)
elif strategy is InsertStrategy.ADAPTIVE:
if contains_all_duplicates(sesh, observations, sample_size):
dbm = single_insert_strategy(sesh, observations)
results_total = {
"successes": 0,
"skips": 0,
"failures": 0,
"insertions_per_sec": 0,
}
if len(observations) < 1:
return results_total

obs_by_network_dict = obs_by_network(observations, sesh)

for obs in obs_by_network_dict.values():
dbm = DBMetrics(0, 0, 0)
with Timer() as tmr:
if strategy is InsertStrategy.BULK:
dbm += bulk_insert_strategy(sesh, obs, chunk_size=bulk_chunk_size)
elif strategy is InsertStrategy.SINGLE:
dbm += single_insert_strategy(sesh, obs)
elif strategy is InsertStrategy.CHUNK_BISECT:
dbm += chunk_bisect_insert_strategy(sesh, obs)
elif strategy is InsertStrategy.ADAPTIVE:
if contains_all_duplicates(sesh, obs, sample_size):
dbm += single_insert_strategy(sesh, obs)
else:
dbm += chunk_bisect_insert_strategy(sesh, obs)
else:
dbm = chunk_bisect_insert_strategy(sesh, observations)
else:
raise ValueError(f"Insert strategy has an unrecognized value: {strategy}")
raise ValueError(
f"Insert strategy has an unrecognized value: {strategy}"
)
results = {
"successes": dbm.successes,
"skips": dbm.skips,
"failures": dbm.failures,
"insertions_per_sec": round(dbm.successes / tmr.run_time, 2),
}
log.info(
"Insert for network {network}: done".format(
network=get_network_name(sesh, obs[0])
)
)
log.info(
"Data insertion results",
extra={"results": results, "network": get_network_name(sesh, obs[0])},
)

log.info("Data insertion complete")
return {
"successes": dbm.successes,
"skips": dbm.skips,
"failures": dbm.failures,
"insertions_per_sec": round(dbm.successes / tmr.run_time, 2),
}
for k, v in results.items():
results_total[k] += v
return results_total
2 changes: 0 additions & 2 deletions crmprtd/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ def process(
bulk_chunk_size=bulk_chunk_size,
sample_size=sample_size,
)
log.info("Insert: done")
log.info("Data insertion results", extra={"results": results, "network": network})


# Note: this function was buried in crmprtd.__init__.py but is
Expand Down
43 changes: 43 additions & 0 deletions tests/test_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import pytest
import pytz
import logging

from crmprtd.more_itertools import cycles
from pycds import History, Obs
from crmprtd.insert import (
insert,
bisect_insert_strategy,
split,
bisection_chunks,
Expand All @@ -17,6 +19,7 @@
fixed_length_chunks,
bulk_insert_strategy,
Timer,
obs_by_network,
)


Expand Down Expand Up @@ -195,3 +198,43 @@ def test_single_insert_obs_not_unique(test_session):
]
dbm = single_insert_strategy(test_session, ob)
assert dbm.skips == 1


def test_mult_networks(test_session, caplog):
caplog.set_level(logging.DEBUG, "crmprtd")

moti = "Brandywine"
wmb = "FIVE MILE"
ec = "Sechelt"
stations = [moti, wmb, ec]

obs = []
with test_session:
for station in stations:
hist = (
test_session.query(History)
.filter(History.station_name == station)
.first()
)
var = hist.station.network.variables[0]
observation = Obs(
history_id=hist.id,
datum=10,
vars_id=var.id,
time=datetime(2012, 9, 25, 6, tzinfo=pytz.utc),
)
obs.append(observation)

expected_networks = {"MoTIe", "FLNRO-WMB", "EC_raw"}
results = insert(test_session, obs)
assert results["successes"] == 3
assert networks_are_logged(expected_networks, caplog)


def networks_are_logged(networks, caplog):
rec_networks = {
getattr(record, "network")
for record in caplog.records
if getattr(record, "network", None) is not None
}
return networks == rec_networks