Skip to content

Commit d87694d

Browse files
authored
fix: correct sorting (#551)
Closes #550
1 parent 710f1e9 commit d87694d

9 files changed

+46
-22
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ test = [
4646
"hist",
4747
"pytest-mock",
4848
"pytest-mpl",
49+
"pytest-xdist",
4950
"pytest>=6.0",
5051
"scikit-hep-testdata",
5152
"scipy>=1.1.0",

src/mplhep/plot.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -198,34 +198,21 @@ def histplot(
198198
else get_histogram_axes_title(hists[0].axes[0])
199199
)
200200

201-
plottables, flow_info = get_plottables(
202-
hists,
203-
bins=final_bins,
204-
w2=w2,
205-
w2method=w2method,
206-
yerr=yerr,
207-
stack=stack,
208-
density=density,
209-
binwnorm=binwnorm,
210-
flow=flow,
211-
)
212-
flow_bins, underflow, overflow = flow_info
213-
214201
_labels: list[str | None]
215202
if label is None:
216-
_labels = [None] * len(plottables)
203+
_labels = [None] * len(hists)
217204
elif isinstance(label, str):
218-
_labels = [label] * len(plottables)
205+
_labels = [label] * len(hists)
219206
elif not np.iterable(label):
220-
_labels = [str(label)] * len(plottables)
207+
_labels = [str(label)] * len(hists)
221208
else:
222209
_labels = [str(lab) for lab in label]
223210

224211
def iterable_not_string(arg):
225212
return isinstance(arg, collections.abc.Iterable) and not isinstance(arg, str)
226213

227214
_chunked_kwargs: list[dict[str, Any]] = []
228-
for _ in range(len(plottables)):
215+
for _ in range(len(hists)):
229216
_chunked_kwargs.append({})
230217
for kwarg, kwarg_content in kwargs.items():
231218
# Check if iterable
@@ -249,22 +236,35 @@ def iterable_not_string(arg):
249236
if sort.split("_")[0] in ["l", "label"] and isinstance(_labels, list):
250237
order = np.argsort(label) # [::-1]
251238
elif sort.split("_")[0] in ["y", "yield"]:
252-
_yields = [np.sum(_h.values) for _h in plottables] # type: ignore[var-annotated]
239+
_yields = [np.sum(_h.values()) for _h in hists] # type: ignore[var-annotated]
253240
order = np.argsort(_yields)
254241
if len(sort.split("_")) == 2 and sort.split("_")[1] == "r":
255242
order = order[::-1]
256243
elif isinstance(sort, (list, np.ndarray)):
257-
if len(sort) != len(plottables):
258-
msg = f"Sort indexing array is of the wrong size - {len(sort)}, {len(plottables)} expected."
244+
if len(sort) != len(hists):
245+
msg = f"Sort indexing array is of the wrong size - {len(sort)}, {len(hists)} expected."
259246
raise ValueError(msg)
260247
order = np.asarray(sort)
261248
else:
262249
msg = f"Sort type: {sort} not understood."
263250
raise ValueError(msg)
264-
plottables = [plottables[ix] for ix in order]
251+
hists = [hists[ix] for ix in order]
265252
_chunked_kwargs = [_chunked_kwargs[ix] for ix in order]
266253
_labels = [_labels[ix] for ix in order]
267254

255+
plottables, flow_info = get_plottables(
256+
hists,
257+
bins=final_bins,
258+
w2=w2,
259+
w2method=w2method,
260+
yerr=yerr,
261+
stack=stack,
262+
density=density,
263+
binwnorm=binwnorm,
264+
flow=flow,
265+
)
266+
flow_bins, underflow, overflow = flow_info
267+
268268
##########
269269
# Plotting
270270
return_artists: list[StairsArtists | ErrorBarArtists] = []
@@ -274,7 +274,7 @@ def iterable_not_string(arg):
274274
elif histtype == "barstep" and len(plottables) == 1:
275275
histtype = "step"
276276

277-
# customize color cycle assignment when stacking to match legend
277+
# # customize color cycle assignment when stacking to match legend
278278
if stack:
279279
plottables = plottables[::-1]
280280
_chunked_kwargs = _chunked_kwargs[::-1]
7.57 KB
Loading
7.57 KB
Loading
7.54 KB
Loading
7.51 KB
Loading
7.57 KB
Loading
7.34 KB
Loading

tests/test_basic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,3 +707,26 @@ def test_histplot_inputs_pass(h, yerr, htype):
707707
fig, ax = plt.subplots()
708708
hep.histplot(h, bins, yerr=yerr, histtype=htype)
709709
plt.close(fig)
710+
711+
712+
@pytest.mark.parametrize(
713+
"sort", [None, "label", "label_r", "yield", "yield_r", [0, 2, 1]]
714+
)
715+
@pytest.mark.mpl_image_compare(style="default", remove_text=True)
716+
def test_histplot_sort(sort):
717+
np.random.seed(0)
718+
h = hist.new.Reg(10, 0, 10).StrCat([], growth=True).Weight()
719+
ixs = ["FOO", "BAR", "ZOO"]
720+
for i, ix in enumerate(ixs):
721+
h.fill(np.random.normal(2 + i * 1.5, 3, int(100 + 200 * i)), ix)
722+
723+
fig, ax = plt.subplots()
724+
hep.histplot(
725+
[h[:, ix] for ix in h.axes[1]],
726+
label=h.axes[1],
727+
stack=True,
728+
histtype="fill",
729+
sort=sort,
730+
)
731+
ax.legend()
732+
return fig

0 commit comments

Comments
 (0)