Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
50 changes: 34 additions & 16 deletions crmprtd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,13 @@

from datetime import datetime, timezone

from importlib import import_module
from importlib.metadata import version
from importlib.resources import files
from collections import namedtuple

import dateutil

from crmprtd.argparse_helpers import OneAndDoneAction

SWOB_PARTNERS = (
Expand Down Expand Up @@ -295,6 +298,21 @@ def add_province_args(parser):
required=True,
)

def add_network_arg(parser):
""" Network is fundamental to how we process, so we always require it. """
parser.add_argument(
"-N",
"--network",
dest="network_name",
choices=NETWORKS + network_alias_names,
required=True,
help=(
"Network identifier (a network name or network alias) from which to "
"download observations. A network alias can stand for one or more "
"individual networks (e.g., 'ytnt' stands for many networks)."
),
)


def add_bulk_args(parser):
"""
Expand All @@ -311,15 +329,10 @@ def add_bulk_args(parser):
),
)
parser.add_argument(
"-N",
"--network",
choices=NETWORKS + network_alias_names,
required=True,
help=(
"Network identifier (a network name or network alias) from which to "
"download observations. A network alias can stand for one or more "
"individual networks (e.g., 'ytnt' stands for many networks)."
),
"--force",
dest="force",
action="store_true",
help="Continue processing if individual operations fail",
)


Expand All @@ -335,6 +348,7 @@ def add_time_range_args(parser):
"-S",
"--start_date",
dest="stime",
type=dateutil.parser.parse,
required=True,
help=(
"Start time (UTC) of range to process (format: '%%Y-%%m-%%d %%H:%%M:%%S'). "
Expand All @@ -345,10 +359,11 @@ def add_time_range_args(parser):
"-E",
"--end_date",
dest="etime",
default=datetime.now(timezone.utc),
type=dateutil.parser.parse,
help=(
"End time (UTC) of range to process (format: '%%Y-%%m-%%d %%H:%%M:%%S'). "
"Interpreted with strptime and rounded to the nearest hour. Defaults to the current UTC time."
"Interpreted with strptime and rounded to the nearest hour. "
"Defaults to start time if not provided."
),
)
parser.add_argument(
Expand All @@ -361,14 +376,17 @@ def add_time_range_args(parser):
)


def ensure_log_directory(log_filename):
def ensure_directory(filename: str):
"""
Ensure the log directory exists.

Args:
log_filename: The path to the log file.
"""
if log_filename:
log_dir = os.path.dirname(log_filename)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
if filename:
dir_path = os.path.dirname(filename)
if dir_path and not os.path.exists(dir_path):
os.makedirs(dir_path, exist_ok=True)

def get_defaults_module(network: str):
return import_module(f"crmprtd.networks.{network}.defaults")
83 changes: 45 additions & 38 deletions crmprtd/bulk_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,50 @@
import os
import sys
import logging
import pytz
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo
from datetime import datetime, timedelta, timezone
from argparse import ArgumentParser
from time import sleep
from importlib.resources import files

# Import from crmprtd
from crmprtd import (
add_logging_args,
add_network_arg,
add_province_args,
get_defaults_module,
setup_logging,
add_bulk_args,
add_time_range_args,
ensure_log_directory,
ensure_directory,
network_alias_names,
network_aliases,
)
from crmprtd.download_cache_process import (
main as download_cache_process_main,
describe_network,
default_cache_filename,
)


def process(current_time, opts, args):
network_defaults = get_defaults_module(opts.network_name)

# Generate cache filename if directory specified
cache_filename = None
if opts.directory:

cache_filename = default_cache_filename(
cache_filename = network_defaults.default_cache_filename(
timestamp=current_time,
network_name=opts.network,
tag=opts.tag,
frequency=opts.frequency if opts.network == "ec" else None,
province=opts.province if opts.network == "ec" else None,
**opts
)

# Replace the default cache directory with the user-specified directory
cache_filename = cache_filename.replace(
f"~/{opts.network}/cache/", f"{opts.directory}/{opts.network}/cache/"
f"~/{opts.network_name}/cache/", f"{opts.directory}/{opts.network_name}/cache/"
)

# ensure directory exists
cache_dir = os.path.dirname(cache_filename)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
ensure_directory(os.path.dirname(cache_filename))

# Build argument list for download_cache_process main function
# Start with base arguments from the original args list
Expand All @@ -60,7 +58,7 @@ def process(current_time, opts, args):
fun_args = [
*base_args,
"--network",
opts.network,
opts.network_name,
# "--log_conf",
# opts.log_conf,
"--log_filename",
Expand All @@ -72,11 +70,11 @@ def process(current_time, opts, args):
]

# Add frequency if provided (only for EC network, as other networks don't use it)
if opts.network == "ec" and opts.frequency:
if opts.network_name == "ec" and opts.frequency:
fun_args.extend(["--frequency", opts.frequency])

# Add province if provided
if opts.network == "ec" and opts.province:
if opts.network_name == "ec" and opts.province:
fun_args.extend(["--province", opts.province])

# Add tag if provided
Expand All @@ -99,9 +97,18 @@ def run(opts, args):
Main function to run bulk pipeline operations using download_cache_process
for time ranges with specified frequency
"""
# Create log directory if it doesn't exist
if opts.log_filename:
ensure_log_directory(opts.log_filename)

network_defaults = get_defaults_module(opts.network_name)

if not opts.log_filename:
opts.log_filename = network_defaults.default_log_filename(
network_name=opts.network_name,
tag=opts.tag,
frequency=opts.frequency,
province=opts.province if opts.network_name == "ec" else None,
)

ensure_directory(opts.log_filename)

# Setup logging first
setup_logging(
Expand All @@ -115,7 +122,7 @@ def run(opts, args):
log = logging.getLogger("crmprtd")

log.info(f"Parsed opts: {opts}")
log.info(f"Network description: {describe_network(opts.network)}")
log.info(f"Network description: {describe_network(opts.network_name)}")

try:
stime = datetime.strptime(opts.stime, "%Y-%m-%d %H:%M:%S")
Expand Down Expand Up @@ -145,13 +152,13 @@ def run(opts, args):
log.info(f"Processing time: {iter_time_str}")

try:
if opts.network == "ec" and not opts.province:
if opts.network_name == "ec" and not opts.province:
log.error(
"For network 'ec', province must be specified using --province option"
)
raise ValueError("Province must be specified for EC network")
# Process each province if specified
if opts.network == "ec":
if opts.network_name == "ec":
for p in opts.province:
log.info(f"Processing province: {p}")
copts = copy.copy(opts)
Expand Down Expand Up @@ -195,6 +202,12 @@ def main():
"downloads are desired, omit inclusion of connection string arguments."
)

add_network_arg(parser)

opts, args = parser.parse_known_args(sysargs)

network_defaults = get_defaults_module(opts.network_name)

add_logging_args(parser)
add_bulk_args(parser)
add_time_range_args(parser)
Expand All @@ -205,13 +218,8 @@ def main():
dest="tag",
help="Tag to include in cache and log filenames",
)

# Control options
parser.add_argument(
"--force",
dest="force",
action="store_true",
help="Continue processing if individual operations fail",
)
parser.add_argument(
"--delay",
dest="delay",
Expand All @@ -229,31 +237,30 @@ def main():
# Set defaults
parser.set_defaults(
log_conf=default_log_conf,
log_filename="/tmp/crmp/bulk_pipeline.log",
log_filename=None,
log_level="INFO",
error_email="[email protected]",
etime=datetime.now(pytz.timezone("UTC")).strftime("%Y-%m-%d %H:%M:%S"),
etime=network_defaults.default_end_time(),
dry_run=False,
force=False,
delay=3,
)

opts, args = parser.parse_known_args(sysargs)

if opts.network == "ec":
# Additional arguments for specific networks, currently only EC so I've not migrated
# then to network defaults or similar yet.
if opts.network_name == "ec":
add_province_args(parser)
opts, args = parser.parse_known_args(sysargs)
# Normalize to lowercase.
opts.province = {p.lower() for p in opts.province}

# Validate arguments
if not opts.network:
parser.error("Network (-N/--network) is required")

if opts.network in network_alias_names:
for alias in network_aliases[opts.network]:
# Network aliases can represent multiple networks, so we loop through them here
if opts.network_name in network_alias_names:
for alias in network_aliases[opts.network_name]:
copts = copy.copy(opts)
copts.network = alias
copts.network_name = alias
run(copts, args)
else:
run(opts, args)
Expand Down
27 changes: 13 additions & 14 deletions crmprtd/bulk_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,26 @@
import sys

# Import the process function directly instead of main
from crmprtd.download_cache_process import default_log_filename
from crmprtd.process import main as process
from crmprtd import add_logging_args, setup_logging, add_bulk_args
from crmprtd import add_logging_args, ensure_directory, setup_logging, add_bulk_args


def run(opts, args):
"""
Main function to process multiple files in a directory using crmprtd.process with
optional pattern matching
"""
# Create log directory if it doesn't exist
if opts.log_filename:
log_dir = os.path.dirname(opts.log_filename)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)

if not opts.log_filename:
opts.log_filename = default_log_filename(
network_name=opts.network,
tag=opts.tag,
frequency=opts.frequency,
province=opts.province,
)

ensure_directory(opts.log_filename)

# Setup logging first
setup_logging(
Expand Down Expand Up @@ -170,13 +176,6 @@ def main():
default="*.xml",
help="File pattern to match in directory (default: *.xml)",
)
# Processing options
parser.add_argument(
"-f",
"--force",
action="store_true",
help="Continue processing remaining files if one fails",
)
parser.add_argument(
"-M",
"--move_processed",
Expand All @@ -194,7 +193,7 @@ def main():
parser.set_defaults(
connection_string="dbname=crmprtd user=crmp",
log_conf=default_log_conf,
log_filename="/tmp/crmp/bulk_process.log",
log_filename=None,
log_level="INFO",
error_email="[email protected]",
force=False,
Expand Down
Loading
Loading