Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,4 @@ download_data = "stormworkflow.prep.download_data:cli"
setup_ensemble = "stormworkflow.prep.setup_ensemble:cli"
combine_ensemble = "stormworkflow.post.combine_ensemble:cli"
analyze_ensemble = "stormworkflow.post.analyze_ensemble:cli"
storm_roc_curve = "stormworkflow.post.storm_roc_curve:cli"
storm_roc_ts_rel_curves = "stormworkflow.post.storm_roc_ts_rel_curves:cli"
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import matplotlib.pyplot as plt
from pathlib import Path
from cartopy.feature import NaturalEarthFeature

import geodatasets
from geodatasets import get_path
from sklearn.calibration import calibration_curve

os.environ['USE_PYGEOS'] = '0'
import geopandas as gpd
Expand Down Expand Up @@ -124,18 +124,28 @@ def main(args):
leadtime = args.leadtime
obs_df_path = Path(args.obs_df_path)
ensemble_dir = Path(args.ensemble_dir)
output_directory = args.output_dir
suffix = args.output_suffix
plot_prob_map = args.plot_prob_map
psurge_far_path = args.psurge_far_path
psurge_pod_path = args.psurge_pod_path

output_directory = ensemble_dir / 'analyze/linear_k1_p1_n0.025'
prob_nc_path = output_directory / 'probabilities.nc'
input_directory = ensemble_dir / 'analyze/linear_k1_p1_n0.025'
prob_nc_path = input_directory / 'probabilities.nc'
if output_directory is None:
output_directory = input_directory

if leadtime == -1:
leadtime = 48

if suffix is None:
suffix = ''

# *.nc file coordinates
thresholds_ft = [3, 6, 9] # in ft
thresholds_ft = [3, 4, 5, 6, 9] # in ft
thresholds_m = [round(i * 0.3048, 4) for i in thresholds_ft] # convert to meter
sources = ['model', 'surrogate']
probabilities = [0.0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
probabilities = list(np.linspace(0, 1, 101))

# attributes of input files
prediction_variable = 'probabilities'
Expand All @@ -145,9 +155,9 @@ def main(args):
max_distance = 1000 # [in meters] to set distance_upper_bound
max_neighbors = 10 # to set k

# creating arrays for ROC and TS plots
blank_arr = np.empty((len(thresholds_ft), 1, 1, len(sources), len(probabilities)))
blank_arr[:] = np.nan

hit_arr = blank_arr.copy()
miss_arr = blank_arr.copy()
false_alarm_arr = blank_arr.copy()
Expand All @@ -157,23 +167,33 @@ def main(args):

# Load obs file, extract storm obs points and coordinates
df_obs = pd.read_csv(obs_df_path)
Event_name = f'{storm}_{year}'
df_obs_storm = df_obs[df_obs.Event == Event_name]
# Get first matching name (first landfall)
CSV_Event_name = df_obs[df_obs.Event.str.fullmatch(f'{storm}\w*_{year}')].Event.iloc[0]
NC_Event_name = CSV_Event_name[:CSV_Event_name.find('_')]
df_obs_storm = df_obs[df_obs.Event == CSV_Event_name]
obs_coordinates = stack_station_coordinates(
df_obs_storm.Longitude.values, df_obs_storm.Latitude.values
)

# creating arrays for reliability diagram
blank_arr = np.empty((len(thresholds_ft), 1, 1, len(sources), len(df_obs_storm)))
blank_arr[:] = np.nan
obs_true_arr = blank_arr.copy()
pred_prob_arr = blank_arr.copy()

# Load Psurge results
ds_psurge = None
if not (psurge_far_path is None or psurge_pod_path is None):
ds_psurge = xr.open_mfdataset(
[psurge_far_path, psurge_pod_path],
combine='nested',
)

# Load probabilities.nc file
ds_prob = xr.open_dataset(prob_nc_path)

gdf_countries = gpd.read_file(geodatasets.get_path('naturalearth land'))

# gdf_countries = gpd.GeoSeries(
# NaturalEarthFeature(category='physical', scale='10m', name='land',).geometries(),
# crs=4326,
# )
gdf_countries = gpd.read_file(get_path('naturalearth land'))

# Loop through thresholds and sources and find corresponding values from probabilities.nc
threshold_count = -1
for threshold in thresholds_m:
threshold_count += 1
Expand All @@ -191,16 +211,23 @@ def main(args):
df_obs_storm[f'{source}_prob'] = prediction_prob

# Plot probabilities at obs. points
plot_probabilities(
df_obs_storm,
f'{source}_prob',
gdf_countries,
f'Probability of {source} exceeding {thresholds_ft[threshold_count]} ft \n {storm}, {year}, {leadtime}-hr leadtime',
os.path.join(
output_directory,
f'prob_{source}_above_{thresholds_ft[threshold_count]}ft_{storm}_{year}_{leadtime}-hr.png',
),
if plot_prob_map:
plot_probabilities(
df_obs_storm,
f'{source}_prob',
gdf_countries,
f'Probability of {source} exceeding {thresholds_ft[threshold_count]} ft \n {storm}, {year}, {leadtime}-hr leadtime',
os.path.join(
output_directory,
f'prob_{source}_above_{thresholds_ft[threshold_count]}ft_{storm}_{year}_{leadtime}-hr_{suffix}.png',
),
)

# Enter observed above threshold and prediction probability into array
obs_true_arr[threshold_count, 0, 0, source_count, :] = (
df_obs_storm[obs_attribute] > threshold
)
pred_prob_arr[threshold_count, 0, 0, source_count, :] = prediction_prob

# Loop through probabilities: calculate hit/miss/... & POD/FAR
prob_count = -1
Expand All @@ -225,6 +252,7 @@ def main(args):
leadtime=[leadtime],
source=sources,
prob=probabilities,
points=range(len(df_obs_storm)),
),
data_vars=dict(
hit=(['threshold', 'storm', 'leadtime', 'source', 'prob'], hit_arr),
Expand All @@ -239,13 +267,19 @@ def main(args):
),
POD=(['threshold', 'storm', 'leadtime', 'source', 'prob'], POD_arr),
FAR=(['threshold', 'storm', 'leadtime', 'source', 'prob'], FAR_arr),
obs_true=(['threshold', 'storm', 'leadtime', 'source', 'points'], obs_true_arr),
pred_prob=(['threshold', 'storm', 'leadtime', 'source', 'points'], pred_prob_arr),
),
)
ds_ROC.to_netcdf(os.path.join(output_directory, f'{storm}_{year}_{leadtime}hr_POD_FAR.nc'))
ds_ROC.to_netcdf(
os.path.join(
output_directory, f'{storm}_{year}_{leadtime}hr_probabilistic_evaluation_stats_{suffix}.nc'
)
)

# plot ROC curves
marker_list = ['s', 'x']
linestyle_list = ['dashed', 'dotted']
colormarker_list = ['bd', 'kx']
linestyle_list = ['--', '-']
threshold_count = -1
for threshold in thresholds_ft:
threshold_count += 1
Expand All @@ -257,23 +291,156 @@ def main(args):
source_count = -1
for source in sources:
source_count += 1
AUC = abs(
np.trapz(
POD_arr[threshold_count, 0, 0, source_count, :],
x=FAR_arr[threshold_count, 0, 0, source_count, :],
)
)
label = f'{source}, AUC={AUC:.2f}'
colormarker = colormarker_list[source_count]
if source == 'surrogate':
best_res = (
POD_arr[threshold_count, 0, 0, source_count, :]
- FAR_arr[threshold_count, 0, 0, source_count, :]
).argmax()
label = f'{label}, x @ p({probabilities[best_res]:.2f})'
plt.plot(
FAR_arr[threshold_count, 0, 0, source_count, best_res],
POD_arr[threshold_count, 0, 0, source_count, best_res],
'kX',
)
plt.plot(
FAR_arr[threshold_count, 0, 0, source_count, :],
POD_arr[threshold_count, 0, 0, source_count, :],
colormarker,
label=label,
linestyle=linestyle_list[source_count],
markersize=5,
)
if ds_psurge is not None and threshold in ds_psurge.threshold:
psurge_far = ds_psurge.sel(
version='v3pt0Apr062023_kdtree',
threshold=threshold,
storm=NC_Event_name,
leadtime=leadtime).FARate.values
psurge_pod = ds_psurge.sel(
version='v3pt0Apr062023_kdtree',
threshold=threshold,
storm=NC_Event_name,
leadtime=leadtime).POD.values
AUC = abs(np.trapz(psurge_pod, x=psurge_far))
label = f'psurge, AUC={AUC:.2f}'
plt.plot(
psurge_far,
psurge_pod,
'ro',
label=label,
linestyle=':',
markersize=4,
)
plt.legend(loc='lower right')
npos = int(hit_arr[threshold_count, 0, 0, 0, 0])
nneg = int(false_alarm_arr[threshold_count, 0, 0, 0, 0])
plt.xlabel(f'False Alarm Rate, N={nneg}')
plt.ylabel(f'Probability of Detection, N={npos}')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.grid(True)

plt.title(
f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold: N={len(df_obs_storm)}'
)
plt.savefig(
os.path.join(
output_directory,
f'ROC_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft_{suffix}.png',
)
)
plt.close()

# plot TS curves
threshold_count = -1
for threshold in thresholds_ft:
threshold_count += 1
fig = plt.figure()
ax = fig.add_subplot(111)
source_count = -1
for source in sources:
source_count += 1
# TS = hits / (hits + misses + false alarms)
TS = hit_arr[threshold_count, 0, 0, source_count, :] / (
hit_arr[threshold_count, 0, 0, source_count, :]
+ miss_arr[threshold_count, 0, 0, source_count, :]
+ false_alarm_arr[threshold_count, 0, 0, source_count, :]
)

plt.plot(
probabilities,
TS,
colormarker_list[source_count][0],
label=f'{source}',
linestyle=linestyle_list[source_count],
)
plt.legend(loc='upper right')
plt.ylabel(f'Threat Score')
plt.xlabel(f'probability of exceedance')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.grid(True)

plt.title(
f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold: N={len(df_obs_storm)}'
)
plt.savefig(
os.path.join(
output_directory,
f'TS_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft_{suffix}.png',
)
)
plt.close()

# plot reliability curves
threshold_count = -1
for threshold in thresholds_ft:
threshold_count += 1
fig = plt.figure()
ax = fig.add_subplot(111)
plt.axline(
(0.0, 0.0), (1.0, 1.0), linestyle='--', color='grey', label='perfect reliability'
)
source_count = -1
for source in sources:
source_count += 1
true_prob, pred_prob = calibration_curve(
obs_true_arr[threshold_count, 0, 0, source_count, :],
pred_prob_arr[threshold_count, 0, 0, source_count, :],
n_bins=5,
strategy='uniform',
)
plt.plot(
pred_prob,
true_prob,
colormarker_list[source_count],
label=f'{source}',
marker=marker_list[source_count],
linestyle=linestyle_list[source_count],
markersize=5,
)
plt.legend()
plt.xlabel('False Alarm Rate')
plt.ylabel('Probability of Detection')

plt.title(f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.legend(loc='lower right')
plt.xlabel(f'Predicted probability of exceedance')
plt.ylabel(f'Observed fraction of exceedances')
plt.grid(True)

plt.title(
f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold: N={len(df_obs_storm)}'
)
plt.savefig(
os.path.join(
output_directory,
f'ROC_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft.png',
f'REL_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft_{suffix}.png',
)
)
plt.close()
Expand All @@ -288,6 +455,36 @@ def cli():
parser.add_argument('--obs_df_path', help='path to NHC obs data', type=str)
parser.add_argument('--ensemble-dir', help='path to ensemble.dir', type=str)

# optional
parser.add_argument(
'--output-dir',
help='directory to save the outputs of this function',
default=None,
type=str,
)

parser.add_argument(
'--plot_prob_map',
help='plot the prediction probability at observations maps',
default=True,
action=argparse.BooleanOptionalAction,
)
parser.add_argument(
'--psurge-far-path',
help='path to NHC PSurge FAR results',
type=Path
)
parser.add_argument(
'--psurge-pod-path',
help='path to NHC PSurge POD results',
type=Path
)
parser.add_argument(
'--output-suffix',
help='prefix to be used for path of output files',
type=str
)

main(parser.parse_args())


Expand Down
2 changes: 1 addition & 1 deletion stormworkflow/slurm/post.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ analyze_ensemble \
--ensemble-dir $ENSEMBLE_DIR \
--tracks-dir $ENSEMBLE_DIR/track_files

storm_roc_curve \
storm_roc_ts_rel_curves \
--storm ${storm} \
--year ${year} \
--leadtime ${hr_prelandfall} \
Expand Down
Loading