Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 71 additions & 22 deletions src/pydartdiags/matplots/matplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,33 +227,26 @@ def plot_profile(
return fig


def plot_rank_histogram(obs_seq, levels, type, ens_size):

def plot_rank_histogram(obs_seq, type, ens_size, levels=None):
qc0 = stats.select_used_qcs(obs_seq.df) # filter only qc=0, qc=2
qc0 = qc0[qc0["type"] == type] # filter by type
if qc0.empty:
print(f"No rows found for type: {type}")
return None

if qc0["vert_unit"].nunique() > 1:
print(
f"Multiple vertical units found in the data: {qc0['vert_unit'].unique()} for type: {type}"
)
return None

vert_unit = qc0.iloc[0]["vert_unit"]
conversion, unit = _get_plot_unit(vert_unit) # multiplier and unit for y-axis
print("type: ", type)

stats.bin_by_layer(qc0, levels, verticalUnit=vert_unit) # bin by level
if (isinstance(type, int) and type < 0) or (type == "IDENTITY_OBS"):
print("Observation type is for identity observations.") # No filtering by type for identity obs
if qc0.empty:
print(f"No rows found for type: {type}")
return None
type = "IDENTITY_OBS"

midpoints = qc0["midpoint"].unique()
else:
qc0 = qc0[qc0["type"] == type] # filter by type

for level in sorted(midpoints):
if qc0.empty:
print(f"No rows found for type: {type}")
return None

df = qc0[qc0["midpoint"] == level]
# convert to hPa only for Pressure (Pa)
df["midpoint"] = df["midpoint"].astype(float)
df["midpoint"] = df["midpoint"] * conversion
if levels is None:
print(f"Observation sequence does not use vertical coordinate. Proceeding without level binning.")

df = stats.calculate_rank(qc0)

Expand Down Expand Up @@ -284,6 +277,62 @@ def plot_rank_histogram(obs_seq, levels, type, ens_size):
ax2.set_xlabel("Observation Rank (among ensemble members)")
ax2.set_ylabel("Count")

fig.suptitle(f"{type}", fontsize=14)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()
return None

elif qc0["vert_unit"].nunique() > 1:
print(
f"Multiple vertical units found in the data: {qc0['vert_unit'].unique()} for type: {type}"
)
return None

else:
vert_unit = qc0.iloc[0]["vert_unit"]
conversion, unit = _get_plot_unit(vert_unit) # multiplier and unit for y-axis

stats.bin_by_layer(qc0, levels, verticalUnit=vert_unit) # bin by level

midpoints = qc0["midpoint"].unique()

for level in sorted(midpoints):

df = qc0[qc0["midpoint"] == level]
# convert to hPa only for Pressure (Pa)
df["midpoint"] = df["midpoint"].astype(float)
df["midpoint"] = df["midpoint"] * conversion

df = stats.calculate_rank(qc0)

if "posterior_rank" in df.columns:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
else:
fig, ax1 = plt.subplots()

# Plot the prior rank histogram
bins = list(range(1, ens_size + 2))
ax1.hist(
df["prior_rank"], bins=bins, color="blue", alpha=0.5, label="prior rank"
)
ax1.set_title("Prior Rank Histogram")
ax1.set_xlabel("Observation Rank (among ensemble members)")
ax1.set_ylabel("Count")

# Plot the posterior rank histogram if it exists
if "posterior_rank" in df.columns:
ax2.hist(
df["posterior_rank"],
bins=bins,
color="green",
alpha=0.5,
label="posterior rank",
)
ax2.set_title("Posterior Rank Histogram")
ax2.set_xlabel("Observation Rank (among ensemble members)")
ax2.set_ylabel("Count")

fig.suptitle(f"{type} at Level {round(level, 1)} {unit}", fontsize=14)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
Expand Down
Loading