Skip to content

Commit

Permalink
bump ruff to 0.9 and auto-fix
Browse files Browse the repository at this point in the history
ruff ignore (noqa): pymatviz/(typing|io).py:1:1: A005 Module (typing|io) shadows a Python standard-library module
  • Loading branch information
janosh committed Jan 9, 2025
1 parent 4dc5f1e commit 6989ba5
Show file tree
Hide file tree
Showing 18 changed files with 133 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.6
rev: v0.9.0
hooks:
- id: ruff
args: [--fix]
Expand Down
2 changes: 1 addition & 1 deletion assets/scripts/ptable_plotly/ptable_scatter_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@
title = f"<b>Periodic Table Scatter Plots</b><br>{mode=}"
fig.layout.title.update(text=title, x=0.4, y=0.85, font_size=20)
fig.show()
pmv.io.save_and_compress_svg(fig, f"ptable-scatter-plotly-{mode.replace('+','-')}")
pmv.io.save_and_compress_svg(fig, f"ptable-scatter-plotly-{mode.replace('+', '-')}")
2 changes: 1 addition & 1 deletion pymatviz/classify/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def confusion_matrix(
spaces = [i for i, c in enumerate(label) if c == " "]
if spaces:
split_point = min(spaces, key=lambda x: abs(x - mid))
label = f"{label[:split_point]}<br>{label[split_point + 1:]}" # noqa: PLW2901
label = f"{label[:split_point]}<br>{label[split_point + 1 :]}" # noqa: PLW2901
formatted_labels[key] += [label]

fmt_tile_vals = np.array(
Expand Down
2 changes: 1 addition & 1 deletion pymatviz/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Typing related: TypeAlias, generic types and so on."""
"""Typing related: TypeAlias, generic types and so on.""" # noqa: A005

from __future__ import annotations

Expand Down
6 changes: 3 additions & 3 deletions tests/phonons/test_phonon_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def test_phonon_bands(
assert fig.layout.font.size == 16

actual_x_labels = fig.layout.xaxis.ticktext
assert (
actual_x_labels == expected_x_labels
), f"{actual_x_labels=}, {expected_x_labels=}"
assert actual_x_labels == expected_x_labels, (
f"{actual_x_labels=}, {expected_x_labels=}"
)
assert fig.layout.xaxis.range is None
assert fig.layout.yaxis.range == pytest.approx((0, 5.36385427095))

Expand Down
6 changes: 3 additions & 3 deletions tests/powerups/test_both_powerups.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ def test_annotate_metrics(
y_pred, y_true, metrics=metrics, fmt=fmt, prefix=prefix, suffix=suffix, fig=fig
)
anno_text_with_fixes = _extract_anno_from_fig(out_fig)
assert (
anno_text_with_fixes == prefix + expected_text + suffix
), f"{anno_text_with_fixes=}"
assert anno_text_with_fixes == prefix + expected_text + suffix, (
f"{anno_text_with_fixes=}"
)


def test_annotate_metrics_faceted_plotly(plotly_faceted_scatter: go.Figure) -> None:
Expand Down
61 changes: 31 additions & 30 deletions tests/ptable/plotly/test_ptable_heatmap_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
def test_ptable_heatmap_plotly(glass_formulas: list[str]) -> None:
fig = pmv.ptable_heatmap_plotly(glass_formulas)
assert isinstance(fig, go.Figure)
assert (
len(fig.layout.annotations) == 18 * 10
), "not all periodic table tiles have annotations"
assert (
sum(anno.text != "" for anno in fig.layout.annotations) == 118
), "no annotations should be empty"
assert len(fig.layout.annotations) == 18 * 10, (
"not all periodic table tiles have annotations"
)
assert sum(anno.text != "" for anno in fig.layout.annotations) == 118, (
"no annotations should be empty"
)

# test hover_props and show_values=False
pmv.ptable_heatmap_plotly(
Expand Down Expand Up @@ -134,9 +134,9 @@ def test_ptable_heatmap_plotly_kwarg_combos(
assert fig.layout.font.size == font_size * scale

if len(font_colors) == 2:
assert all(
anno.font.color in font_colors for anno in fig.layout.annotations
), f"{font_colors=}"
assert all(anno.font.color in font_colors for anno in fig.layout.annotations), (
f"{font_colors=}"
)
elif len(font_colors) == 1:
assert all(
anno.font.color == font_colors[0] for anno in fig.layout.annotations
Expand Down Expand Up @@ -197,12 +197,12 @@ def test_ptable_heatmap_plotly_cscale_range(
assert trace.zmin == pytest.approx(data_min)
assert trace.zmax == pytest.approx(data_max)
else:
assert trace.zmin == pytest.approx(
cscale_range[0] or data_min
), f"{cscale_range=}"
assert trace.zmax == pytest.approx(
cscale_range[1] or data_max
), f"{cscale_range=}"
assert trace.zmin == pytest.approx(cscale_range[0] or data_min), (
f"{cscale_range=}"
)
assert trace.zmax == pytest.approx(cscale_range[1] or data_max), (
f"{cscale_range=}"
)


def test_ptable_heatmap_plotly_cscale_range_raises() -> None:
Expand Down Expand Up @@ -414,35 +414,36 @@ def test_ptable_heatmap_plotly_element_symbol_map() -> None:
# Check if custom symbols are used in tile texts
tile_texts = [anno.text for anno in fig.layout.annotations if anno.text]
for custom_symbol in element_symbol_map.values():
assert any(
custom_symbol in text for text in tile_texts
), f"Custom symbol {custom_symbol} not found in tile texts"
assert any(custom_symbol in text for text in tile_texts), (
f"Custom symbol {custom_symbol} not found in tile texts"
)

# Check if original element symbols are still used in hover texts
hover_texts = fig.data[-1].text.flat
for elem in values:
hover_text = next(
text for text in hover_texts if text.startswith(df_ptable.loc[elem, "name"])
)
assert (
f"({elem})" in hover_text
), f"Original symbol {elem} not found in hover text"
assert f"({elem})" in hover_text, (
f"Original symbol {elem} not found in hover text"
)

# Test with partial mapping
partial_map = {"Fe": "This be Iron"}
fig = pmv.ptable_heatmap_plotly(values, element_symbol_map=partial_map)
tile_texts = [anno.text for anno in fig.layout.annotations if anno.text]
assert any(
"This be Iron" in text for text in tile_texts
), "Custom symbol not found in tile texts"
assert any(
"O</span>" in text for text in tile_texts
), "Original symbol 'O' not found in tile texts"
assert any("This be Iron" in text for text in tile_texts), (
"Custom symbol not found in tile texts"
)
assert any("O</span>" in text for text in tile_texts), (
"Original symbol 'O' not found in tile texts"
)

# Test with None value
fig = pmv.ptable_heatmap_plotly(values, element_symbol_map=None)
tile_texts = [anno.text for anno in fig.layout.annotations if anno.text]
for elem in values:
assert any(
f"{elem}</span>" in text for text in tile_texts
), f"Original symbol {elem} not found in tile texts for element_symbol_map=None"
assert any(f"{elem}</span>" in text for text in tile_texts), (
f"Original symbol {elem} not found in tile texts for "
"element_symbol_map=None"
)
18 changes: 9 additions & 9 deletions tests/rdf/test_rdf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def test_calculate_rdf_normalization(composition: list[str], n_atoms: int) -> No
)

# Check if RDF starts from 0 at r=0
assert (
rdf[0] == 0
), f"{rdf[0]=} should start from 0 at r=0 for {el1}-{el2} pair"
assert rdf[0] == 0, (
f"{rdf[0]=} should start from 0 at r=0 for {el1}-{el2} pair"
)

# Check there are no negative values in the RDF
assert all(rdf >= 0), f"RDF contains negative values for {el1}-{el2} pair"
Expand All @@ -61,9 +61,9 @@ def test_calculate_rdf_normalization(composition: list[str], n_atoms: int) -> No
)

# Check if the RDF has the correct number of bins
assert (
len(rdf) == n_bins
), f"RDF should have {n_bins=}, got {len(rdf)} for {el1}-{el2} pair"
assert len(rdf) == n_bins, (
f"RDF should have {n_bins=}, got {len(rdf)} for {el1}-{el2} pair"
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -156,9 +156,9 @@ def test_calculate_rdf_different_species() -> None:
assert np.any(rdf_si_ge > 0), "Si-Ge RDF should have non-zero values"

peak_index = int(4.33 / cutoff * n_bins)
assert (
rdf_si_ge[peak_index] > 0
), "Expected peak in Si-Ge RDF at sqrt(3)/2 * lattice constant"
assert rdf_si_ge[peak_index] > 0, (
"Expected peak in Si-Ge RDF at sqrt(3)/2 * lattice constant"
)


@pytest.mark.parametrize(
Expand Down
36 changes: 18 additions & 18 deletions tests/rdf/test_rdf_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,31 +93,31 @@ def test_element_pair_rdfs_cutoff_and_bin_size(
# Check that we have the correct number of traces (one for each element pair)
n_elements = len({site.specie.symbol for site in struct})
expected_traces = n_elements * (n_elements + 1) // 2
assert (
len(fig.data) == expected_traces
), f"Expected {expected_traces} traces, got {len(fig.data)}"
assert len(fig.data) == expected_traces, (
f"Expected {expected_traces} traces, got {len(fig.data)}"
)

max_cell_len = max(struct.lattice.abc)
for trace in fig.data:
if param == "cutoff":
expected_cutoff = abs(value) * max_cell_len if value < 0 else value
# Check that the x-axis data doesn't exceed the cutoff
assert np.all(
trace.x <= expected_cutoff
), f"X-axis data exceeds cutoff of {expected_cutoff}"
assert np.all(trace.x <= expected_cutoff), (
f"X-axis data exceeds cutoff of {expected_cutoff}"
)
# Check that the maximum x value is close to the cutoff
assert max(trace.x) == pytest.approx(
expected_cutoff
), f"Maximum x={max(trace.x):.4} not close to cutoff {expected_cutoff}"
assert max(trace.x) == pytest.approx(expected_cutoff), (
f"Maximum x={max(trace.x):.4} not close to cutoff {expected_cutoff}"
)
elif param == "bin_size":
# When bin_size is specified but cutoff is None,
# the default cutoff is 2 * max_cell_len
default_cutoff = 2 * max_cell_len
# Check that the number of bins is approximately correct
expected_bins = int(default_cutoff / value)
assert (
0.85 <= expected_bins / len(trace.x) <= 1
), f"{expected_bins=}, got {len(trace.x)}"
assert 0.85 <= expected_bins / len(trace.x) <= 1, (
f"{expected_bins=}, got {len(trace.x)}"
)


def test_element_pair_rdfs_subplot_layout(structures: list[Structure]) -> None:
Expand Down Expand Up @@ -283,15 +283,15 @@ def test_full_rdf_cutoff_and_bin_size(

if param == "cutoff":
assert np.all(trace.x <= value), f"X-axis data exceeds cutoff of {value}"
assert max(trace.x) == pytest.approx(
value
), f"Maximum x value {max(trace.x):.4} not close to cutoff {value}"
assert max(trace.x) == pytest.approx(value), (
f"Maximum x value {max(trace.x):.4} not close to cutoff {value}"
)
elif param == "bin_size":
default_cutoff = 15
expected_bins = int(np.ceil(default_cutoff / value))
assert (
abs(len(trace.x) - expected_bins) <= 1
), f"Expected around {expected_bins} bins, got {len(trace.x)}"
assert abs(len(trace.x) - expected_bins) <= 1, (
f"Expected around {expected_bins} bins, got {len(trace.x)}"
)


def test_full_rdf_consistency(structures: list[Structure]) -> None:
Expand Down
6 changes: 3 additions & 3 deletions tests/structure_viz/test_structure_viz_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,6 @@ def test_structure_2d_color_schemes() -> None:
if isinstance(patch, plt.matplotlib.patches.Wedge)
}

assert (
jmol_colors != vesta_colors
), f"{jmol_colors=}\n\nshould not equal\n\n{vesta_colors=}"
assert jmol_colors != vesta_colors, (
f"{jmol_colors=}\n\nshould not equal\n\n{vesta_colors=}"
)
36 changes: 18 additions & 18 deletions tests/structure_viz/test_structure_viz_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ def test_structure_2d_plotly(kwargs: dict[str, Any]) -> None:
vector_traces = [
trace for trace in fig.data if (trace.name or "").startswith("vector")
]
assert (
len(vector_traces) > 0
), "No vector traces found when show_site_vectors is True"
assert len(vector_traces) > 0, (
"No vector traces found when show_site_vectors is True"
)
for vector_trace in vector_traces:
assert vector_trace.mode == "lines+markers"
assert vector_trace.marker.symbol == "arrow"
Expand Down Expand Up @@ -335,14 +335,14 @@ def test_structure_3d_plotly(kwargs: dict[str, Any]) -> None:
text in site_trace.text for text in kwargs["site_labels"].values()
), "Expected site labels not found in trace text"
elif kwargs["site_labels"] in ("symbol", "species"):
assert len(site_trace.text) == len(
DISORDERED_STRUCT
), "Mismatch in number of site labels"
assert len(site_trace.text) == len(DISORDERED_STRUCT), (
"Mismatch in number of site labels"
)
else:
# If site_labels is False, ensure that the trace has no text
assert (
site_trace.text is None or len(site_trace.text) == 0
), "Unexpected site labels found"
assert site_trace.text is None or len(site_trace.text) == 0, (
"Unexpected site labels found"
)

# Check for sites and arrows
if kwargs.get("show_sites"):
Expand All @@ -355,9 +355,9 @@ def test_structure_3d_plotly(kwargs: dict[str, Any]) -> None:
vector_traces = [
trace for trace in fig.data if (trace.name or "").startswith("vector")
]
assert (
len(vector_traces) > 0
), f"No vector traces even though {show_site_vectors=}"
assert len(vector_traces) > 0, (
f"No vector traces even though {show_site_vectors=}"
)
for vector_trace in vector_traces:
if vector_trace.type == "scatter3d":
assert vector_trace.mode == "lines"
Expand Down Expand Up @@ -526,13 +526,13 @@ def test_hover_text(
assert "<b>" in site_hover_text, f"{site_hover_text=}"
assert "</b>" in site_hover_text, f"{site_hover_text=}"
elif hover_text == SiteCoords.cartesian:
assert re.search(
rf"Coordinates \({re_3_coords}\)", site_hover_text
), f"{site_hover_text=}"
assert re.search(rf"Coordinates \({re_3_coords}\)", site_hover_text), (
f"{site_hover_text=}"
)
elif hover_text == SiteCoords.fractional:
assert re.search(
rf"Coordinates \[{re_3_coords}\]", site_hover_text
), f"{site_hover_text=}"
assert re.search(rf"Coordinates \[{re_3_coords}\]", site_hover_text), (
f"{site_hover_text=}"
)
elif hover_text == SiteCoords.cartesian_fractional:
assert re.search(
rf"Coordinates \({re_3_coords}\) \[{re_3_coords}\]", site_hover_text
Expand Down
12 changes: 6 additions & 6 deletions tests/test_brillouin.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,15 @@ def test_brillouin_zone_3d_trace_counts(material_id: str) -> None:
# Assert exact counts for each trace type
exp_mesh3d = 1
assert trace_counts["mesh3d"] == exp_mesh3d, f"{trace_counts=}, {exp_mesh3d=}"
assert (
trace_counts["scatter3d"] == exp_scatter3d
), f"{trace_counts=}, {exp_scatter3d=}"
assert trace_counts["scatter3d"] == exp_scatter3d, (
f"{trace_counts=}, {exp_scatter3d=}"
)
exp_cone = 6
assert trace_counts["cone"] == exp_cone, f"{trace_counts=}, {exp_cone=}"
assert scatter_modes["lines"] == exp_lines, f"{scatter_modes=}, {exp_lines=}"
exp_markers_text = 1
assert (
scatter_modes["markers+text"] == exp_markers_text
), f"{scatter_modes=}, {exp_markers_text=}"
assert scatter_modes["markers+text"] == exp_markers_text, (
f"{scatter_modes=}, {exp_markers_text=}"
)
exp_text = 6
assert scatter_modes["text"] == exp_text, f"{scatter_modes=}, {exp_text=}"
Loading

0 comments on commit 6989ba5

Please sign in to comment.