Skip to content

Commit

Permalink
Hetero-nuclear diatomics example with MACE (#259)
Browse files Browse the repository at this point in the history
* density_scatter_plotly add keyword colorbar_kwargs for customizing colorbar

- increase tolerance in `get_image_sites` in `helpers.py` from 0.02 to 0.03
- add new test for colorbar customization in `test_scatter.py`

* add plot_hetero_diatomic_curves.py with ptable and 3D line plots

set Atoms(pbc=False) in mace_pair_repulsion.py for safety (already the default)

* drop windows CI
  • Loading branch information
janosh committed Dec 15, 2024
1 parent 4b125b1 commit d6041b3
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
tests:
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
os: [ubuntu-latest]
split: [1, 2, 3, 4]
uses: janosh/workflows/.github/workflows/pytest.yml@main
with:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ build/
# auto-generated by jupyter nbconvert
examples/*.html
examples/**/*.json.gz
examples/**/*.json.xz
examples/dataset_exploration/**/*.pdf
gnome
17 changes: 16 additions & 1 deletion assets/scripts/structure_viz/structure_3d_plotly.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# %%
from matminer.datasets import load_dataset
from pymatgen.core import Lattice, Structure

import pymatviz as pmv
from pymatviz.enums import ElemColorScheme, Key, SiteCoords
Expand All @@ -16,4 +17,18 @@
hover_text=SiteCoords.cartesian_fractional,
)
fig.show()
pmv.io.save_and_compress_svg(fig, "matbench-phonons-structures-3d-plotly")
# pmv.io.save_and_compress_svg(fig, "matbench-phonons-structures-3d-plotly")


# %% BaTiO3 = https://materialsproject.org/materials/mp-5020
batio3 = Structure(
lattice=Lattice.cubic(4.0338),
species=["Ba", "Ti", "O", "O", "O"],
coords=[(0, 0, 0), (0.5, 0.5, 0.5), (0.5, 0.5, 0), (0.5, 0, 0.5), (0, 0.5, 0.5)],
)

fig = pmv.structure_3d_plotly(
batio3, show_unit_cell={"edge": dict(color="white", width=2)}
)
fig.show()
# pmv.io.save_and_compress_svg(fig, "bato3-structure-3d-plotly")
27 changes: 17 additions & 10 deletions examples/diatomics/mace_pair_repulsion.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Predict pair-repulsion curves for diatomic molecules with MACE-MP.
All credit for this code to Tamas Stenczel. Authored in https://github.com/stenczelt/MACE-MP-work
for MACE-MP paper https://arxiv.org/abs/2401.00096
Thanks to Tamas Stenczel who first did this type of PES smoothness and physicality
analysis in https://github.com/stenczelt/MACE-MP-work for the MACE-MP paper
https://arxiv.org/abs/2401.00096 (see fig. 56).
"""

# %%
from __future__ import annotations

import json
import lzma
import os
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -50,7 +52,7 @@ def generate_diatomics(
list[Atoms]: List of diatomic molecules.
"""
return [
Atoms(f"{elem1}{elem2}", positions=[[0, 0, 0], [dist, 0, 0]])
Atoms(f"{elem1}{elem2}", positions=[[0, 0, 0], [dist, 0, 0]], pbc=False)
for dist in distances
]

Expand All @@ -75,7 +77,7 @@ def calc_one_pair(
]


def generate_homonuclear(calculator: MACECalculator, label: str) -> None:
def generate_homo_nuclear(calculator: MACECalculator, label: str) -> None:
"""Generate potential energy data for homonuclear diatomic molecules.
Args:
Expand All @@ -96,7 +98,7 @@ def generate_homonuclear(calculator: MACECalculator, label: str) -> None:
json.dump(results, file)


def generate_fixed_any(z0: int, calculator: MACECalculator, label: str) -> None:
def generate_hetero_nuclear(z0: int, calculator: MACECalculator, label: str) -> None:
"""Generate potential energy data for hetero-nuclear diatomic molecules with a
fixed first element.
Expand All @@ -105,6 +107,11 @@ def generate_fixed_any(z0: int, calculator: MACECalculator, label: str) -> None:
calculator: MACECalculator instance.
label: Label for the output file.
"""
out_path = f"hetero-nuclear-diatomics-{z0}-{label}.json.xz"
if os.path.isfile(out_path):
print(f"Skipping {z0} because {out_path} already exists")
return
print(f"Starting {z0}")
distances = np.linspace(0.1, 6.0, 119)
allowed_atomic_numbers = calculator.z_table.zs
# saving the results in a dict: "z0-z1" -> [energy] & saved the distances
Expand All @@ -115,18 +122,18 @@ def generate_fixed_any(z0: int, calculator: MACECalculator, label: str) -> None:
formula = f"{elem1}-{elem2}"
with timer(formula):
results[formula] = calc_one_pair(elem1, elem2, calculator, distances)
with lzma.open(f"{label}-{z0}-X.json.xz", mode="wt") as file:
with lzma.open(out_path, mode="wt") as file:
json.dump(results, file)


if __name__ == "__main__":
# first homo-nuclear diatomics
for label in ("small", "medium", "large"):
calculator = mace_mp(model=label)
generate_homonuclear(calculator, f"mace-{label}")
# for label in ("small", "medium", "large"):
# calculator = mace_mp(model=label)
# generate_homo_nuclear(calculator, f"mace-{label}")

# then all hetero-nuclear diatomics
for label in ("small", "medium", "large"):
calculator = mace_mp(model=label)
for z0 in calculator.z_table.zs:
generate_fixed_any(z0, calculator, f"mace-{label}")
generate_hetero_nuclear(z0, calculator, f"mace-{label}")
139 changes: 139 additions & 0 deletions examples/diatomics/plot_hetero_diatomic_curves.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Plot MLIP pair repulsion curves in a periodic table layout and as 3D lines with
elements stacked in the z-direction.
"""

# %%
import json
import lzma
import os

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from pymatgen.core import Element

import pymatviz as pmv


pmv.set_plotly_template("pymatviz_dark")
module_dir = os.path.dirname(__file__)
__date__ = "2024-03-31"

model_name, z1 = "mace-small", 5
elem1 = Element.from_Z(z1)
lzma_path = f"{module_dir}/hetero-nuclear-diatomics-{z1}-{model_name}.json.xz"
with lzma.open(lzma_path, mode="rt") as file:
hetero_nuc_diatomics = json.load(file)

x_range, y_range = [0, 6], [-8, 15]


# %% plot homo-nuclear and heteronuclear pair repulsion curves
# Convert data to format needed for plotting
# Each element in diatomic_curves should be a tuple of (x_values, y_values)
diatomic_curves: dict[str, tuple[list[float], list[float]]] = {}
distances = hetero_nuc_diatomics.pop("distances", locals().get("distances"))

for elem_pair in hetero_nuc_diatomics:
energies = np.asarray(hetero_nuc_diatomics[elem_pair])
# Get element symbol from the key (format is "Z-Z" where Z is atomic number)
elem2 = elem_pair.split("-")[1]

# Shift energies so the energy at infinite separation (last point) is 0
shifted_energies = energies - energies[-1]

diatomic_curves[elem2] = distances, shifted_energies


# %%
fig = pmv.ptable_scatter_plotly(
diatomic_curves,
mode="lines",
x_axis_kwargs=dict(range=x_range),
y_axis_kwargs=dict(range=y_range),
scale=1.2,
)

title = f"<b>{model_name.title()}</b> Heteronuclear Diatomic Curves for <b>{elem1.long_name}</b>" # noqa: E501
fig.layout.title.update(text=title, x=0.4, y=0.8)
fig.show()
pmv.io.save_and_compress_svg(fig, f"hetero-nuclear-{model_name}-{elem1}")


# %%
fig = go.Figure()
# Sort elements by atomic number for consistent z-ordering
sorted_elements = sorted(diatomic_curves, key=lambda symbol: Element(symbol).Z)

# Find global min/max energy for consistent coloring
min_energies: dict[str, float] = {}
filtered_distances: dict[str, np.ndarray] = {}
filtered_energies: dict[str, np.ndarray] = {}

# First pass: collect min energies and filter data
for elem2 in sorted_elements:
distances, energies = map(np.asarray, diatomic_curves[elem2])
mask = distances >= 0.5 # Filter data points where distance >= 0.5
filtered_distances[elem2] = distances[mask]
filtered_energies[elem2] = energies[mask]
min_energies[elem2] = min(energies[mask])

min_energy_global = min(min_energies.values())

# Create a trace for each element
for idx, elem2 in enumerate(sorted_elements):
distances = filtered_distances[elem2]
energies = filtered_energies[elem2]
z_pos = Element(elem2).Z # Use atomic number for z-position

# Create a constant z array for the line
z_vals = [z_pos] * len(distances)

# Normalize the minimum energy for this element to get color
min_energy = min_energies[elem2]
# Use log scale for better color distribution
normalized_energy = np.log(-min_energy + 1) / np.log(-min_energy_global + 1)
line_color = px.colors.sample_colorscale("Reds", [normalized_energy])[0]

fig.add_scatter3d(
x=distances,
y=energies,
z=z_vals,
name=f"{elem1}-{elem2} (min={min_energy:.1f} eV)",
mode="lines",
line=dict(width=4, color=line_color),
showlegend=True,
)

# Create 4-fold staggered pattern for element labels
x_offset = (idx % 4) * 0.3 # 4 positions, spaced by 0.3 Å

fig.add_scatter3d(
x=[distances[-1] - x_offset], # Last x point with staggered offset
y=[energies[-1] + 0.1], # Last y point
z=[z_pos],
mode="text",
text=[elem2],
textfont=dict(size=20, color=line_color),
showlegend=False,
)


title = f"<b>{model_name.title()}</b> Heteronuclear Diatomic Curves for <b>{elem1.long_name}</b>" # noqa: E501
fig.layout.title = dict(text=title, x=0.5, y=0.98)
fig.layout.scene = dict(
xaxis_title="Distance (Å)",
yaxis_title="Energy (eV)",
zaxis_title="Atomic Number (Z)",
camera=dict(
eye=dict(x=1.3, y=1.3, z=0),
up=dict(x=0, y=1, z=0),
),
aspectratio=dict(x=1, y=1, z=3), # Make plot wider by adjusting aspect ratio
xaxis=dict(range=x_range),
yaxis=dict(range=y_range),
)
fig.layout.update(showlegend=False, margin=dict(l=0, r=0, t=0, b=0))

fig.show()
pmv.io.save_and_compress_svg(fig, f"hetero-nuclear-{model_name}-{elem1}-lines-3d")
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
"""Plot MLIP pair repulsion curves in a periodic table layout.
Thanks to Tamas Stenczel who first did this type of PES smoothness and physicality
analysis in https://github.com/stenczelt/MACE-MP-work for the MACE-MP paper
https://arxiv.org/abs/2401.00096
"""
"""Plot MLIP pair repulsion curves in a periodic table layout."""

# %%
import json
Expand Down
6 changes: 6 additions & 0 deletions pymatviz/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def density_scatter_plotly(
n_bins: int | None | Literal[False] = None,
bin_counts_col: str | None = None,
facet_col: str | None = None,
colorbar_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> go.Figure:
"""Scatter plot colored by density using plotly backend.
Expand Down Expand Up @@ -201,6 +202,8 @@ def density_scatter_plotly(
facet_col (str | None, optional): Column name to use for creating faceted
subplots. If provided, the plot will be split into multiple subplots based
on unique values in this column. Defaults to None.
colorbar_kwargs (dict, optional): Passed to fig.layout.coloraxis.colorbar.
E.g. dict(thickness=15) to make colorbar thinner.
**kwargs: Passed to px.scatter().
Returns:
Expand Down Expand Up @@ -253,6 +256,9 @@ def density_scatter_plotly(
**kwargs,
)

colorbar_defaults = dict(thickness=15)
fig.layout.coloraxis.colorbar.update(colorbar_defaults | (colorbar_kwargs or {}))

if log_density:
_update_colorbar_for_log_density(fig, color_vals, bin_counts_col)

Expand Down
2 changes: 1 addition & 1 deletion pymatviz/structure_viz/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _angles_to_rotation_matrix(


def get_image_sites(
site: PeriodicSite, lattice: Lattice, tol: float = 0.02
site: PeriodicSite, lattice: Lattice, tol: float = 0.03
) -> np.ndarray:
"""Get images for a given site in a lattice.
Expand Down
15 changes: 15 additions & 0 deletions tests/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,18 @@ def test_density_scatter_plotly_facet_hover_template() -> None:
for trace in fig.data:
assert "total_bill" in trace.hovertemplate
assert "tip" in trace.hovertemplate


def test_density_scatter_plotly_colorbar_kwargs() -> None:
colorbar_kwargs = {"title": "Custom Title", "thickness": 30, "len": 0.8, "x": 1.1}

fig = density_scatter_plotly(
df=DF_TIPS, x="total_bill", y="tip", colorbar_kwargs=colorbar_kwargs
)

# Check that colorbar properties were applied correctly
actual_colorbar = fig.layout.coloraxis.colorbar
assert actual_colorbar.title.text == "Custom Title"
assert actual_colorbar.thickness == 30
assert actual_colorbar.len == 0.8
assert actual_colorbar.x == 1.1

0 comments on commit d6041b3

Please sign in to comment.