Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import matplotlib.pyplot as plt
from pathlib import Path
from cartopy.feature import NaturalEarthFeature

import geodatasets
from geodatasets import get_path
from sklearn.calibration import calibration_curve

os.environ['USE_PYGEOS'] = '0'
import geopandas as gpd
Expand Down Expand Up @@ -124,18 +124,22 @@ def main(args):
leadtime = args.leadtime
obs_df_path = Path(args.obs_df_path)
ensemble_dir = Path(args.ensemble_dir)
output_directory = args.output_dir
plot_prob_map = args.plot_prob_map

output_directory = ensemble_dir / 'analyze/linear_k1_p1_n0.025'
prob_nc_path = output_directory / 'probabilities.nc'
input_directory = ensemble_dir / 'analyze/linear_k1_p1_n0.025'
prob_nc_path = input_directory / 'probabilities.nc'
if output_directory is None:
output_directory = input_directory

if leadtime == -1:
leadtime = 48

# *.nc file coordinates
thresholds_ft = [3, 6, 9] # in ft
thresholds_ft = [3, 4, 5, 6, 9] # in ft
thresholds_m = [round(i * 0.3048, 4) for i in thresholds_ft] # convert to meter
sources = ['model', 'surrogate']
probabilities = [0.0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
probabilities = list(np.linspace(0, 1, 101))

# attributes of input files
prediction_variable = 'probabilities'
Expand All @@ -145,9 +149,9 @@ def main(args):
max_distance = 1000 # [in meters] to set distance_upper_bound
max_neighbors = 10 # to set k

# creating arrays for ROC and TS plots
blank_arr = np.empty((len(thresholds_ft), 1, 1, len(sources), len(probabilities)))
blank_arr[:] = np.nan

hit_arr = blank_arr.copy()
miss_arr = blank_arr.copy()
false_alarm_arr = blank_arr.copy()
Expand All @@ -163,10 +167,16 @@ def main(args):
df_obs_storm.Longitude.values, df_obs_storm.Latitude.values
)

# creating arrays for reliability diagram
blank_arr = np.empty((len(thresholds_ft), 1, 1, len(sources), len(df_obs_storm)))
blank_arr[:] = np.nan
obs_true_arr = blank_arr.copy()
pred_prob_arr = blank_arr.copy()

# Load probabilities.nc file
ds_prob = xr.open_dataset(prob_nc_path)

gdf_countries = gpd.read_file(geodatasets.get_path('naturalearth land'))
gdf_countries = gpd.read_file(get_path('naturalearth land'))

# gdf_countries = gpd.GeoSeries(
# NaturalEarthFeature(category='physical', scale='10m', name='land',).geometries(),
Expand All @@ -191,16 +201,23 @@ def main(args):
df_obs_storm[f'{source}_prob'] = prediction_prob

# Plot probabilities at obs. points
plot_probabilities(
df_obs_storm,
f'{source}_prob',
gdf_countries,
f'Probability of {source} exceeding {thresholds_ft[threshold_count]} ft \n {storm}, {year}, {leadtime}-hr leadtime',
os.path.join(
output_directory,
f'prob_{source}_above_{thresholds_ft[threshold_count]}ft_{storm}_{year}_{leadtime}-hr.png',
),
if plot_prob_map:
plot_probabilities(
df_obs_storm,
f'{source}_prob',
gdf_countries,
f'Probability of {source} exceeding {thresholds_ft[threshold_count]} ft \n {storm}, {year}, {leadtime}-hr leadtime',
os.path.join(
output_directory,
f'prob_{source}_above_{thresholds_ft[threshold_count]}ft_{storm}_{year}_{leadtime}-hr.png',
),
)

# Enter observed above threshold and prediction probability into array
obs_true_arr[threshold_count, 0, 0, source_count, :] = (
df_obs_storm[obs_attribute] > threshold
)
pred_prob_arr[threshold_count, 0, 0, source_count, :] = prediction_prob

# Loop through probabilities: calculate hit/miss/... & POD/FAR
prob_count = -1
Expand All @@ -225,6 +242,7 @@ def main(args):
leadtime=[leadtime],
source=sources,
prob=probabilities,
points=range(len(df_obs_storm)),
),
data_vars=dict(
hit=(['threshold', 'storm', 'leadtime', 'source', 'prob'], hit_arr),
Expand All @@ -239,13 +257,19 @@ def main(args):
),
POD=(['threshold', 'storm', 'leadtime', 'source', 'prob'], POD_arr),
FAR=(['threshold', 'storm', 'leadtime', 'source', 'prob'], FAR_arr),
obs_true=(['threshold', 'storm', 'leadtime', 'source', 'points'], obs_true_arr),
pred_prob=(['threshold', 'storm', 'leadtime', 'source', 'points'], pred_prob_arr),
),
)
ds_ROC.to_netcdf(os.path.join(output_directory, f'{storm}_{year}_{leadtime}hr_POD_FAR.nc'))
ds_ROC.to_netcdf(
os.path.join(
output_directory, f'{storm}_{year}_{leadtime}hr_probabilistic_evaluation_stats.nc'
)
)

# plot ROC curves
marker_list = ['s', 'x']
linestyle_list = ['dashed', 'dotted']
colormarker_list = ['bs', 'kx']
linestyle_list = ['--', '-']
threshold_count = -1
for threshold in thresholds_ft:
threshold_count += 1
Expand All @@ -257,23 +281,136 @@ def main(args):
source_count = -1
for source in sources:
source_count += 1
AUC = abs(
np.trapz(
POD_arr[threshold_count, 0, 0, source_count, :],
x=FAR_arr[threshold_count, 0, 0, source_count, :],
)
)
label = f'{source}, AUC={AUC:.2f}'
colormarker = colormarker_list[source_count]
if source == 'surrogate':
best_res = (
POD_arr[threshold_count, 0, 0, source_count, :]
- FAR_arr[threshold_count, 0, 0, source_count, :]
).argmax()
label = f'{label}, x @ p({probabilities[best_res]:.2f})'
plt.plot(
FAR_arr[threshold_count, 0, 0, source_count, best_res],
POD_arr[threshold_count, 0, 0, source_count, best_res],
colormarker_list[source_count],
)
colormarker = colormarker_list[source_count][0]
plt.plot(
FAR_arr[threshold_count, 0, 0, source_count, :],
POD_arr[threshold_count, 0, 0, source_count, :],
colormarker,
label=label,
linestyle=linestyle_list[source_count],
markersize=5,
)
plt.legend(loc='lower right')
npos = int(hit_arr[threshold_count, 0, 0, 0, 0])
nneg = int(false_alarm_arr[threshold_count, 0, 0, 0, 0])
plt.xlabel(f'False Alarm Rate, N={nneg}')
plt.ylabel(f'Probability of Detection, N={npos}')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.grid(True)

plt.title(
f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold: N={len(df_obs_storm)}'
)
plt.savefig(
os.path.join(
output_directory,
f'ROC_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft.png',
)
)
plt.close()

# plot TS curves
threshold_count = -1
for threshold in thresholds_ft:
threshold_count += 1
fig = plt.figure()
ax = fig.add_subplot(111)
source_count = -1
for source in sources:
source_count += 1
# TS = hits / (hits + misses + false alarms)
TS = hit_arr[threshold_count, 0, 0, source_count, :] / (
hit_arr[threshold_count, 0, 0, source_count, :]
+ miss_arr[threshold_count, 0, 0, source_count, :]
+ false_alarm_arr[threshold_count, 0, 0, source_count, :]
)

plt.plot(
probabilities,
TS,
colormarker_list[source_count][0],
label=f'{source}',
linestyle=linestyle_list[source_count],
)
plt.legend(loc='upper right')
plt.ylabel(f'Threat Score')
plt.xlabel(f'probability of exceedance')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.grid(True)

plt.title(
f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold: N={len(df_obs_storm)}'
)
plt.savefig(
os.path.join(
output_directory,
f'TS_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft.png',
)
)
plt.close()

# plot reliability curves
threshold_count = -1
for threshold in thresholds_ft:
threshold_count += 1
fig = plt.figure()
ax = fig.add_subplot(111)
plt.axline(
(0.0, 0.0), (1.0, 1.0), linestyle='--', color='grey', label='perfect reliability'
)
source_count = -1
for source in sources:
source_count += 1
true_prob, pred_prob = calibration_curve(
obs_true_arr[threshold_count, 0, 0, source_count, :],
pred_prob_arr[threshold_count, 0, 0, source_count, :],
n_bins=5,
strategy='uniform',
)
plt.plot(
pred_prob,
true_prob,
colormarker_list[source_count],
label=f'{source}',
marker=marker_list[source_count],
linestyle=linestyle_list[source_count],
markersize=5,
)
plt.legend()
plt.xlabel('False Alarm Rate')
plt.ylabel('Probability of Detection')

plt.title(f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.legend(loc='lower right')
plt.xlabel(f'Predicted probability of exceedance')
plt.ylabel(f'Observed fraction of exceedances')
plt.grid(True)

plt.title(
f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold: N={len(df_obs_storm)}'
)
plt.savefig(
os.path.join(
output_directory,
f'ROC_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft.png',
f'REL_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft.png',
)
)
plt.close()
Expand All @@ -288,6 +425,21 @@ def cli():
parser.add_argument('--obs_df_path', help='path to NHC obs data', type=str)
parser.add_argument('--ensemble-dir', help='path to ensemble.dir', type=str)

# optional
parser.add_argument(
'--output-dir',
help='directory to save the outputs of this function',
default=None,
type=str,
)

parser.add_argument(
'--plot_prob_map',
help='plot the prediction probability at observations maps',
default=True,
action=argparse.BooleanOptionalAction,
)

main(parser.parse_args())


Expand Down
2 changes: 1 addition & 1 deletion stormworkflow/slurm/post.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ analyze_ensemble \
--ensemble-dir $ENSEMBLE_DIR \
--tracks-dir $ENSEMBLE_DIR/track_files

storm_roc_curve \
storm_roc_ts_rel_curves \
--storm ${storm} \
--year ${year} \
--leadtime ${hr_prelandfall} \
Expand Down