Skip to content

Commit

Permalink
phonon_bands enable custom acoustic/optical bands (#249)
Browse files Browse the repository at this point in the history
* fix NaN handling in ptable_heatmap_splits_plotly

* phonon_bands line_kwargs allow separate dicts for acoustic/optical modes or callable

new admissible types:
line_kwargs: dict[Literal["acoustic", "optical"], dict[str, Any]]
# function taking (band_data, band_idx)
| Callable[[np.ndarray, int], dict[str, Any]]

* test_phonon_bands with acoustic/optical subdict and callable line_kwargs
  • Loading branch information
janosh authored Nov 19, 2024
1 parent 4380b05 commit 0d24ace
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 47 deletions.
25 changes: 21 additions & 4 deletions assets/scripts/phonons/phonon_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,32 @@
):
docs = {}
for path in glob(f"{pmv.utils.TEST_FILES}/phonons/{mp_id}-{formula}-*.json.lzma"):
key = path.split("-")[-1].split(".")[0]
model_label = (
"CHGNet"
if "chgnet" in path
else "MACE"
if "mace" in path
else "M3GNet"
if "m3gnet" in path
else "PBE"
)
with zopen(path) as file:
docs[key] = json.loads(file.read(), cls=MontyDecoder)
docs[model_label] = json.loads(file.read(), cls=MontyDecoder)

ph_bands: dict[str, PhononBands] = {
key: getattr(doc, Key.ph_band_structure) for key, doc in docs.items()
}

fig = pmv.phonon_bands(ph_bands)
fig.layout.title = dict(text=f"Phonon Bands of {formula} ({mp_id})", x=0.5, y=0.98)
acoustic_lines: dict[str, str | float] = dict(width=1.5)
optical_lines: dict[str, str | float] = dict(width=1)
if len(ph_bands) == 1:
acoustic_lines |= dict(dash="dash", color="red", name="Acoustic")
optical_lines |= dict(dash="dot", color="blue", name="Optical")

fig = pmv.phonon_bands(
ph_bands, line_kwargs=dict(acoustic=acoustic_lines, optical=optical_lines)
)
fig.layout.title = dict(text=f"{formula} ({mp_id}) Phonon Bands", x=0.5, y=0.98)
fig.layout.margin = dict(l=0, r=0, b=0, t=40)
fig.show()
pmv.io.save_and_compress_svg(fig, f"phonon-bands-{mp_id}")
85 changes: 63 additions & 22 deletions pymatviz/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import Any, TypeAlias

import numpy as np
Expand Down Expand Up @@ -154,16 +154,22 @@ def _shaded_range(

def phonon_bands(
band_structs: PhononBands | dict[str, PhononBands],
line_kwargs: dict[str, Any] | None = None,
line_kwargs: (
dict[str, Any] # single dict for all lines
# separate dicts for modes
| dict[Literal["acoustic", "optical"], dict[str, Any]]
# function taking (band_data, band_idx)
| Callable[[np.ndarray, int], dict[str, Any]]
| None
) = None,
branches: Sequence[str] = (),
branch_mode: BranchMode = "union",
shaded_ys: dict[tuple[YMin | YMax, YMin | YMax], dict[str, Any]]
| bool
| None = None,
**kwargs: Any,
) -> go.Figure:
"""Plot single or multiple pymatgen band structures using Plotly, focusing on the
minimum set of overlapping branches.
"""Plot single or multiple pymatgen band structures using Plotly.
Warning: Only tested with phonon band structures so far but plan is to extend to
electronic band structures.
Expand All @@ -172,7 +178,13 @@ def phonon_bands(
band_structs (PhononBandStructureSymmLine | dict[str, PhononBandStructure]):
Single BandStructureSymmLine or PhononBandStructureSymmLine object or a dict
with labels mapped to multiple such objects.
line_kwargs (dict[str, Any]): Passed to Plotly's Figure.add_scatter method.
line_kwargs (dict | dict[str, dict] | Callable): Line style configuration.
Can be:
- A single dict applied to all lines
- A dict with keys "acoustic" and "optical" containing style dicts for each
mode type
- A callable taking (band_data, band_idx) and returning a style dict
Common style options include color, width, dash. Defaults to None.
branches (Sequence[str]): Branches to plot. Defaults to empty tuple, meaning all
branches are plotted.
branch_mode ("union" | "intersection"): Whether to plot union or intersection
Expand Down Expand Up @@ -251,29 +263,58 @@ def phonon_bands(
for bs_idx, (label, bs) in enumerate(band_structs.items()):
color = colors[bs_idx % len(colors)]
line_style = line_styles[bs_idx % len(line_styles)]
line_defaults = dict(color=color, width=1.5, dash=line_style)
# 1st bands determine x-axis scale (there are usually slight scale differences
# between bands)
first_bs = first_bs or bs
for branch_idx, branch in enumerate(bs.branches):

for branch in bs.branches:
if branch["name"] not in common_branches:
continue
start_idx = branch["start_index"]
end_idx = branch["end_index"] + 1 # Include the end point
# using the same first_bs x-axis for all band structures to avoid band
# shifting
end_idx = branch["end_index"] + 1
distances = first_bs.distance[start_idx:end_idx]
for band in range(bs.nb_bands):
frequencies = bs.bands[band][start_idx:end_idx]
# group traces for toggling and set legend name only for 1st band

for band_idx in range(bs.nb_bands):
frequencies = bs.bands[band_idx][start_idx:end_idx]
# Determine if this is an acoustic or optical band
is_acoustic = band_idx < 3
mode_type = "acoustic" if is_acoustic else "optical"

# Default line style
line_defaults = dict(
color=color, width=1.5 if is_acoustic else 1, dash=line_style
)
trace_name = label
existing_names = {trace.name for trace in fig.data}

# Apply line style based on line_kwargs type
if callable(line_kwargs):
# Pass band data and index to callback
custom_style = line_kwargs(frequencies, band_idx)
line_defaults |= custom_style
elif isinstance(line_kwargs, dict):
# check for custom line styles for one or both modes
if {"acoustic", "optical"} <= set(line_kwargs):
mode_styles = line_kwargs.get(mode_type, {}) # type: ignore[call-overload]
# use custom trace name if provided (needs to be popped before
# passed to line kwargs)
if mode_name := mode_styles.pop("name", None):
trace_name = mode_name
# don't show default trace name in legend if got custom name
existing_names.add(trace_name)

# Use mode-specific styles
line_defaults |= mode_styles
else: # Apply single style dict to all lines
line_defaults |= line_kwargs # type: ignore[arg-type]

is_new_name = trace_name not in existing_names
fig.add_scatter(
x=distances,
y=frequencies,
mode="lines",
line=line_defaults | line_kwargs,
legendgroup=label,
name=label,
showlegend=branch_idx == band == 0,
line=line_defaults,
legendgroup=trace_name,
name=trace_name,
showlegend=is_new_name,
**kwargs,
)

Expand Down Expand Up @@ -484,10 +525,10 @@ def phonon_bands_and_dos(
subplot_kwargs (dict[str, Any]): Passed to Plotly's make_subplots method.
Defaults to dict(shared_yaxes=True, column_widths=(0.8, 0.2),
horizontal_spacing=0.01).
all_line_kwargs (dict[str, Any]): Passed to trace.update for each in fig.data.
Modify line appearance for all traces. Defaults to None.
all_line_kwargs (dict[str, Any]): Passed to trace.update for each trace in
fig.data. Modifies line appearance for all traces. Defaults to None.
per_line_kwargs (dict[str, str]): Map of line labels to kwargs for trace.update.
Modify line appearance for specific traces. Defaults to None.
Modifies line appearance for specific traces. Defaults to None.
**kwargs: Passed to Plotly's Figure.add_scatter method.
Returns:
Expand Down
26 changes: 11 additions & 15 deletions pymatviz/ptable/ptable_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,15 +759,14 @@ def ptable_heatmap_splits_plotly(
"""
import plotly.colors

# Process input data
if isinstance(data, pd.Series):
if isinstance(data, pd.Series): # Process input data
data = data.to_dict()
elif isinstance(data, pd.DataFrame):
data = data.to_dict(orient="index")

# Find global min and max values for color scaling
all_values = [v for values in data.values() for v in values if not np.isnan(v)]
vmin, vmax = min(all_values), max(all_values)
all_values = np.array(list(data.values()), dtype=float)
cbar_min, cbar_max = np.nanmin(all_values), np.nanmax(all_values)

# Initialize figure with subplots
n_rows, n_cols = 10, 18
Expand Down Expand Up @@ -864,20 +863,17 @@ def create_section_coords(
xy_ref = dict(xref=f"x{subplot_key}", yref=f"y{subplot_key}")

# Get values and colors
values = data.get(symbol, [])
if len(values) == 0:
values = [1]
values = np.asarray(data.get(symbol, []), dtype=float)

# Create sections
sections = create_section_coords(len(values), orientation) # type: ignore[arg-type]
for idx, (xs, ys) in enumerate(sections):
color = (
plotly.colors.sample_colorscale(
colorscale, (values[idx] - vmin) / (vmax - vmin)
if len(values) <= idx or np.isnan(values[idx]):
color = nan_color
else:
color = plotly.colors.sample_colorscale(
colorscale, (values[idx] - cbar_min) / (cbar_max - cbar_min)
)[0]
if not np.isnan(values[idx])
else nan_color
)
fig.add_scatter(
x=xs,
y=ys,
Expand Down Expand Up @@ -998,8 +994,8 @@ def create_section_coords(
marker=dict(
colorscale=colorscale,
showscale=True,
cmin=vmin,
cmax=vmax,
cmin=cbar_min,
cmax=cbar_max,
colorbar=colorbar,
),
hoverinfo="none",
Expand Down
64 changes: 58 additions & 6 deletions tests/test_phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import re
from glob import glob
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import plotly.graph_objects as go
import pytest
Expand All @@ -17,8 +17,11 @@


if TYPE_CHECKING:
from collections.abc import Callable
from typing import Literal

import numpy as np

BandsDoses = dict[str, dict[str, PhononBands | PhononDos]]
bs_key, dos_key = "phonon_bandstructure", "phonon_dos"
# enable loading PhononDBDocParsed with @module set to uninstalled ffonons.dbs.phonondb
Expand Down Expand Up @@ -64,37 +67,86 @@ def phonon_doses() -> dict[str, PhononDos]:


@pytest.mark.parametrize(
("branches", "branch_mode"), [(["GAMMA-X", "X-U"], "union"), ((), "intersection")]
("branches", "branch_mode", "line_kwargs"),
[
# test original single dict behavior
(["GAMMA-X", "X-U"], "union", dict(width=2)),
# test empty tuple branches with intersection mode
((), "intersection", None),
# test separate acoustic/optical styling
(
["GAMMA-X"],
"union",
{
"acoustic": dict(width=2.5, dash="solid", name="Acoustic modes"),
"optical": dict(width=1, dash="dash", name="Optical modes"),
},
),
# test callable line_kwargs
((), "union", lambda _freqs, idx: dict(dash="solid" if idx < 3 else "dash")),
],
)
def test_phonon_bands(
phonon_bands_doses_mp_2758: BandsDoses,
branches: tuple[str, str],
branch_mode: pmv.phonons.BranchMode,
line_kwargs: dict[str, Any] | Callable[[np.ndarray, int], dict[str, Any]] | None,
) -> None:
# test single band structure
fig = pmv.phonon_bands(
phonon_bands_doses_mp_2758["bands"]["DFT"],
branch_mode=branch_mode,
branches=branches,
line_kwargs=line_kwargs,
)
assert isinstance(fig, go.Figure)
assert fig.layout.xaxis.title.text == "Wave Vector"
assert fig.layout.yaxis.title.text == "Frequency (THz)"
assert fig.layout.font.size == 16

x_labels: tuple[str, ...]
if branches == ():
x_labels = ["Γ", "X", "X", "U|K", "Γ", "Γ", "L", "L", "W", "W", "X"]
x_labels = ("Γ", "X", "X", "U|K", "Γ", "Γ", "L", "L", "W", "W", "X")
else:
x_labels = ["Γ", "X", "X", "U|K"]
assert list(fig.layout.xaxis.ticktext) == x_labels
x_labels = ("Γ", "U|K") if len(branches) == 1 else ("Γ", "X", "X", "U|K")
assert fig.layout.xaxis.ticktext == x_labels
assert fig.layout.xaxis.range is None
assert fig.layout.yaxis.range == pytest.approx((0, 5.36385427095))

# test line styling
if isinstance(line_kwargs, dict) and "acoustic" in line_kwargs:
# check that acoustic and optical modes have different styles
trace_names = {trace.name for trace in fig.data}
assert trace_names == {"", "Acoustic modes", "Optical modes"}

acoustic_traces = [t for t in fig.data if t.name == "Acoustic modes"]
optical_traces = [t for t in fig.data if t.name == "Optical modes"]

assert all(t.line.width == 2.5 for t in acoustic_traces)
assert all(t.line.dash == "solid" for t in acoustic_traces)
assert all(t.line.width == 1 for t in optical_traces)
assert all(t.line.dash == "dash" for t in optical_traces)
elif callable(line_kwargs):
# check that line width increases with band index
traces_by_width = sorted(fig.data, key=lambda t: t.line.width)
assert traces_by_width[0].line.width < traces_by_width[-1].line.width

# check acoustic/optical line style separation
acoustic_traces = [t for t in fig.data if t.line.dash == "solid"]
optical_traces = [t for t in fig.data if t.line.dash == "dash"]
assert len(acoustic_traces) == 18 # 6 segments for the first 3 bands
assert len(optical_traces) > 0 # should have some optical bands

# test dict of band structures
fig = pmv.phonon_bands(
phonon_bands_doses_mp_2758["bands"], branch_mode=branch_mode, branches=branches
phonon_bands_doses_mp_2758["bands"],
branch_mode=branch_mode,
branches=branches,
line_kwargs=line_kwargs,
)
assert isinstance(fig, go.Figure)
assert {trace.name for trace in fig.data} == {"DFT", "MACE"}
assert fig.layout.xaxis.ticktext == x_labels


def test_phonon_bands_raises(
Expand Down

0 comments on commit 0d24ace

Please sign in to comment.