diff --git a/pyproject.toml b/pyproject.toml index 446219d..ac857b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,4 +99,4 @@ download_data = "stormworkflow.prep.download_data:cli" setup_ensemble = "stormworkflow.prep.setup_ensemble:cli" combine_ensemble = "stormworkflow.post.combine_ensemble:cli" analyze_ensemble = "stormworkflow.post.analyze_ensemble:cli" -storm_roc_curve = "stormworkflow.post.storm_roc_curve:cli" +storm_roc_ts_rel_curves = "stormworkflow.post.storm_roc_ts_rel_curves:cli" diff --git a/stormworkflow/post/storm_roc_curve.py b/stormworkflow/post/storm_roc_ts_rel_curves.py similarity index 50% rename from stormworkflow/post/storm_roc_curve.py rename to stormworkflow/post/storm_roc_ts_rel_curves.py index fe7bde5..fcee32b 100644 --- a/stormworkflow/post/storm_roc_curve.py +++ b/stormworkflow/post/storm_roc_ts_rel_curves.py @@ -9,8 +9,8 @@ import matplotlib.pyplot as plt from pathlib import Path from cartopy.feature import NaturalEarthFeature - -import geodatasets +from geodatasets import get_path +from sklearn.calibration import calibration_curve os.environ['USE_PYGEOS'] = '0' import geopandas as gpd @@ -124,18 +124,28 @@ def main(args): leadtime = args.leadtime obs_df_path = Path(args.obs_df_path) ensemble_dir = Path(args.ensemble_dir) + output_directory = args.output_dir + suffix = args.output_suffix + plot_prob_map = args.plot_prob_map + psurge_far_path = args.psurge_far_path + psurge_pod_path = args.psurge_pod_path - output_directory = ensemble_dir / 'analyze/linear_k1_p1_n0.025' - prob_nc_path = output_directory / 'probabilities.nc' + input_directory = ensemble_dir / 'analyze/linear_k1_p1_n0.025' + prob_nc_path = input_directory / 'probabilities.nc' + if output_directory is None: + output_directory = input_directory if leadtime == -1: leadtime = 48 + if suffix is None: + suffix = '' + # *.nc file coordinates - thresholds_ft = [3, 6, 9] # in ft + thresholds_ft = [3, 4, 5, 6, 9] # in ft thresholds_m = [round(i * 0.3048, 4) for i in thresholds_ft] # convert to meter sources = ['model', 'surrogate'] - probabilities = [0.0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + probabilities = list(np.linspace(0, 1, 101)) # attributes of input files prediction_variable = 'probabilities' @@ -145,9 +155,9 @@ def main(args): max_distance = 1000 # [in meters] to set distance_upper_bound max_neighbors = 10 # to set k + # creating arrays for ROC and TS plots blank_arr = np.empty((len(thresholds_ft), 1, 1, len(sources), len(probabilities))) blank_arr[:] = np.nan - hit_arr = blank_arr.copy() miss_arr = blank_arr.copy() false_alarm_arr = blank_arr.copy() @@ -157,23 +167,33 @@ def main(args): # Load obs file, extract storm obs points and coordinates df_obs = pd.read_csv(obs_df_path) - Event_name = f'{storm}_{year}' - df_obs_storm = df_obs[df_obs.Event == Event_name] + # Get first matching name (first landfall) + CSV_Event_name = df_obs[df_obs.Event.str.fullmatch(f'{storm}\w*_{year}')].Event.iloc[0] + NC_Event_name = CSV_Event_name[:CSV_Event_name.find('_')] + df_obs_storm = df_obs[df_obs.Event == CSV_Event_name] obs_coordinates = stack_station_coordinates( df_obs_storm.Longitude.values, df_obs_storm.Latitude.values ) + # creating arrays for reliability diagram + blank_arr = np.empty((len(thresholds_ft), 1, 1, len(sources), len(df_obs_storm))) + blank_arr[:] = np.nan + obs_true_arr = blank_arr.copy() + pred_prob_arr = blank_arr.copy() + + # Load Psurge results + ds_psurge = None + if not (psurge_far_path is None or psurge_pod_path is None): + ds_psurge = xr.open_mfdataset( + [psurge_far_path, psurge_pod_path], + combine='nested', + ) + # Load probabilities.nc file ds_prob = xr.open_dataset(prob_nc_path) - gdf_countries = gpd.read_file(geodatasets.get_path('naturalearth land')) - - # gdf_countries = gpd.GeoSeries( - # NaturalEarthFeature(category='physical', scale='10m', name='land',).geometries(), - # crs=4326, - # ) + gdf_countries = gpd.read_file(get_path('naturalearth land')) - # Loop through thresholds and sources and find corresponding values from probabilities.nc threshold_count = -1 for threshold in thresholds_m: threshold_count += 1 @@ -191,16 +211,23 @@ def main(args): df_obs_storm[f'{source}_prob'] = prediction_prob # Plot probabilities at obs. points - plot_probabilities( - df_obs_storm, - f'{source}_prob', - gdf_countries, - f'Probability of {source} exceeding {thresholds_ft[threshold_count]} ft \n {storm}, {year}, {leadtime}-hr leadtime', - os.path.join( - output_directory, - f'prob_{source}_above_{thresholds_ft[threshold_count]}ft_{storm}_{year}_{leadtime}-hr.png', - ), + if plot_prob_map: + plot_probabilities( + df_obs_storm, + f'{source}_prob', + gdf_countries, + f'Probability of {source} exceeding {thresholds_ft[threshold_count]} ft \n {storm}, {year}, {leadtime}-hr leadtime', + os.path.join( + output_directory, + f'prob_{source}_above_{thresholds_ft[threshold_count]}ft_{storm}_{year}_{leadtime}-hr_{suffix}.png', + ), + ) + + # Enter observed above threshold and prediction probability into array + obs_true_arr[threshold_count, 0, 0, source_count, :] = ( + df_obs_storm[obs_attribute] > threshold ) + pred_prob_arr[threshold_count, 0, 0, source_count, :] = prediction_prob # Loop through probabilities: calculate hit/miss/... & POD/FAR prob_count = -1 @@ -225,6 +252,7 @@ def main(args): leadtime=[leadtime], source=sources, prob=probabilities, + points=range(len(df_obs_storm)), ), data_vars=dict( hit=(['threshold', 'storm', 'leadtime', 'source', 'prob'], hit_arr), @@ -239,13 +267,19 @@ def main(args): ), POD=(['threshold', 'storm', 'leadtime', 'source', 'prob'], POD_arr), FAR=(['threshold', 'storm', 'leadtime', 'source', 'prob'], FAR_arr), + obs_true=(['threshold', 'storm', 'leadtime', 'source', 'points'], obs_true_arr), + pred_prob=(['threshold', 'storm', 'leadtime', 'source', 'points'], pred_prob_arr), ), ) - ds_ROC.to_netcdf(os.path.join(output_directory, f'{storm}_{year}_{leadtime}hr_POD_FAR.nc')) + ds_ROC.to_netcdf( + os.path.join( + output_directory, f'{storm}_{year}_{leadtime}hr_probabilistic_evaluation_stats_{suffix}.nc' + ) + ) # plot ROC curves - marker_list = ['s', 'x'] - linestyle_list = ['dashed', 'dotted'] + colormarker_list = ['bd', 'kx'] + linestyle_list = ['--', '-'] threshold_count = -1 for threshold in thresholds_ft: threshold_count += 1 @@ -257,23 +291,156 @@ def main(args): source_count = -1 for source in sources: source_count += 1 + AUC = abs( + np.trapz( + POD_arr[threshold_count, 0, 0, source_count, :], + x=FAR_arr[threshold_count, 0, 0, source_count, :], + ) + ) + label = f'{source}, AUC={AUC:.2f}' + colormarker = colormarker_list[source_count] + if source == 'surrogate': + best_res = ( + POD_arr[threshold_count, 0, 0, source_count, :] + - FAR_arr[threshold_count, 0, 0, source_count, :] + ).argmax() + label = f'{label}, x @ p({probabilities[best_res]:.2f})' + plt.plot( + FAR_arr[threshold_count, 0, 0, source_count, best_res], + POD_arr[threshold_count, 0, 0, source_count, best_res], + 'kX', + ) plt.plot( FAR_arr[threshold_count, 0, 0, source_count, :], POD_arr[threshold_count, 0, 0, source_count, :], + colormarker, + label=label, + linestyle=linestyle_list[source_count], + markersize=5, + ) + if ds_psurge is not None and threshold in ds_psurge.threshold: + psurge_far = ds_psurge.sel( + version='v3pt0Apr062023_kdtree', + threshold=threshold, + storm=NC_Event_name, + leadtime=leadtime).FARate.values + psurge_pod = ds_psurge.sel( + version='v3pt0Apr062023_kdtree', + threshold=threshold, + storm=NC_Event_name, + leadtime=leadtime).POD.values + AUC = abs(np.trapz(psurge_pod, x=psurge_far)) + label = f'psurge, AUC={AUC:.2f}' + plt.plot( + psurge_far, + psurge_pod, + 'ro', + label=label, + linestyle=':', + markersize=4, + ) + plt.legend(loc='lower right') + npos = int(hit_arr[threshold_count, 0, 0, 0, 0]) + nneg = int(false_alarm_arr[threshold_count, 0, 0, 0, 0]) + plt.xlabel(f'False Alarm Rate, N={nneg}') + plt.ylabel(f'Probability of Detection, N={npos}') + plt.xlim([-0.01, 1.01]) + plt.ylim([-0.01, 1.01]) + plt.grid(True) + + plt.title( + f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold: N={len(df_obs_storm)}' + ) + plt.savefig( + os.path.join( + output_directory, + f'ROC_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft_{suffix}.png', + ) + ) + plt.close() + + # plot TS curves + threshold_count = -1 + for threshold in thresholds_ft: + threshold_count += 1 + fig = plt.figure() + ax = fig.add_subplot(111) + source_count = -1 + for source in sources: + source_count += 1 + # TS = hits / (hits + misses + false alarms) + TS = hit_arr[threshold_count, 0, 0, source_count, :] / ( + hit_arr[threshold_count, 0, 0, source_count, :] + + miss_arr[threshold_count, 0, 0, source_count, :] + + false_alarm_arr[threshold_count, 0, 0, source_count, :] + ) + + plt.plot( + probabilities, + TS, + colormarker_list[source_count][0], + label=f'{source}', + linestyle=linestyle_list[source_count], + ) + plt.legend(loc='upper right') + plt.ylabel(f'Threat Score') + plt.xlabel(f'probability of exceedance') + plt.xlim([-0.01, 1.01]) + plt.ylim([-0.01, 1.01]) + plt.grid(True) + + plt.title( + f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold: N={len(df_obs_storm)}' + ) + plt.savefig( + os.path.join( + output_directory, + f'TS_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft_{suffix}.png', + ) + ) + plt.close() + + # plot reliability curves + threshold_count = -1 + for threshold in thresholds_ft: + threshold_count += 1 + fig = plt.figure() + ax = fig.add_subplot(111) + plt.axline( + (0.0, 0.0), (1.0, 1.0), linestyle='--', color='grey', label='perfect reliability' + ) + source_count = -1 + for source in sources: + source_count += 1 + true_prob, pred_prob = calibration_curve( + obs_true_arr[threshold_count, 0, 0, source_count, :], + pred_prob_arr[threshold_count, 0, 0, source_count, :], + n_bins=5, + strategy='uniform', + ) + plt.plot( + pred_prob, + true_prob, + colormarker_list[source_count], label=f'{source}', - marker=marker_list[source_count], linestyle=linestyle_list[source_count], markersize=5, ) - plt.legend() - plt.xlabel('False Alarm Rate') - plt.ylabel('Probability of Detection') - plt.title(f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold') + plt.xlim([-0.01, 1.01]) + plt.ylim([-0.01, 1.01]) + plt.legend(loc='lower right') + plt.xlabel(f'Predicted probability of exceedance') + plt.ylabel(f'Observed fraction of exceedances') + plt.grid(True) + + plt.title( + f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold: N={len(df_obs_storm)}' + ) plt.savefig( os.path.join( output_directory, - f'ROC_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft.png', + f'REL_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft_{suffix}.png', ) ) plt.close() @@ -288,6 +455,36 @@ def cli(): parser.add_argument('--obs_df_path', help='path to NHC obs data', type=str) parser.add_argument('--ensemble-dir', help='path to ensemble.dir', type=str) + # optional + parser.add_argument( + '--output-dir', + help='directory to save the outputs of this function', + default=None, + type=str, + ) + + parser.add_argument( + '--plot_prob_map', + help='plot the prediction probability at observations maps', + default=True, + action=argparse.BooleanOptionalAction, + ) + parser.add_argument( + '--psurge-far-path', + help='path to NHC PSurge FAR results', + type=Path + ) + parser.add_argument( + '--psurge-pod-path', + help='path to NHC PSurge POD results', + type=Path + ) + parser.add_argument( + '--output-suffix', + help='prefix to be used for path of output files', + type=str + ) + main(parser.parse_args()) diff --git a/stormworkflow/slurm/post.sbatch b/stormworkflow/slurm/post.sbatch index 33ef357..81d039d 100644 --- a/stormworkflow/slurm/post.sbatch +++ b/stormworkflow/slurm/post.sbatch @@ -14,7 +14,7 @@ analyze_ensemble \ --ensemble-dir $ENSEMBLE_DIR \ --tracks-dir $ENSEMBLE_DIR/track_files -storm_roc_curve \ +storm_roc_ts_rel_curves \ --storm ${storm} \ --year ${year} \ --leadtime ${hr_prelandfall} \