Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions hewr/R/process_loc_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ read_and_combine_data <- function(model_dir) {
)

dat <-
tidyr::expand_grid(
tibble::tibble(
epiweekly = c(FALSE, TRUE),
root = c("combined_training_data", "combined_eval_data")
root = "combined_data",
) |>
dplyr::mutate(
prefix = ifelse(.data$epiweekly, "epiweekly_", ""),
Expand Down Expand Up @@ -543,22 +543,24 @@ process_loc_forecast <- function(
fs::path(
model_dir,
"data",
"combined_training_data",
"combined_data",
ext = "tsv"
),
col_types = data_col_types
)
) |>
dplyr::filter(.data$data_type == "train")

# Used for augmenting denominator forecasts with training period denominator
epiweekly_training_dat <- readr::read_tsv(
fs::path(
model_dir,
"data",
"epiweekly_combined_training_data",
"epiweekly_combined_data",
ext = "tsv"
),
col_types = data_col_types
)
) |>
dplyr::filter(.data$data_type == "train")

required_columns_e <- c(
".chain",
Expand Down Expand Up @@ -678,11 +680,12 @@ process_forecast <- function(
fs::path(
model_dir,
"data",
"epiweekly_combined_training_data",
"epiweekly_combined_data",
ext = "tsv"
),
col_types = data_col_types
)
) |>
dplyr::filter(.data$data_type == "train")

required_columns_e <- c(
".chain",
Expand Down
3 changes: 2 additions & 1 deletion hewr/R/timeseries_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#' @export
load_training_data <- function(
model_dir,
base_data_name = "combined_training_data",
base_data_name = "combined_data",
epiweekly = FALSE
) {
resolution <- dplyr::if_else(epiweekly, "epiweekly", "daily")
Expand All @@ -45,6 +45,7 @@ load_training_data <- function(
.value = readr::col_double()
)
) |>
dplyr::filter(.data$data_type == "train") |>
dplyr::select(-"lab_site_index") |>
dplyr::filter(stringr::str_ends(.data$.variable, "ed_visits")) |>
tidyr::pivot_wider(names_from = ".variable", values_from = ".value")
Expand Down
2 changes: 1 addition & 1 deletion hewr/man/load_training_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions pipelines/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
Functions to run Pyrenew-HEW forecasting pipeline. Uses the following data directories:

- facility-level-nssp-data-dir `nssp-etl/gold`
- state-level-nssp-data-dir `nssp-archival-vintages/gold`
- param-data-dir `prod_param_estimates`
- nwss-data-dir `nwss_vintages`

And the following files:

- eval-data-path `nssp-etl/latest_comprehensive.parquet`
- priors-path `pipelines/priors/prod_priors.py`
6 changes: 0 additions & 6 deletions pipelines/batch/setup_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,11 @@ def main(
"--loc {loc} "
f"--n-training-days {n_training_days} "
"--facility-level-nssp-data-dir nssp-etl/gold "
"--state-level-nssp-data-dir "
"nssp-archival-vintages/gold "
"--param-data-dir params "
"--output-dir {output_dir} "
"--credentials-path config/creds.toml "
"--report-date {report_date} "
f"--exclude-last-n-days {exclude_last_n_days} "
f"--model-letters {model_letters} "
"--eval-data-path "
"nssp-etl/latest_comprehensive.parquet "
f"{additional_args}"
"'"
)
Expand Down Expand Up @@ -344,7 +339,6 @@ def style_bool(val):
base_call=base_call.format(
loc=loc,
disease=disease,
report_date="latest",
output_dir=str(Path("output", output_subdir)),
),
container_settings=container_settings,
Expand Down
20 changes: 0 additions & 20 deletions pipelines/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,13 @@ def add_common_forecast_arguments(parser: argparse.ArgumentParser) -> None:
),
)

parser.add_argument(
"--report-date",
type=str,
default="latest",
help="Report date in YYYY-MM-DD format or latest (default: latest).",
)

parser.add_argument(
"--facility-level-nssp-data-dir",
type=Path,
default=Path("private_data", "nssp_etl_gold"),
help="Directory in which to look for facility-level NSSP ED visit data.",
)

parser.add_argument(
"--state-level-nssp-data-dir",
type=Path,
default=Path("private_data", "nssp_state_level_gold"),
help="Directory in which to look for state-level NSSP ED visit data.",
)

parser.add_argument(
"--param-data-dir",
type=Path,
Expand Down Expand Up @@ -92,12 +78,6 @@ def add_common_forecast_arguments(parser: argparse.ArgumentParser) -> None:
help="Path to a TOML file containing credentials such as API keys.",
)

parser.add_argument(
"--eval-data-path",
type=Path,
help="Path to a parquet file containing comprehensive truth data.",
)

parser.add_argument(
"--nhsn-data-path",
type=Path,
Expand Down
131 changes: 6 additions & 125 deletions pipelines/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from pathlib import Path
from typing import Any

import polars as pl

from pipelines.cli_utils import run_command


Expand Down Expand Up @@ -44,68 +42,6 @@ def get_available_reports(
]


def parse_and_validate_report_date(
report_date: str,
available_facility_level_reports: list[dt.date],
available_loc_level_reports: list[dt.date],
logger: logging.Logger,
) -> tuple[dt.date, dt.date | None]:
"""
Parse and validate report date, determine location-level report date to use.

Parameters
----------
report_date : str
Report date as string ("latest" or "YYYY-MM-DD" format).
available_facility_level_reports : list[dt.date]
List of available facility-level report dates.
available_loc_level_reports : list[dt.date]
List of available location-level report dates.
logger : logging.Logger
Process logger.

Returns
-------
tuple[dt.date, dt.date | None]
Tuple of (report_date, loc_report_date).

Raises
------
ValueError
If report date is invalid or data is missing.
"""
first_available_loc_report = min(available_loc_level_reports)
last_available_loc_report = max(available_loc_level_reports)

if report_date == "latest":
report_date = max(available_facility_level_reports)
else:
report_date = dt.datetime.strptime(report_date, "%Y-%m-%d").date()

if report_date in available_loc_level_reports:
loc_report_date = report_date
elif report_date > last_available_loc_report:
loc_report_date = last_available_loc_report
elif report_date > first_available_loc_report:
raise ValueError(
"Dataset appear to be missing some state-level "
f"reports. First entry is {first_available_loc_report}, "
f"last is {last_available_loc_report}, but no entry "
f"for {report_date}"
)
else:
raise ValueError(
"Requested report date is earlier than the first "
"state-level vintage. This is not currently supported"
)

logger.info(f"Report date: {report_date}")
if loc_report_date is not None:
logger.info(f"Using location-level data as of: {loc_report_date}")

return report_date, loc_report_date


def _parse_single_date(date_str: str) -> tuple[dt.date, dt.date]:
"""
Parse a single date string into a date range tuple.
Expand Down Expand Up @@ -258,67 +194,6 @@ def calculate_training_dates(
return first_training_date, last_training_date


def load_nssp_data(
report_date: dt.date,
loc_report_date: dt.date | None,
available_facility_level_reports: list[dt.date],
available_loc_level_reports: list[dt.date],
facility_level_nssp_data_dir: Path,
state_level_nssp_data_dir: Path,
logger: logging.Logger,
) -> tuple[pl.LazyFrame | None, pl.LazyFrame | None]:
"""
Load facility-level and location-level NSSP data.

Parameters
----------
report_date : dt.date
The report date.
loc_report_date : dt.date | None
The location-level report date to use.
available_facility_level_reports : list[dt.date]
List of available facility-level report dates.
available_loc_level_reports : list[dt.date]
List of available location-level report dates.
facility_level_nssp_data_dir : Path
Directory containing facility-level NSSP data.
state_level_nssp_data_dir : Path
Directory containing state-level NSSP data.
logger : logging.Logger
Logger for informational messages.

Returns
-------
tuple[pl.LazyFrame | None, pl.LazyFrame | None]
Tuple of (facility_level_nssp_data, loc_level_nssp_data).

Raises
------
ValueError
If no data is available for the requested report date.
"""
facility_level_nssp_data, loc_level_nssp_data = None, None

if report_date in available_facility_level_reports:
logger.info("Facility level data available for the given report date")
facility_datafile = f"{report_date}.parquet"
facility_level_nssp_data = pl.scan_parquet(
Path(facility_level_nssp_data_dir, facility_datafile)
)
if loc_report_date in available_loc_level_reports:
logger.info("location-level data available for the given report date.")
loc_datafile = f"{loc_report_date}.parquet"
loc_level_nssp_data = pl.scan_parquet(
Path(state_level_nssp_data_dir, loc_datafile)
)
if facility_level_nssp_data is None and loc_level_nssp_data is None:
raise ValueError(
f"No data available for the requested report date {report_date}"
)

return facility_level_nssp_data, loc_level_nssp_data


def run_r_script(
script_name: str,
args: list[str],
Expand Down Expand Up @@ -489,6 +364,12 @@ def plot_and_save_loc_forecast(
return None


def py_scalar_to_r_scalar(py_scalar):
if py_scalar is None:
return "NULL"
return f"'{str(py_scalar)}'"


def create_hubverse_table(model_fit_path: Path) -> None:
"""Create hubverse table from model fit using R script.

Expand Down
Loading