diff --git a/.gitignore b/.gitignore index 4a5c204..eb35b02 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ tests/__pycache__/ tests/data/ tests/*.txt docs/examples/*/forcings +docs/examples/*/restart *.txt \ No newline at end of file diff --git a/README.md b/README.md index 19035cd..08e24c0 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,9 @@ where the list of `nwm-id`s are the NHD reaches associated with that NextGen hyd | nwm_file | Path to a text file containing nwm file names. One filename per line. [Tool](#nwm_file) to create this file | :white_check_mark: | | gpkg_file | Geopackage file to define spatial domain. Use [hfsubset](https://github.com/lynker-spatial/hfsubsetCLI) to generate a geopackage with a `forcing-weights` layer. Accepts local absolute path, s3 URI or URL. Also acceptable is a weights parquet generated with [weights_hf2ds.py](https://github.com/CIROH-UA/forcingprocessor/blob/main/src/forcingprocessor/weights_hf2ds.py), though the plotting option will no longer be available. | :white_check_mark: | | map_file | Path to a json containing the NWM to NGEN mapping for channel routing data extraction. Absolute path or s3 URI | | +| restart_map_file | Path to a json containing the NWM to NGEN catchment mapping for t-route restart generation. Absolute path or s3 URI | | +| crosswalk_file | Path to a netCDF containing the exact order of the catchments in the t-route restart file. Absolute path or s3 URI | | +| routelink_file | Path to a netCDF containing the NWM channel geometry data, needed for t-route restart generation. Absolute path or s3 URI | | ### 2. Storage diff --git a/docs/examples/troute-restart_example/RouteLink_CONUS_subset.nc b/docs/examples/troute-restart_example/RouteLink_CONUS_subset.nc new file mode 100644 index 0000000..9bc0188 Binary files /dev/null and b/docs/examples/troute-restart_example/RouteLink_CONUS_subset.nc differ diff --git a/docs/examples/troute-restart_example/channel_restart_20250718_000000.nc b/docs/examples/troute-restart_example/channel_restart_20250718_000000.nc new file mode 100644 index 0000000..96916ea Binary files /dev/null and b/docs/examples/troute-restart_example/channel_restart_20250718_000000.nc differ diff --git a/docs/examples/troute-restart_example/conf.json b/docs/examples/troute-restart_example/conf.json new file mode 100644 index 0000000..5640738 --- /dev/null +++ b/docs/examples/troute-restart_example/conf.json @@ -0,0 +1,20 @@ +{ + "forcing" : { + "nwm_file" : "filenamelist.txt", + "restart_map_file" : "hf2.2_subset_cat_map.json", + "crosswalk_file" : "crosswalk_subset.nc", + "routelink_file" : "RouteLink_CONUS_subset.nc" + }, + + "storage":{ + "storage_type" : "local", + "output_path" : "./restart", + "output_file_type" : ["netcdf"] + }, + + "run" : { + "verbose" : true, + "collect_stats" : true, + "nprocs" : 8 + } +} \ No newline at end of file diff --git a/docs/examples/troute-restart_example/crosswalk_subset.nc b/docs/examples/troute-restart_example/crosswalk_subset.nc new file mode 100644 index 0000000..f4788a2 Binary files /dev/null and b/docs/examples/troute-restart_example/crosswalk_subset.nc differ diff --git a/docs/examples/troute-restart_example/filenamelist.txt b/docs/examples/troute-restart_example/filenamelist.txt new file mode 100644 index 0000000..8394e8a --- /dev/null +++ b/docs/examples/troute-restart_example/filenamelist.txt @@ -0,0 +1 @@ +https://noaa-nwm-pds.s3.amazonaws.com/nwm.20250718/analysis_assim/nwm.t00z.analysis_assim.channel_rt.tm00.conus.nc \ No newline at end of file diff --git a/docs/examples/troute-restart_example/hf2.2_subset_cat_map.json b/docs/examples/troute-restart_example/hf2.2_subset_cat_map.json new file mode 100644 index 0000000..1e2e74b --- /dev/null +++ b/docs/examples/troute-restart_example/hf2.2_subset_cat_map.json @@ -0,0 +1,18 @@ +{ + "cat-20469": [ + 166196253.0, + 166196257.0, + 166196256.0 + ], + "cat-20479": [ + 4599061.0, + 4599715.0, + 166196261.0 + ], + "cat-19499": [ + 25068048.0 + ], + "cat-19630": [ + 25077314.0 + ] +} \ No newline at end of file diff --git a/src/forcingprocessor/processor.py b/src/forcingprocessor/processor.py index 0bb21a2..5ec9c72 100644 --- a/src/forcingprocessor/processor.py +++ b/src/forcingprocessor/processor.py @@ -20,6 +20,7 @@ from forcingprocessor.plot_forcings import plot_ngen_forcings from forcingprocessor.utils import make_forcing_netcdf, get_window, log_time, convert_url2key, report_usage, nwm_variables, ngen_variables from forcingprocessor.channel_routing_tools import channelrouting_nwm2ngen, write_netcdf_chrt +from forcingprocessor.troute_restart_tools import create_restart, write_netcdf_restart B2MB = 1048576 @@ -790,6 +791,9 @@ def prep_ngen_data(conf): else: gpkg_files = gpkg_file map_file_path = conf['forcing'].get("map_file",None) + restart_map_file_path = conf['forcing'].get("restart_map_file", None) + crosswalk_file_path = conf['forcing'].get("crosswalk_file", None) + routelink_file_path = conf['forcing'].get("routelink_file", None) if map_file_path: # NWM to NGEN channel routing processing requires json map data_source = "channel_routing" if "s3://" in map_file_path: @@ -799,6 +803,30 @@ def prep_ngen_data(conf): else: with open(map_file_path, "r", encoding="utf-8") as map_file: full_nwm_ngen_map = json.load(map_file) + elif restart_map_file_path: + data_source = "troute_restarts" + + if "s3://" in restart_map_file_path: + s3 = s3fs.S3FileSystem(anon=True) + with s3.open(restart_map_file_path, "r") as map_file: + cat_map = json.load(map_file) + else: + with open(restart_map_file_path, "r", encoding="utf-8") as map_file: + cat_map = json.load(map_file) + + if "s3://" in crosswalk_file_path: + s3 = s3fs.S3FileSystem(anon=True) + with s3.open(crosswalk_file_path, "rb") as crosswalk_file: + crosswalk_ds = xr.open_dataset(crosswalk_file) + else: + crosswalk_ds = xr.open_dataset(crosswalk_file_path) + + if "s3://" in routelink_file_path: + s3 = s3fs.S3FileSystem(anon=True) + with s3.open(routelink_file_path, "rb") as routelink_file: + routelink_ds = xr.open_dataset(routelink_file) + else: + routelink_ds = xr.open_dataset(routelink_file_path) else: data_source = "forcings" @@ -815,8 +843,8 @@ def prep_ngen_data(conf): global ii_plot, nts_plot, ngen_vars_plot ii_plot = conf.get("plot",False) if ii_plot: - if data_source == "channel_routing": - raise RuntimeError("Plotting not supported for channel routing processing.") + if data_source == "channel_routing" or data_source == "troute_restarts": + raise RuntimeError("Plotting not supported for channel routing or restart processing.") nts_plot = conf["plot"].get("nts_plot",10) ngen_vars_plot = conf["plot"].get("ngen_vars",ngen_variables) @@ -880,7 +908,7 @@ def prep_ngen_data(conf): window = [x_max, x_min, y_max, y_min] weight_time = time.perf_counter() - tw log_time("CALC_WINDOW_END", log_file) - else: + elif data_source == "channel_routing": log_time("READMAP_START", log_file) tw = time.perf_counter() if ii_verbose: @@ -893,6 +921,8 @@ def prep_ngen_data(conf): nwm_ngen_map[jcatch] = full_nwm_ngen_map[jcatch] ncatchments = len(nwm_ngen_map) log_time("READMAP_END", log_file) + else: + ncatchments = 1 log_time("STORE_METADATA_START", log_file) global forcing_path @@ -903,6 +933,8 @@ def prep_ngen_data(conf): output_path = Path(output_path) if data_source == "channel_routing": forcing_path = Path(output_path, 'outputs', 'ngen') + elif data_source == "troute_restarts": + forcing_path = Path(output_path, 'restart') else: forcing_path = Path(output_path, 'forcings') meta_path = Path(output_path, 'metadata') @@ -950,8 +982,12 @@ def prep_ngen_data(conf): # s3://noaa-nwm-pds/nwm.20241029/forcing_short_range/nwm.t00z.short_range.forcing.f001.conus.nc if data_source == "forcings": pattern = r"nwm\.(\d{8})/forcing_(\w+)/nwm\.(\w+)(\d{2})z\.\w+\.forcing\.(\w+)(\d{2})\.conus\.nc" - else: + elif data_source == "channel_routing": pattern = r"nwm\.(\d{8})/(\w+)/nwm\.(\w+)(\d{2})z\.\w+\.channel_rt[^\.]*\.(\w+)(\d{2})\.conus\.nc" + else: + #s3://noaa-nwm-pds/nwm.20241029/analysis_assim/nwm.t16z.analysis_assim.channel_rt.tm00.conus.nc + pattern = r"nwm\.(\d{8})/analysis_assim/nwm\.t(\d{2})z\.analysis_assim\.channel_rt\.tm00\.conus\.nc" + pass # Extract forecast cycle and lead time from the first and last file names global URLBASE, FCST_CYCLE, LEAD_START, LEAD_END @@ -959,17 +995,24 @@ def prep_ngen_data(conf): FCST_CYCLE=None LEAD_START=None LEAD_END=None - if match: - URLBASE = match.group(2) - FCST_CYCLE = match.group(3) + match.group(4) - LEAD_START = match.group(5) + match.group(6) - else: - print(f"Could not extract forecast cycle and lead start from the first NWM forcing file: {nwm_forcing_files[0]}") - match = re.search(pattern, nwm_forcing_files[-1]) - if match: - LEAD_END = match.group(5) + match.group(6) + if data_source != "troute_restarts": + if match: + URLBASE = match.group(2) + FCST_CYCLE = match.group(3) + match.group(4) + LEAD_START = match.group(5) + match.group(6) + else: + print(f"Could not extract forecast cycle and lead start from the first NWM forcing file: {nwm_forcing_files[0]}") + match = re.search(pattern, nwm_forcing_files[-1]) + if match: + LEAD_END = match.group(5) + match.group(6) + else: + print(f"Could not extract lead end from the last NWM forcing file: {nwm_forcing_files[-1]}") else: - print(f"Could not extract lead end from the last NWM forcing file: {nwm_forcing_files[-1]}") + if match: + restart_date = match.group(1) + restart_hour = match.group(2) + else: + print("Could not extract restart date and time") # Determine the file system type based on the first NWM forcing file global fs_type @@ -998,24 +1041,57 @@ def prep_ngen_data(conf): # data_array=data_array[0][None,:] # t_ax = t_ax # nwm_data=nwm_data[0][None,:] - if data_source == "forcings": - data_array, t_ax, nwm_data, nwm_file_sizes_MB = multiprocess_data_extract(nwm_forcing_files,nprocs,weights_df,fs) + if data_source == "forcings" or data_source == "channel_routing": + if data_source == "forcings": + data_array, t_ax, nwm_data, nwm_file_sizes_MB = multiprocess_data_extract(nwm_forcing_files,nprocs,weights_df,fs) + else: + data_array, t_ax, nwm_file_sizes_MB = multiprocess_chrt_extract( + nwm_forcing_files,nprocs,nwm_ngen_map,fs) + + if datetime.strptime(t_ax[0],'%Y-%m-%d %H:%M:%S') > datetime.strptime(t_ax[-1],'%Y-%m-%d %H:%M:%S'): + # Hack to ensure data is always written out with time moving forward. + t_ax=list(reversed(t_ax)) + data_array = np.flip(data_array,axis=0) + tmp = LEAD_START + LEAD_START = LEAD_END + LEAD_END = tmp + + t_extract = time.perf_counter() - t0 + complexity = (nfiles * ncatchments) / 10000 + score = complexity / t_extract + if ii_verbose: print(f'Data extract processs: {nprocs:.2f}\nExtract time: {t_extract:.2f}\nComplexity: {complexity:.2f}\nScore: {score:.2f}\n', end=None,flush=True) + else: - data_array, t_ax, nwm_file_sizes_MB = multiprocess_chrt_extract( - nwm_forcing_files,nprocs,nwm_ngen_map,fs) - - if datetime.strptime(t_ax[0],'%Y-%m-%d %H:%M:%S') > datetime.strptime(t_ax[-1],'%Y-%m-%d %H:%M:%S'): - # Hack to ensure data is always written out with time moving forward. - t_ax=list(reversed(t_ax)) - data_array = np.flip(data_array,axis=0) - tmp = LEAD_START - LEAD_START = LEAD_END - LEAD_END = tmp - - t_extract = time.perf_counter() - t0 - complexity = (nfiles * ncatchments) / 10000 - score = complexity / t_extract - if ii_verbose: print(f'Data extract processs: {nprocs:.2f}\nExtract time: {t_extract:.2f}\nComplexity: {complexity:.2f}\nScore: {score:.2f}\n', end=None,flush=True) + nwm_file = nwm_forcing_files[0] + nwm_file_sizes_MB = [] + if fs_type == 'google': + fs_arg = gcsfs.GCSFileSystem() + elif fs_type == 's3': + fs_arg = s3fs.S3FileSystem(anon=True) + else: + fs_arg = None + if fs_arg: + if nwm_file.find('https://') >= 0: + _, bucket_key = convert_url2key(nwm_file,fs_type) + else: + bucket_key = nwm_file + file_obj = fs_arg.open(bucket_key, mode='rb') + nwm_file_sizes_MB.append(file_obj.details['size']) + elif 'https://' in nwm_file: + response = requests.get(nwm_file, timeout=10) + + if response.status_code == 200: + file_obj = BytesIO(response.content) + else: + raise RuntimeError(f"{nwm_file} does not exist") + nwm_file_sizes_MB.append(len(response.content) / B2MB) + else: + file_obj = nwm_file + nwm_file_sizes_MB.append(os.path.getsize(nwm_file) / B2MB) + + with xr.open_dataset(file_obj) as nwm_ds: + data_array = create_restart(cat_map, crosswalk_ds, nwm_ds, routelink_ds) + log_time("PROCESSING_END", log_file) log_time("FILEWRITING_START", log_file) @@ -1023,22 +1099,30 @@ def prep_ngen_data(conf): if "netcdf" in output_file_type: if data_source == "forcings": netcdf_cat_file_sizes_MB = multiprocess_write_netcdf(data_array, jcatchment_dict, t_ax) - else: + elif data_source == "channel_routing": if FCST_CYCLE is None: filename = 'qlaterals.nc' else: filename = f'ngen.{FCST_CYCLE}z.{URLBASE}.channel_routing.{LEAD_START}_{LEAD_END}.nc' netcdf_cat_file_sizes_MB = write_netcdf_chrt( storage_type, forcing_path, data_array, t_ax, filename) + else: + filename = "channel_restart_" + restart_date + "_" + restart_hour + "0000.nc" + netcdf_cat_file_sizes_MB = write_netcdf_restart( + storage_type, forcing_path, data_array, filename + ) # write_netcdf(data_array,"1", t_ax, jcatchment_dict['1']) if ii_verbose: print(f'Writing catchment forcings to {output_path}!', end=None,flush=True) if ii_plot or ii_collect_stats or any(x in output_file_type for x in ["csv","parquet","tar"]): if data_source == "forcings": forcing_cat_ids, filenames, individual_cat_file_sizes_MB, individual_cat_file_sizes_MB_zipped, tar_buffs = multiprocess_write_df( data_array,t_ax,list(weights_df.index),nprocs,forcing_path,data_source) - else: + elif data_source == "channel_routing": forcing_cat_ids, filenames, individual_cat_file_sizes_MB, individual_cat_file_sizes_MB_zipped, tar_buffs = multiprocess_write_df( data_array,t_ax,list(nwm_ngen_map.keys()),nprocs,forcing_path,data_source) + else: + print("Dataframes don't get written for t-route restarts") + write_time += time.perf_counter() - t0 write_rate = ncatchments / write_time if ii_verbose: print(f'\n\nWrite processs: {nprocs}\nWrite time: {write_time:.2f}\nWrite rate {write_rate:.2f} files/second\n', end=None,flush=True) @@ -1118,7 +1202,7 @@ def prep_ngen_data(conf): "netcdf_catch_file_size_med_MB" : [netcdf_catch_file_size_med], "netcdf_catch_file_size_std_MB" : [netcdf_catch_file_size_std] } - else: + elif data_source == "channel_routing": metadata = { "runtime_s" : [round(runtime,2)], "nvars_intput" : [1], @@ -1138,6 +1222,14 @@ def prep_ngen_data(conf): "netcdf_catch_file_size_med_MB" : [netcdf_catch_file_size_med], "netcdf_catch_file_size_std_MB" : [netcdf_catch_file_size_std] } + else: + # metadata for troute restart gen + metadata = { + "runtime_s" : [round(runtime,2)], + "nwmfiles_input" : [len(nwm_forcing_files)], + "nwm_file_size" : [nwm_file_size_avg], + "netcdf_catch_file_size_MB" : [netcdf_catch_file_size_avg], + } if data_source == "forcings": data_avg = np.average(data_array,axis=0) @@ -1147,7 +1239,7 @@ def prep_ngen_data(conf): data_med = np.median(data_array,axis=0) med_df = pd.DataFrame(data_med.T,columns=ngen_variables) med_df.insert(0,"catchment id",forcing_cat_ids) - else: + elif data_source == "channel_routing": data_avg = np.average(data_array[:,:,1],axis=0) avg_df = pd.DataFrame(data_avg.T, columns=['q_lateral']) avg_df.insert(0,"nexus id",list(nwm_ngen_map.keys())) @@ -1155,6 +1247,10 @@ def prep_ngen_data(conf): data_med = np.median(data_array[:,:,1],axis=0) med_df = pd.DataFrame(data_med.T,columns=['q_lateral']) med_df.insert(0,"nexus id",list(nwm_ngen_map.keys())) + else: + # troute restarts won't need stats calculated for them since there's no time axis + avg_df = pd.DataFrame() + med_df = pd.DataFrame() del data_array @@ -1171,8 +1267,10 @@ def prep_ngen_data(conf): local_metapath = metaf_path write_df(metadata_df, "metadata.csv", storage_type, data_source_arg="na", local_path=local_metapath, key_prefix=meta_key, bucket=meta_bucket, client=s3) - write_df(avg_df, "catchments_avg.csv", storage_type, data_source_arg="na", local_path=local_metapath, key_prefix=meta_key, bucket=meta_bucket, client=s3) - write_df(med_df, "catchments_median.csv", storage_type, data_source_arg="na", local_path=local_metapath, key_prefix=meta_key, bucket=meta_bucket, client=s3) + if not avg_df.empty: + write_df(avg_df, "catchments_avg.csv", storage_type, data_source_arg="na", local_path=local_metapath, key_prefix=meta_key, bucket=meta_bucket, client=s3) + if not med_df.empty: + write_df(med_df, "catchments_median.csv", storage_type, data_source_arg="na", local_path=local_metapath, key_prefix=meta_key, bucket=meta_bucket, client=s3) meta_time = time.perf_counter() - t000 log_time("METADATA_END", log_file) diff --git a/src/forcingprocessor/troute_restart_tools.py b/src/forcingprocessor/troute_restart_tools.py new file mode 100644 index 0000000..922f5d0 --- /dev/null +++ b/src/forcingprocessor/troute_restart_tools.py @@ -0,0 +1,257 @@ +""" +Tools to extract and write streamflow and depth values into a restart format ingestible by +t-route. Translates between NWM and NGEN IDs! +""" + +from pathlib import Path +import tempfile +import os +import xarray as xr +import numpy as np +import pandas as pd +import boto3 +from forcingprocessor.utils import convert_url2key + +B2MB = 1048576 + + +def average_nwm_variables( + nwm_ids_flat: np.ndarray, cat_ids_flat: np.ndarray, nwm_ds: xr.Dataset +) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Vectorized averaging calculations for NWM data + + Parameters: + - nwm_ids_flat: array of NWM ids (np.ndarray) + - cat_ids_flat: array of NextGen cat-ids (np.ndarray) + - nwm_ds: NWM analysis/assimilation data (xr.Dataset) + + Returns: + - nwm_agg: averaged NWM data (pd.DataFrame) + - mapping_df: DataFrame version of flat maps (pd.DataFrame) + """ + + # --- NWM dataset: streamflow and velocity --- + # Filter nwm_ds to only feature_ids we care about + valid_mask = np.isin(nwm_ds["feature_id"].values, nwm_ids_flat) + nwm_sub = nwm_ds.isel(feature_id=valid_mask) + + # Build a df with feature_id -> cat_id, then merge with nwm values + mapping_df = pd.DataFrame({"feature_id": nwm_ids_flat, "cat_id": cat_ids_flat}) + + nwm_df = pd.DataFrame( + { + "feature_id": nwm_sub["feature_id"].values, + "streamflow": nwm_sub["streamflow"].values, + "velocity": nwm_sub["velocity"].values, + } + ) + + nwm_df = nwm_df.merge(mapping_df, on="feature_id", how="left") + nwm_agg = nwm_df.groupby("cat_id")[["streamflow", "velocity"]].mean() + + return nwm_agg, mapping_df + + +def average_rtlink_variables( + nwm_ids_flat: np.ndarray, mapping_df: pd.DataFrame, routelink_ds: xr.Dataset +) -> pd.DataFrame: + """ + Vectorized averaging calculations for RouteLink data + + Parameters + - nwm_ids_flat: array of NWM ids (np.ndarray) + - mapping_df: dataframe of flat map (pd.DataFrame) + - routelink_ds: NWM RouteLink data (xr.Dataset) + + Returns: + - rl_agg: averaged NWM RouteLink channel geometry data (pd.DataFrame) + """ + + # --- Routelink dataset: TopWdth, BtmWdth, ChSlp --- + valid_mask_rl = np.isin(routelink_ds["link"].values, nwm_ids_flat) + rl_sub = routelink_ds.isel(feature_id=valid_mask_rl) + + rl_df = pd.DataFrame( + { + "feature_id": rl_sub["link"].values, + "TopWdth": rl_sub["TopWdth"].values, + "BtmWdth": rl_sub["BtmWdth"].values, + "ChSlp": rl_sub["ChSlp"].values, + } + ) + + rl_df = rl_df.merge(mapping_df, on="feature_id", how="left") + rl_agg = rl_df.groupby("cat_id")[["TopWdth", "BtmWdth", "ChSlp"]].mean() + + return rl_agg + + +def quadratic_formula(b_coeff: np.ndarray, c_coeff: np.ndarray) -> np.ndarray: + """ + Vectorized quadratic formula solver (assumes no a coefficient). Only returns positive root + + Parameters: + - b_coeff: np.ndarray + - c_coeff: np.ndarray + + Returns: + - h_positive: positive root (np.ndarray) + """ + discriminant = b_coeff**2 - 4 * c_coeff + h_positive = (-b_coeff + np.sqrt(discriminant)) / 2 + + return h_positive + + +def solve_depth_geom( + streamflow: np.ndarray, + velocity: np.ndarray, + tw: np.ndarray, + bw: np.ndarray, + cs: np.ndarray, +) -> np.ndarray: + """ + Solves for depth h using CHRTOUT file variables and channel geometry variables. + + Parameters: + - streamflow: Streamflow from CHRTOUT file. (m^3/s) + - velocity: Velocity from CHRTOUT file. (m/s) + - tw: Top width of the main channel. (m) + - bw: Bottom width of the main channel. (m) + - cs: Channel slope (dimensionless). + + Returns: + - h: Initial depth that achieves the target flow rate, or NaN if no solution is found. + """ + + area = streamflow / velocity # cross-sectional area of initial flow + area = np.where(np.isnan(area), 0, area) # set NaN areas to 0 + area = np.where(np.isinf(area), 0, area) # set infinite areas to 0 + + db = (cs * (tw - bw)) / 2 # bankfull depth + area_bankfull = (tw + bw) / 2 * db # cross-sectional area at bankfull conditions + # assume trapezoidal main channel + + depths = np.zeros_like(area) # initialize depths array with 0 values + + above_bankfull = area >= area_bankfull + area_flood = area[above_bankfull] - area_bankfull[above_bankfull] + df = area_flood / (tw[above_bankfull] * 3) + depths[above_bankfull] = db[above_bankfull] + df + + # Below bankfull - solve quadratic formula directly (vectorized) + below_bankfull = ~above_bankfull & (area > 0) + + # Quadratic: h^2 + cs*bw*h - cs*area = 0 + # Using formula: h = (-b + sqrt(b^2 + 4*c)) / 2, where a=1 + h_positive = quadratic_formula( + cs[below_bankfull] * bw[below_bankfull], + -cs[below_bankfull] * area[below_bankfull], + ) + depths[below_bankfull] = h_positive + + return depths + + +def create_restart( + cat_map_temp: dict, + crosswalk_ds: xr.Dataset, + nwm_ds: xr.Dataset, + routelink_ds: xr.Dataset, +) -> xr.Dataset: + """ + Creates t-route restart file. + + Parameters: + - cat_map: NGEN to NWM catchment json file (dict) + - crosswalk_ds: "crosswalk" NetCDF file that has all the + NextGen catchments in the order that the restart file will have them in (xr.Dataset) + - nwm_ds: NWM analysis/assimilation NetCDF (xr.Dataset) + - routelink_ds: NWM RouteLink channel geometry NetCDF (xr.Dataset) + + Returns: + - restart: t-route ingestible restart file (xr.Dataset) + """ + cat_map_temp = { + k[4:]: v for k, v in cat_map_temp.items() + } # remove "cat" prefix from keys + + cat_map = {} + for link_id in crosswalk_ds["link"].values: + if cat_map_temp.get(str(link_id)) is None: + cat_map[link_id] = [] # add empty list for missing cat_id + else: + cat_map[link_id] = cat_map_temp[str(link_id)] + nwm_ids_flat = [] + cat_ids_flat = [] + for cat_id, nwm_ids in cat_map.items(): + for nwm_id in nwm_ids: + nwm_ids_flat.append(nwm_id) + cat_ids_flat.append(cat_id) + + nwm_ids_flat = np.array(nwm_ids_flat, dtype=float) + cat_ids_flat = np.array(cat_ids_flat) + + nwm_agg, mapping_df = average_nwm_variables(nwm_ids_flat, cat_ids_flat, nwm_ds) + rl_agg = average_rtlink_variables(nwm_ids_flat, mapping_df, routelink_ds) + result_df = pd.DataFrame({"cat_id": crosswalk_ds["link"].values}) + result_df = result_df.join(nwm_agg, on="cat_id").join(rl_agg, on="cat_id").fillna(0) + + # depth calculation + depths = solve_depth_geom( + streamflow=np.array(result_df["streamflow"].values), + velocity=np.array(result_df["velocity"].values), + tw=np.array(result_df["TopWdth"].values), + bw=np.array(result_df["BtmWdth"].values), + cs=np.array(result_df["ChSlp"].values), + ) + result_df["depth"] = depths + + # create netcdf + restart = xr.Dataset( + data_vars={ + "hlink": (["links"], result_df["depth"].values), + "qlink1": (["links"], result_df["streamflow"].values), + "qlink2": (["links"], result_df["streamflow"].values), + }, + coords={"links": range(len(result_df))}, + attrs={ + "Restart_Time": pd.Timestamp(nwm_ds["time"].values[0]).strftime( + "%Y-%m-%d_%H:%M:%S" + ) + }, + ) + + return restart + + +def write_netcdf_restart(storage_type: str, prefix: Path, ds: xr.Dataset, name: str): + """ + Write restart data to a NetCDF file. + + Parameters: + storage_type (str): s3 or local + prefix (Path): filename prefix + data (xr.Dataset): restart file + name (str): string for the filename + Returns: + netcdf_cat_file_size (list): file size of output netcdf + """ + if storage_type == "s3": + s3_client = boto3.session.Session().client("s3") + nc_filename = str(prefix) + "/" + name + bucket, key = convert_url2key(nc_filename, "s3") + with tempfile.NamedTemporaryFile(suffix=".nc") as tmpfile: + ds.to_netcdf(tmpfile.name, engine="netcdf4") + netcdf_cat_file_size = os.path.getsize(tmpfile.name) / B2MB + tmpfile.flush() + tmpfile.seek(0) + print(f"Uploading netcdf forcings to S3: bucket={bucket}, key={key}") + s3_client.upload_file(tmpfile.name, bucket, key) + else: + nc_filename = Path(prefix, name) + ds.to_netcdf(nc_filename, engine="netcdf4") + print(f"netcdf has been written to {nc_filename}") + netcdf_cat_file_size = os.path.getsize(nc_filename) / B2MB + return [netcdf_cat_file_size] diff --git a/tests/test_trouterestarts.py b/tests/test_trouterestarts.py new file mode 100644 index 0000000..3ad7326 --- /dev/null +++ b/tests/test_trouterestarts.py @@ -0,0 +1,290 @@ +""" +Unit tests for restart utility functions. +""" + +import os +from pathlib import Path +from datetime import datetime, timedelta, timezone +import re +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from forcingprocessor.troute_restart_tools import ( + average_nwm_variables, + average_rtlink_variables, + create_restart, + quadratic_formula, + solve_depth_geom, +) +from forcingprocessor.processor import prep_ngen_data +from forcingprocessor.nwm_filenames_generator import generate_nwmfiles + +# --------------------------------------------------------------------------- +# minimum viable examples +# --------------------------------------------------------------------------- + +simple_nwm_ds = xr.Dataset( + { + "streamflow": ("feature_id", [10.0, 20.0, 30.0, 40.0]), + "velocity": ("feature_id", [1.0, 2.0, 3.0, 4.0]), + }, + coords={ + "feature_id": [101, 102, 103, 104], + "time": [np.datetime64("2024-01-15T12:00:00.000000000")], + }, +) + +simple_routelink_ds = xr.Dataset( + { + "link": ("feature_id", [101, 102, 103, 104]), + "TopWdth": ("feature_id", [10.0, 12.0, 14.0, 16.0]), + "BtmWdth": ("feature_id", [5.0, 6.0, 7.0, 8.0]), + "ChSlp": ("feature_id", [0.5, 0.5, 0.5, 0.5]), + }, + coords={}, +) + +simple_crosswalk_ds = xr.Dataset(coords={"link": [1, 2]}) + +simple_cat_map = { + "cat-1": [101.0, 102.0], + "cat-2": [103.0, 104.0], +} + + +# --------------------------------------------------------------------------- +# unit tests +# --------------------------------------------------------------------------- + + +def test_averages_streamflow_and_velocity(): + # cat 1 -> features 101 (sf=10, v=1) and 102 (sf=20, v=2) => mean sf=15, v=1.5 + # cat 2 -> feature 103 (sf=30, v=3) => mean sf=30, v=3.0 + nwm_ids = np.array([101.0, 102.0, 103.0]) + cat_ids = np.array([1, 1, 2]) + agg, mapping = average_nwm_variables(nwm_ids, cat_ids, simple_nwm_ds) + + assert agg.loc[1, "streamflow"] == pytest.approx(15.0) # test averaging + assert agg.loc[1, "velocity"] == pytest.approx(1.5) + assert agg.loc[2, "streamflow"] == pytest.approx(30.0) + assert agg.loc[2, "velocity"] == pytest.approx(3.0) + + assert len(mapping) == 3 # test mapping + assert set(mapping.columns) == {"feature_id", "cat_id"} + + assert 3 not in agg.index # test subsetting + assert len(agg) == 2 + + +def test_averages_routelink(): + nwm_ids = np.array([101.0, 102.0, 103.0, 104.0]) + mapping = pd.DataFrame( + {"feature_id": [101.0, 102.0, 103.0, 104.0], "cat_id": [1, 1, 2, 2]} + ) + agg = average_rtlink_variables(nwm_ids, mapping, simple_routelink_ds) + + assert agg.loc[1, "TopWdth"] == pytest.approx(11.0) # test averaging + assert agg.loc[2, "TopWdth"] == pytest.approx(15.0) + assert agg.loc[1, "BtmWdth"] == pytest.approx(5.5) + assert agg.loc[2, "BtmWdth"] == pytest.approx(7.5) + assert agg.loc[1, "ChSlp"] == pytest.approx(0.5) + assert agg.loc[2, "ChSlp"] == pytest.approx(0.5) + + assert len(agg) == 2 # test layout + + +def test_quadratic_formula(): + # Two equations: [x^2+2x-3, x^2-4] -> roots [1, 2] + b = np.array([2.0, 0.0]) + c = np.array([-3.0, -4.0]) + result = quadratic_formula(b, c) + np.testing.assert_allclose(result, [1.0, 2.0]) + + +def test_solve_depth_geom(): + sf = np.array([0.0, 5.0, 2.0, 8.0, 1000.0, np.nan]) + v = np.array([1.0, 0.0, 1.0, 1.0, 1.0, 1.0]) + tw = np.array([10.0, 10.0, 10.0, 10.0, 10.0, 10.0]) + bw = np.array([5.0, 5.0, 5.0, 5.0, 5.0, 5.0]) + cs = np.array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) + depths = solve_depth_geom(sf, v, tw, bw, cs) + + assert depths.shape == (6,) + assert depths[0] == pytest.approx(0.0) + assert depths[1] == pytest.approx(0.0) + assert depths[2] == pytest.approx(0.35078106) + assert depths[3] == pytest.approx(1.10849528) + assert depths[4] == pytest.approx(34.27083333) + assert depths[5] == pytest.approx(0.0) + + +def test_restart(): + result = create_restart( + simple_cat_map, simple_crosswalk_ds, simple_nwm_ds, simple_routelink_ds + ) + + n_links = len(simple_crosswalk_ds["link"]) + assert result["hlink"].shape == (n_links,) + assert result["qlink1"].shape == (n_links,) + assert result["qlink2"].shape == (n_links,) + + assert result.attrs["Restart_Time"] == "2024-01-15_12:00:00" + + assert result["hlink"].values[0] == pytest.approx(1.25) + assert result["hlink"].values[1] == pytest.approx(1.04315438) + assert ( + result["qlink1"].values[0] == result["qlink2"].values[0] == pytest.approx(15.0) + ) + assert ( + result["qlink1"].values[1] == result["qlink2"].values[1] == pytest.approx(35.0) + ) + + +# --------------------------------------------------------------------------- +# test prep_ngen_conf +# --------------------------------------------------------------------------- + +HF_VERSION="v2.2" +date = datetime.now(timezone.utc) +date = date.strftime('%Y%m%d') +HOURMINUTE = '0000' +TODAY_START = date + HOURMINUTE +yesterday = datetime.now(timezone.utc) - timedelta(hours=24) +yesterday = yesterday.strftime('%Y%m%d') +test_dir = Path(__file__).parent +data_dir = (test_dir/'data').resolve() +forcings_dir = (data_dir/'restart').resolve() +pwd = Path.cwd() +if os.path.exists(data_dir): + os.system(f"rm -rf {data_dir}") +os.system(f"mkdir {data_dir}") +FILENAMELIST = str((pwd/"filenamelist.txt").resolve()) + +conf = { + "forcing" : { + "nwm_file" : FILENAMELIST, + "restart_map_file" : f"{pwd}/docs/examples/troute-restart_example/hf2.2_subset_cat_map.json", + "crosswalk_file" : f"{pwd}/docs/examples/troute-restart_example/crosswalk_subset.nc", + "routelink_file" : f"{pwd}/docs/examples/troute-restart_example/RouteLink_CONUS_subset.nc" + }, + + "storage":{ + "storage_type" : "local", + "output_path" : str(data_dir), + "output_file_type" : ["netcdf"] + }, + + "run" : { + "verbose" : False, + "collect_stats" : False, + "nprocs" : 1 + } + } + +nwmurl_conf = { + "forcing_type" : "operational_archive", + "start_date" : "", + "end_date" : "", + "runinput" : 5, + "varinput" : 1, + "geoinput" : 1, + "meminput" : 0, + "urlbaseinput" : 7, + "fcst_cycle" : [0], + "lead_time" : [0] + } + +@pytest.fixture +def clean_dir(autouse=True): + if os.path.exists(forcings_dir): + os.system(f'rm -rf {str(forcings_dir)}') + +def test_nomads_prod(): + nwmurl_conf['start_date'] = TODAY_START + nwmurl_conf['end_date'] = TODAY_START + nwmurl_conf["urlbaseinput"] = 1 + generate_nwmfiles(nwmurl_conf) + conf['run']['collect_stats'] = True # test metadata generation once + prep_ngen_data(conf) + conf['run']['collect_stats'] = False + assert_file=Path(data_dir / f"restart/channel_restart_{date}_{HOURMINUTE}00.nc").resolve() + assert assert_file.exists() + os.remove(assert_file) + +def test_nwm_google_apis(): + nwmurl_conf['start_date'] = TODAY_START + nwmurl_conf['end_date'] = TODAY_START + nwmurl_conf["urlbaseinput"] = 3 + generate_nwmfiles(nwmurl_conf) + prep_ngen_data(conf) + assert_file=Path(data_dir / f"restart/channel_restart_{date}_{HOURMINUTE}00.nc").resolve() + assert assert_file.exists() + os.remove(assert_file) + +def test_google_cloud_storage(): + nwmurl_conf['start_date'] = "202407100100" + nwmurl_conf['end_date'] = "202407100100" + nwmurl_conf["urlbaseinput"] = 4 + generate_nwmfiles(nwmurl_conf) + prep_ngen_data(conf) + assert_file=(data_dir/"restart/channel_restart_20240710_000000.nc").resolve() + assert assert_file.exists() + os.remove(assert_file) + +def test_gs(): + nwmurl_conf['start_date'] = TODAY_START + nwmurl_conf['end_date'] = TODAY_START + nwmurl_conf["urlbaseinput"] = 5 + generate_nwmfiles(nwmurl_conf) + assert_file=Path(data_dir / f"restart/channel_restart_{date}_{HOURMINUTE}00.nc").resolve() + prep_ngen_data(conf) + assert assert_file.exists() + os.remove(assert_file) + +def test_gcs(): + nwmurl_conf['start_date'] = "202407100100" + nwmurl_conf['end_date'] = "202407100100" + nwmurl_conf["urlbaseinput"] = 6 + generate_nwmfiles(nwmurl_conf) + prep_ngen_data(conf) + assert_file=(data_dir/"restart/channel_restart_20240710_000000.nc").resolve() + assert assert_file.exists() + os.remove(assert_file) + +def test_noaa_nwm_pds_https(): + nwmurl_conf['start_date'] = TODAY_START + nwmurl_conf['end_date'] = TODAY_START + nwmurl_conf["urlbaseinput"] = 7 + generate_nwmfiles(nwmurl_conf) + prep_ngen_data(conf) + assert_file=Path(data_dir / f"restart/channel_restart_{date}_{HOURMINUTE}00.nc").resolve() + assert assert_file.exists() + os.remove(assert_file) + +def test_noaa_nwm_pds_s3(): + nwmurl_conf['start_date'] = TODAY_START + nwmurl_conf['end_date'] = TODAY_START + nwmurl_conf["urlbaseinput"] = 8 + generate_nwmfiles(nwmurl_conf) + prep_ngen_data(conf) + assert_file=Path(data_dir / f"restart/channel_restart_{date}_{HOURMINUTE}00.nc").resolve() + assert assert_file.exists() + os.remove(assert_file) + +def test_s3_output(): + test_bucket = "ciroh-community-ngen-datastream" + conf['storage']['output_path'] = f's3://{test_bucket}/test/cicd/forcingprocessor/pytest' + nwmurl_conf["urlbaseinput"] = 4 + generate_nwmfiles(nwmurl_conf) + prep_ngen_data(conf) + conf['storage']['output_path'] = str(data_dir) + +def test_netcdf_output_type(): + generate_nwmfiles(nwmurl_conf) + conf['storage']['output_file_type'] = ["netcdf"] + prep_ngen_data(conf) + assert_file=Path(data_dir / f"restart/channel_restart_{date}_{HOURMINUTE}00.nc").resolve() + assert assert_file.exists() + os.remove(assert_file)