Skip to content
Open
61 changes: 50 additions & 11 deletions crmprtd/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pycds import Obs, History, Network, Variable, Station
from crmprtd.db_exceptions import InsertionError

from crmprtd.insert import sanitize_connection

log = logging.getLogger(__name__)
ureg = UnitRegistry()
Expand Down Expand Up @@ -67,7 +67,10 @@ def memoize(sesh, *args, **kwargs):
key = (args) + tuple(kwargs.items())
if key not in cache:
obj = f(sesh, *args, **kwargs)
log.debug(f"Cache miss: {f.__name__} {key} -> {repr(obj)}")
log.debug(
f"Cache miss: {f.__name__} {key} -> {repr(obj)}",
extra={"database": sanitize_connection(sesh)},
)
cache[key] = obj and SimpleNamespace(
**{attr: getattr(obj, attr) for attr in attrs}
)
Expand Down Expand Up @@ -196,7 +199,11 @@ def find_nearest_history(
for history in histories:
if close_history.history_id == history.id:
log.debug(
"Matched history", extra={"station_name": history.station_name}
"Matched history",
extra={
"station_name": history.station_name,
"database": sanitize_connection(sesh),
},
)
return history

Expand Down Expand Up @@ -231,7 +238,11 @@ def create_station_and_history_entry(
station = Station(native_id=native_id, network_id=network.id)
log.info(
f"{action} new station entry",
extra={"native_id": station.native_id, "network_name": network.name},
extra={
"native_id": station.native_id,
"network_name": network.name,
"database": sanitize_connection(sesh),
},
)

history = History(station=station, lat=lat, lon=lon)
Expand All @@ -243,13 +254,15 @@ def create_station_and_history_entry(
"native_id": station.native_id,
"lat": lat,
"lon": lon,
"database": sanitize_connection(sesh),
},
)

if diagnostic:
log.info(
f"In diagnostic mode. Skipping insertion of new history entry for: "
f"network_name={network_name}, native_id={native_id}, lat={lat}, lon={lon}"
f"network_name={network_name}, native_id={native_id}, lat={lat}, lon={lon}",
extra={"database": sanitize_connection(sesh)},
)
return None

Expand All @@ -260,7 +273,12 @@ def create_station_and_history_entry(
except Exception as e:
log.warning(
"Unable to insert new stn/hist entries",
extra={"stn": station, "hist": history, "exception": e},
extra={
"stn": station,
"hist": history,
"exception": e,
"database": sanitize_connection(sesh),
},
)
raise InsertionError(native_id=station.id, hid=history.id, e=e)
sesh.commit()
Expand Down Expand Up @@ -320,7 +338,11 @@ def find_or_create_matching_history_and_station(
- If at least one is found within tolerance distance, return one.
- If none are found within tolerance, this is an error condition, return None.
"""
log.debug("Searching for native_id = %s", native_id)
log.debug(
"Searching for native_id = %s",
native_id,
extra={"database": sanitize_connection(sesh)},
)
histories = (
sesh.query(History)
.join(Station)
Expand All @@ -329,15 +351,25 @@ def find_or_create_matching_history_and_station(
)

if histories.count() == 0:
log.debug("Cound not find native_id %s", native_id)
log.debug(
"Cound not find native_id %s",
native_id,
extra={"database": sanitize_connection(sesh)},
)
return create_station_and_history_entry(
sesh, network_name, native_id, lat, lon, diagnostic=diagnostic
)
elif histories.count() == 1:
log.debug("Found exactly one matching history_id")
log.debug(
"Found exactly one matching history_id",
extra={"database": sanitize_connection(sesh)},
)
return histories.one_or_none()
elif histories.count() >= 2:
log.debug("Found multiple history entries. Searching for match.")
log.debug(
"Found multiple history entries. Searching for match.",
extra={"database": sanitize_connection(sesh)},
)
return match_history(
sesh, network_name, native_id, lat, lon, histories, diagnostic=diagnostic
)
Expand Down Expand Up @@ -383,6 +415,7 @@ def align(sesh, row, diagnostic=False):
"time": row.time,
"val": row.val,
"variable_name": row.variable_name,
"database": sanitize_connection(sesh),
},
)
return None
Expand All @@ -391,7 +424,10 @@ def align(sesh, row, diagnostic=False):
if not get_network(sesh, row.network_name):
log.error(
"Network does not exist in db",
extra={"network_name": row.network_name},
extra={
"network_name": row.network_name,
"database": sanitize_connection(sesh),
},
)
return None

Expand All @@ -410,6 +446,7 @@ def align(sesh, row, diagnostic=False):
extra={
"network_name": row.network_name,
"native_id": row.station_id,
"database": sanitize_connection(sesh),
},
)
return None
Expand All @@ -421,6 +458,7 @@ def align(sesh, row, diagnostic=False):
'Variable "%s" from network "%s" is not tracked by crmp',
row.variable_name,
row.network_name,
extra={"database": sanitize_connection(sesh)},
)
return None
else:
Expand All @@ -437,6 +475,7 @@ def align(sesh, row, diagnostic=False):
"unit_db": var_unit,
"data": row.val,
"network_name": row.network_name,
"database": sanitize_connection(sesh),
},
)
return None
Expand Down
7 changes: 4 additions & 3 deletions crmprtd/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# local
from pycds import Network, Variable
from crmprtd.align import get_variable, find_or_create_matching_history_and_station

from crmprtd.insert import sanitize_connection

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -114,13 +114,14 @@ def infer(sesh, rows, diagnostic=False):
for var in variables:
log.info(
f"INSERT INTO meta_vars (network_id, net_var_name, unit) "
f"VALUES ({var.network.id}, '{var.name}', '{var.unit}')"
f"VALUES ({var.network.id}, '{var.name}', '{var.unit}')",
extra={"database": sanitize_connection(sesh)},
)

if diagnostic:
nested.rollback()
elif len(variables) > 0:
raise ValueError(
f"{len(variables)} Variables need to be inserted (see log). "
f"This is not possible without human intervention."
f"This is not possible without human intervention.",
)
70 changes: 57 additions & 13 deletions crmprtd/insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def max_power_of_two(num):
return 2 ** floor(mathlog(num, 2))


def sanitize_connection(sesh):
return sesh.bind.url.render_as_string(hide_password=True)


def get_bisection_chunk_sizes(remainder):
chunk_list = []
while remainder != 0:
Expand Down Expand Up @@ -124,32 +128,48 @@ def insert_single_obs(sesh, obs):
try:
# Create a nested SAVEPOINT context manager to rollback to in the
# event of unique constraint errors
log.debug("New SAVEPOINT for single observation")
log.debug(
"New SAVEPOINT for single observation",
extra={"database": sanitize_connection(sesh)},
)
with sesh.begin_nested():
sesh.add(obs)
except IntegrityError as e:
log.debug(
"Failure, observation already exists",
extra={"observation": obs, "exception": e},
extra={
"observation": obs,
"exception": e,
"database": sanitize_connection(sesh),
},
)
db_metrics = DBMetrics(0, 1, 0)
except InsertionError as e:
# TODO: InsertionError is an defined by crmprtd. It can't be raised by
# SQLAlchemy unless something very tricky is going on. Why is this here?
log.warning(
"Failure occured during insertion",
extra={"observation": obs, "exception": e},
extra={
"observation": obs,
"exception": e,
"database": sanitize_connection(sesh),
},
)
db_metrics = DBMetrics(0, 0, 1)
else:
log.info("Successfully inserted observations: 1")
log.info(
"Successfully inserted observations: 1",
extra={"database": sanitize_connection(sesh)},
)
db_metrics = DBMetrics(1, 0, 0)
sesh.commit()
return db_metrics


def single_insert_strategy(sesh, observations):
log.info("Using Single Insert Strategy")
log.info(
"Using Single Insert Strategy", extra={"database": sanitize_connection(sesh)}
)
dbm = DBMetrics(0, 0, 0)
for obs in observations:
dbm += insert_single_obs(sesh, obs)
Expand Down Expand Up @@ -187,7 +207,10 @@ def bisect_insert_strategy(sesh, observations):
but in the optimal case it reduces the transactions to a constant
1.
"""
log.debug("Begin mass observation insertion", extra={"num_obs": len(observations)})
log.debug(
"Begin mass observation insertion",
extra={"num_obs": len(observations), "database": sanitize_connection(sesh)},
)

# Base cases
if len(observations) < 1:
Expand All @@ -198,7 +221,13 @@ def bisect_insert_strategy(sesh, observations):
else:
try:
with sesh.begin_nested():
log.debug("New SAVEPOINT", extra={"num_obs": len(observations)})
log.debug(
"New SAVEPOINT",
extra={
"num_obs": len(observations),
"database": sanitize_connection(sesh),
},
)
sesh.add_all(observations)
except IntegrityError:
log.debug("Failed, splitting observations.")
Expand All @@ -211,15 +240,21 @@ def bisect_insert_strategy(sesh, observations):
else:
log.info(
f"Successfully inserted observations: {len(observations)}",
extra={"num_obs": len(observations)},
extra={
"num_obs": len(observations),
"database": sanitize_connection(sesh),
},
)
db_metrics = DBMetrics(len(observations), 0, 0)
sesh.commit()
return db_metrics


def chunk_bisect_insert_strategy(sesh, observations):
log.info("Using Chunk + Bisection Strategy")
log.info(
"Using Chunk + Bisection Strategy",
extra={"database": sanitize_connection(sesh)},
)
dbm = DBMetrics(0, 0, 0)
for chunk in bisection_chunks(observations):
dbm += bisect_insert_strategy(sesh, chunk)
Expand Down Expand Up @@ -269,7 +304,10 @@ def insert_bulk_obs(sesh, observations):
except DBAPIError as e:
# Something really unanticipated happened. Duplicate rows do not trigger an
# exception.
log.exception("Unexpected error during bulk insertion")
log.exception(
"Unexpected error during bulk insertion",
extra={"database": sanitize_connection(sesh)},
)
return DBMetrics(0, 0, num_to_insert)
sesh.commit()
num_inserted = len(result)
Expand All @@ -287,16 +325,22 @@ def bulk_insert_strategy(sesh, observations, chunk_size=1000):
:param chunk_size: Size of chunks.
:return: DMMetrics describing result of insertion
"""
log.info("Using Bulk Insert Strategy")
log.info(
"Using Bulk Insert Strategy", extra={"database": sanitize_connection(sesh)}
)
dbm = DBMetrics(0, 0, 0)
for chunk in fixed_length_chunks(observations, chunk_size=chunk_size):
chunk_dbm = insert_bulk_obs(sesh, chunk)
dbm += chunk_dbm
log.info(
f"Bulk insert progress: "
f"{dbm.successes} inserted, {dbm.skips} skipped, {dbm.failures} failed"
f"{dbm.successes} inserted, {dbm.skips} skipped, {dbm.failures} failed",
extra={"database": sanitize_connection(sesh)},
)
log.info(f"Successfully inserted observations: {dbm.successes}")
log.info(
f"Successfully inserted observations: {dbm.successes}",
extra={"database": sanitize_connection(sesh)},
)
return dbm


Expand Down
19 changes: 16 additions & 3 deletions crmprtd/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from crmprtd.constants import InsertStrategy
from crmprtd.align import align
from crmprtd.insert import insert
from crmprtd.insert import insert, sanitize_connection
from crmprtd.download_utils import verify_date
from crmprtd.infer import infer
from crmprtd import add_version_arg, add_logging_args, setup_logging, NETWORKS
Expand Down Expand Up @@ -205,7 +205,14 @@ def process(
sample_size=sample_size,
)
log.info("Insert: done")
log.info("Data insertion results", extra={"results": results, "network": network})
log.info(
"Data insertion results",
extra={
"results": results,
"network": network,
"database": sanitize_connection(sesh),
},
)


# Note: this function was buried in crmprtd.__init__.py but is
Expand Down Expand Up @@ -246,7 +253,13 @@ def run_data_pipeline(
results = insert(sesh, observations, sample_size)

log = logging.getLogger(__name__)
log.info("Data insertion results", extra={"results": results})
log.info(
"Data insertion results",
extra={
"results": results,
"database": sanitize_connection(sesh),
},
)


def main(args=None):
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,12 @@ def make_404(request):
text=station_listing,
status_code=200,
)


def records_contain_db_connection(test_session, caplog):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't tell you exactly why, but importing from conftest is not something that is done. Fixtures are defined in conftest.py, but not utility functions.

Maybe you could create this as a utility fixture as explained in this SO answer? I agree w/ the first commenter that it "feels a bit hacky", but there don't appear to be any obviously better options.

@rod-glover have you run across a requirement like this (sharing a test helper function across test modules) in your pytest trials?

A possible explanation of why it hasn't come up to date (to my knowledge) may be that we are encouraged to keep our unit tests concise and simple. And if your test conditions are so complicated that they require more logic in an external function, then maybe they need to be simplified.

There's an argument to be made for either approach. I think I'll be happy however you choose to proceed.

Copy link
Contributor

@rod-glover rod-glover Aug 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have, and I have used three methods: the hacky one, a slightly clunky one (A), and another maybe less clunky one (B).

Clunky A is to define a fixture that returns the helper function. Then use the fixture as normal. Best to scope such fixtures broadly, i.e., session scope.

Clunky B is to treat the test directory (or any subdirectory of it) as a package, with an __init__.py. Put helper functions there, and import them using relative imports. For example

from . import a_helper_function

from .. import  another_helper_function

Alternatively, create a module in such a package and import from it.

from .helpers import yet_another

Right now B is my preferred setup.

UPDATE: Should have read the SO answer first. It is a variation of clunky A, which reads in my code more like:

def helper_function:
    def f():
       #...

   return f

I don't bother with the Helpers class; that seems ... unweildy and unnecessary unless you are importing a ton of helper functions ... in which case you have to ask why so many.

Copy link
Contributor

@rod-glover rod-glover Aug 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A possible explanation of why it hasn't come up to date (to my knowledge) may be that we are encouraged to keep our unit tests concise and simple. And if your test conditions are so complicated that they require more logic in an external function, then maybe they need to be simplified.

In general, I like this prinicple. If you write simple functions/methods, and compose them in straightforward ways, then they are easier to understand, test and maintain. That said, some functions need somewhat complicated tests, and/or the same helper function is needed in several different places. So I treat this principle with a certain pragmatism.

Plus I do use helper functions fairly often. YMMV.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, great input!

for record in caplog.records:
if "database" in record.__dict__:
Copy link
Contributor

@jameshiebert jameshiebert Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be wrong here, but I believe that you can just use the idiom if "database" in record and do not have to specifically access the __dict__ attribute. See: https://docs.python.org/3/reference/expressions.html#membership-test-operations

One thing that you don't check here, that maybe you should is whether the record is found in the correct level of logging. E.g. you check that a log entry is found, but it's possible that it's at a higher level of logging then specified.

logged_db = getattr(record, "database", {})
if logged_db == test_session.bind.url.render_as_string(hide_password=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might be able to simplify this to:

return record.database == test_session.bind.url.render_as_string(hide_password=True)

return True
return False
Loading