diff --git a/src/pydartdiags/matplots/matplots.py b/src/pydartdiags/matplots/matplots.py index 8b08675..f2bccf0 100644 --- a/src/pydartdiags/matplots/matplots.py +++ b/src/pydartdiags/matplots/matplots.py @@ -227,33 +227,26 @@ def plot_profile( return fig -def plot_rank_histogram(obs_seq, levels, type, ens_size): - +def plot_rank_histogram(obs_seq, type, ens_size, levels=None): qc0 = stats.select_used_qcs(obs_seq.df) # filter only qc=0, qc=2 - qc0 = qc0[qc0["type"] == type] # filter by type - if qc0.empty: - print(f"No rows found for type: {type}") - return None - - if qc0["vert_unit"].nunique() > 1: - print( - f"Multiple vertical units found in the data: {qc0['vert_unit'].unique()} for type: {type}" - ) - return None - - vert_unit = qc0.iloc[0]["vert_unit"] - conversion, unit = _get_plot_unit(vert_unit) # multiplier and unit for y-axis + print("type: ", type) - stats.bin_by_layer(qc0, levels, verticalUnit=vert_unit) # bin by level + if (isinstance(type, int) and type < 0) or (type == "IDENTITY_OBS"): + print("Observation type is for identity observations.") # No filtering by type for identity obs + if qc0.empty: + print(f"No rows found for type: {type}") + return None + type = "IDENTITY_OBS" - midpoints = qc0["midpoint"].unique() + else: + qc0 = qc0[qc0["type"] == type] # filter by type - for level in sorted(midpoints): + if qc0.empty: + print(f"No rows found for type: {type}") + return None - df = qc0[qc0["midpoint"] == level] - # convert to hPa only for Pressure (Pa) - df["midpoint"] = df["midpoint"].astype(float) - df["midpoint"] = df["midpoint"] * conversion + if levels is None: + print(f"Observation sequence does not use vertical coordinate. Proceeding without level binning.") df = stats.calculate_rank(qc0) @@ -284,6 +277,62 @@ def plot_rank_histogram(obs_seq, levels, type, ens_size): ax2.set_xlabel("Observation Rank (among ensemble members)") ax2.set_ylabel("Count") + fig.suptitle(f"{type}", fontsize=14) + + plt.tight_layout(rect=[0, 0.03, 1, 0.95]) + plt.show() + return None + + elif qc0["vert_unit"].nunique() > 1: + print( + f"Multiple vertical units found in the data: {qc0['vert_unit'].unique()} for type: {type}" + ) + return None + + else: + vert_unit = qc0.iloc[0]["vert_unit"] + conversion, unit = _get_plot_unit(vert_unit) # multiplier and unit for y-axis + + stats.bin_by_layer(qc0, levels, verticalUnit=vert_unit) # bin by level + + midpoints = qc0["midpoint"].unique() + + for level in sorted(midpoints): + + df = qc0[qc0["midpoint"] == level] + # convert to hPa only for Pressure (Pa) + df["midpoint"] = df["midpoint"].astype(float) + df["midpoint"] = df["midpoint"] * conversion + + df = stats.calculate_rank(qc0) + + if "posterior_rank" in df.columns: + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) + else: + fig, ax1 = plt.subplots() + + # Plot the prior rank histogram + bins = list(range(1, ens_size + 2)) + ax1.hist( + df["prior_rank"], bins=bins, color="blue", alpha=0.5, label="prior rank" + ) + ax1.set_title("Prior Rank Histogram") + ax1.set_xlabel("Observation Rank (among ensemble members)") + ax1.set_ylabel("Count") + + # Plot the posterior rank histogram if it exists + if "posterior_rank" in df.columns: + ax2.hist( + df["posterior_rank"], + bins=bins, + color="green", + alpha=0.5, + label="posterior rank", + ) + ax2.set_title("Posterior Rank Histogram") + ax2.set_xlabel("Observation Rank (among ensemble members)") + ax2.set_ylabel("Count") + fig.suptitle(f"{type} at Level {round(level, 1)} {unit}", fontsize=14) plt.tight_layout(rect=[0, 0.03, 1, 0.95])