Skip to content

Commit

Permalink
Hetero-nuclear diatomics example with MACE (#259)
Browse files Browse the repository at this point in the history
* density_scatter_plotly add keyword colorbar_kwargs for customizing colorbar

- increase tolerance in `get_image_sites` in `helpers.py` from 0.02 to 0.03
- add new test for colorbar customization in `test_scatter.py`

* add plot_hetero_diatomic_curves.py with ptable and 3D line plots

set Atoms(pbc=False) in mace_pair_repulsion.py for safety (already the default)

* drop windows CI

* add make_vasp_diatomics_inputs.py to create VASP input files for calculating MP compatible diatomic curves
  • Loading branch information
janosh committed Dec 15, 2024
1 parent 4b125b1 commit 706a73e
Show file tree
Hide file tree
Showing 10 changed files with 289 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
tests:
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
os: [ubuntu-latest]
split: [1, 2, 3, 4]
uses: janosh/workflows/.github/workflows/pytest.yml@main
with:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ build/
# auto-generated by jupyter nbconvert
examples/*.html
examples/**/*.json.gz
examples/**/*.json.xz
examples/dataset_exploration/**/*.pdf
gnome
17 changes: 16 additions & 1 deletion assets/scripts/structure_viz/structure_3d_plotly.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# %%
from matminer.datasets import load_dataset
from pymatgen.core import Lattice, Structure

import pymatviz as pmv
from pymatviz.enums import ElemColorScheme, Key, SiteCoords
Expand All @@ -16,4 +17,18 @@
hover_text=SiteCoords.cartesian_fractional,
)
fig.show()
pmv.io.save_and_compress_svg(fig, "matbench-phonons-structures-3d-plotly")
# pmv.io.save_and_compress_svg(fig, "matbench-phonons-structures-3d-plotly")


# %% BaTiO3 = https://materialsproject.org/materials/mp-5020
batio3 = Structure(
lattice=Lattice.cubic(4.0338),
species=["Ba", "Ti", "O", "O", "O"],
coords=[(0, 0, 0), (0.5, 0.5, 0.5), (0.5, 0.5, 0), (0.5, 0, 0.5), (0, 0.5, 0.5)],
)

fig = pmv.structure_3d_plotly(
batio3, show_unit_cell={"edge": dict(color="white", width=2)}
)
fig.show()
# pmv.io.save_and_compress_svg(fig, "bato3-structure-3d-plotly")
27 changes: 17 additions & 10 deletions examples/diatomics/mace_pair_repulsion.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Predict pair-repulsion curves for diatomic molecules with MACE-MP.
All credit for this code to Tamas Stenczel. Authored in https://github.com/stenczelt/MACE-MP-work
for MACE-MP paper https://arxiv.org/abs/2401.00096
Thanks to Tamas Stenczel who first did this type of PES smoothness and physicality
analysis in https://github.com/stenczelt/MACE-MP-work for the MACE-MP paper
https://arxiv.org/abs/2401.00096 (see fig. 56).
"""

# %%
from __future__ import annotations

import json
import lzma
import os
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -50,7 +52,7 @@ def generate_diatomics(
list[Atoms]: List of diatomic molecules.
"""
return [
Atoms(f"{elem1}{elem2}", positions=[[0, 0, 0], [dist, 0, 0]])
Atoms(f"{elem1}{elem2}", positions=[[0, 0, 0], [dist, 0, 0]], pbc=False)
for dist in distances
]

Expand All @@ -75,7 +77,7 @@ def calc_one_pair(
]


def generate_homonuclear(calculator: MACECalculator, label: str) -> None:
def generate_homo_nuclear(calculator: MACECalculator, label: str) -> None:
"""Generate potential energy data for homonuclear diatomic molecules.
Args:
Expand All @@ -96,7 +98,7 @@ def generate_homonuclear(calculator: MACECalculator, label: str) -> None:
json.dump(results, file)


def generate_fixed_any(z0: int, calculator: MACECalculator, label: str) -> None:
def generate_hetero_nuclear(z0: int, calculator: MACECalculator, label: str) -> None:
"""Generate potential energy data for hetero-nuclear diatomic molecules with a
fixed first element.
Expand All @@ -105,6 +107,11 @@ def generate_fixed_any(z0: int, calculator: MACECalculator, label: str) -> None:
calculator: MACECalculator instance.
label: Label for the output file.
"""
out_path = f"hetero-nuclear-diatomics-{z0}-{label}.json.xz"
if os.path.isfile(out_path):
print(f"Skipping {z0} because {out_path} already exists")
return
print(f"Starting {z0}")
distances = np.linspace(0.1, 6.0, 119)
allowed_atomic_numbers = calculator.z_table.zs
# saving the results in a dict: "z0-z1" -> [energy] & saved the distances
Expand All @@ -115,18 +122,18 @@ def generate_fixed_any(z0: int, calculator: MACECalculator, label: str) -> None:
formula = f"{elem1}-{elem2}"
with timer(formula):
results[formula] = calc_one_pair(elem1, elem2, calculator, distances)
with lzma.open(f"{label}-{z0}-X.json.xz", mode="wt") as file:
with lzma.open(out_path, mode="wt") as file:
json.dump(results, file)


if __name__ == "__main__":
# first homo-nuclear diatomics
for label in ("small", "medium", "large"):
calculator = mace_mp(model=label)
generate_homonuclear(calculator, f"mace-{label}")
# for label in ("small", "medium", "large"):
# calculator = mace_mp(model=label)
# generate_homo_nuclear(calculator, f"mace-{label}")

# then all hetero-nuclear diatomics
for label in ("small", "medium", "large"):
calculator = mace_mp(model=label)
for z0 in calculator.z_table.zs:
generate_fixed_any(z0, calculator, f"mace-{label}")
generate_hetero_nuclear(z0, calculator, f"mace-{label}")
92 changes: 92 additions & 0 deletions examples/diatomics/make_vasp_diatomics_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import warnings
from collections.abc import Sequence

import numpy as np
from pymatgen.core import Element, Lattice, Structure
from pymatgen.io.vasp.inputs import Kpoints
from pymatgen.io.vasp.sets import BadInputSetWarning, MPStaticSet

from pymatviz.utils import ROOT


# silence verbose pymatgen warnings
warnings.filterwarnings("ignore", category=BadInputSetWarning)
warnings.filterwarnings("ignore", message="No Pauling electronegativity for")


def create_diatomic_inputs(
distances: Sequence[float] = (1, 10, 40),
box_size: tuple[float, float, float] = (10, 10, 20),
elements: Sequence[str] | set[str] = (),
base_dir: str = "diatomic-calcs",
) -> None:
"""Create VASP input files for all pairs of elements at different separations.
Args:
distances (tuple[float, ...]): If tuple and length is 3 and last item is int,
values will be passed to np.logspace as (min_dist, max_dist, n_points).
Else will be used as a list of distances to sample. Defaults to (1, 10, 40).
box_size (tuple[float, float, float]): Size of the cubic box in Å.
Defaults to (10, 10, 20).
elements (set[str]): Elements to include. Defaults to all elements.
base_dir (str): Base directory to store the input files. Defaults to
"diatomic-calcs".
"""
if (
isinstance(distances, tuple)
and len(distances) == 3
and isinstance(distances[-1], int)
):
min_dist, max_dist, n_points = distances
distances = np.logspace(np.log10(min_dist), np.log10(max_dist), n_points)
box = Lattice.orthorhombic(*box_size)

if elements == ():
# skip superheavy elements (most have no POTCARs and are radioactive)
skip_elements = set(
"Am At Bk Cf Cm Es Fr Fm Md No Lr Rf Po Db Sg Bh Hs Mt Ds Cn Nh Fl Mc Lv "
"Ra Rg Ts Og".split()
)
elements = sorted({*map(str, Element)} - set(skip_elements))

os.makedirs(base_dir, exist_ok=True)
# Loop over all pairs of elements
for elem1 in elements:
elem1_dir = f"{base_dir}/{elem1}"
os.makedirs(elem1_dir, exist_ok=True)

for elem2 in elements:
elem2_dir = f"{elem1_dir}/{elem1}-{elem2}"
os.makedirs(elem2_dir, exist_ok=True)

for dist in distances:
# Center the atoms in the box
center = np.array(box_size) / 2
coords_1 = center - np.array([0, 0, dist / 2])
coords_2 = center + np.array([0, 0, dist / 2])

# Create the structure and input set
dimer = Structure(
box, [elem1, elem2], (coords_1, coords_2), coords_are_cartesian=True
)

# Create directory for this distance
dist_dir = f"{elem2_dir}/{dist=:.3f}"
os.makedirs(dist_dir, exist_ok=True)

# Generate VASP input files
vasp_input_set = MPStaticSet(
dimer,
user_kpoints_settings=Kpoints(), # sample a single k-point at Gamma
# disable symmetry since spglib in VASP sometimes detects false
# symmetries in dimers and fails
user_incar_settings={"ISYM": 0, "LH5": True},
)
vasp_input_set.write_input(dist_dir)

print(f"Created inputs for {elem1}-{elem2} pair")


if __name__ == "__main__":
create_diatomic_inputs(base_dir=f"{ROOT}/tmp/diatomic-calcs") # noqa: S108
139 changes: 139 additions & 0 deletions examples/diatomics/plot_hetero_diatomic_curves.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Plot MLIP pair repulsion curves in a periodic table layout and as 3D lines with
elements stacked in the z-direction.
"""

# %%
import json
import lzma
import os

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from pymatgen.core import Element

import pymatviz as pmv


pmv.set_plotly_template("pymatviz_dark")
module_dir = os.path.dirname(__file__)
__date__ = "2024-03-31"

model_name, z1 = "mace-small", 5
elem1 = Element.from_Z(z1)
lzma_path = f"{module_dir}/hetero-nuclear-diatomics-{z1}-{model_name}.json.xz"
with lzma.open(lzma_path, mode="rt") as file:
hetero_nuc_diatomics = json.load(file)

x_range, y_range = [0, 6], [-8, 15]


# %% plot homo-nuclear and heteronuclear pair repulsion curves
# Convert data to format needed for plotting
# Each element in diatomic_curves should be a tuple of (x_values, y_values)
diatomic_curves: dict[str, tuple[list[float], list[float]]] = {}
distances = hetero_nuc_diatomics.pop("distances", locals().get("distances"))

for elem_pair in hetero_nuc_diatomics:
energies = np.asarray(hetero_nuc_diatomics[elem_pair])
# Get element symbol from the key (format is "Z-Z" where Z is atomic number)
elem2 = elem_pair.split("-")[1]

# Shift energies so the energy at infinite separation (last point) is 0
shifted_energies = energies - energies[-1]

diatomic_curves[elem2] = distances, shifted_energies


# %%
fig = pmv.ptable_scatter_plotly(
diatomic_curves,
mode="lines",
x_axis_kwargs=dict(range=x_range),
y_axis_kwargs=dict(range=y_range),
scale=1.2,
)

title = f"<b>{model_name.title()}</b> Heteronuclear Diatomic Curves for <b>{elem1.long_name}</b>" # noqa: E501
fig.layout.title.update(text=title, x=0.4, y=0.8)
fig.show()
pmv.io.save_and_compress_svg(fig, f"hetero-nuclear-{model_name}-{elem1}")


# %%
fig = go.Figure()
# Sort elements by atomic number for consistent z-ordering
sorted_elements = sorted(diatomic_curves, key=lambda symbol: Element(symbol).Z)

# Find global min/max energy for consistent coloring
min_energies: dict[str, float] = {}
filtered_distances: dict[str, np.ndarray] = {}
filtered_energies: dict[str, np.ndarray] = {}

# First pass: collect min energies and filter data
for elem2 in sorted_elements:
distances, energies = map(np.asarray, diatomic_curves[elem2])
mask = distances >= 0.5 # Filter data points where distance >= 0.5
filtered_distances[elem2] = distances[mask]
filtered_energies[elem2] = energies[mask]
min_energies[elem2] = min(energies[mask])

min_energy_global = min(min_energies.values())

# Create a trace for each element
for idx, elem2 in enumerate(sorted_elements):
distances = filtered_distances[elem2]
energies = filtered_energies[elem2]
z_pos = Element(elem2).Z # Use atomic number for z-position

# Create a constant z array for the line
z_vals = [z_pos] * len(distances)

# Normalize the minimum energy for this element to get color
min_energy = min_energies[elem2]
# Use log scale for better color distribution
normalized_energy = np.log(-min_energy + 1) / np.log(-min_energy_global + 1)
line_color = px.colors.sample_colorscale("Reds", [normalized_energy])[0]

fig.add_scatter3d(
x=distances,
y=energies,
z=z_vals,
name=f"{elem1}-{elem2} (min={min_energy:.1f} eV)",
mode="lines",
line=dict(width=4, color=line_color),
showlegend=True,
)

# Create 4-fold staggered pattern for element labels
x_offset = (idx % 4) * 0.3 # 4 positions, spaced by 0.3 Å

fig.add_scatter3d(
x=[distances[-1] - x_offset], # Last x point with staggered offset
y=[energies[-1] + 0.1], # Last y point
z=[z_pos],
mode="text",
text=[elem2],
textfont=dict(size=20, color=line_color),
showlegend=False,
)


title = f"<b>{model_name.title()}</b> Heteronuclear Diatomic Curves for <b>{elem1.long_name}</b>" # noqa: E501
fig.layout.title = dict(text=title, x=0.5, y=0.98)
fig.layout.scene = dict(
xaxis_title="Distance (Å)",
yaxis_title="Energy (eV)",
zaxis_title="Atomic Number (Z)",
camera=dict(
eye=dict(x=1.3, y=1.3, z=0),
up=dict(x=0, y=1, z=0),
),
aspectratio=dict(x=1, y=1, z=3), # Make plot wider by adjusting aspect ratio
xaxis=dict(range=x_range),
yaxis=dict(range=y_range),
)
fig.layout.update(showlegend=False, margin=dict(l=0, r=0, t=0, b=0))

fig.show()
pmv.io.save_and_compress_svg(fig, f"hetero-nuclear-{model_name}-{elem1}-lines-3d")
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
"""Plot MLIP pair repulsion curves in a periodic table layout.
Thanks to Tamas Stenczel who first did this type of PES smoothness and physicality
analysis in https://github.com/stenczelt/MACE-MP-work for the MACE-MP paper
https://arxiv.org/abs/2401.00096
"""
"""Plot MLIP pair repulsion curves in a periodic table layout."""

# %%
import json
Expand Down
6 changes: 6 additions & 0 deletions pymatviz/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def density_scatter_plotly(
n_bins: int | None | Literal[False] = None,
bin_counts_col: str | None = None,
facet_col: str | None = None,
colorbar_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> go.Figure:
"""Scatter plot colored by density using plotly backend.
Expand Down Expand Up @@ -201,6 +202,8 @@ def density_scatter_plotly(
facet_col (str | None, optional): Column name to use for creating faceted
subplots. If provided, the plot will be split into multiple subplots based
on unique values in this column. Defaults to None.
colorbar_kwargs (dict, optional): Passed to fig.layout.coloraxis.colorbar.
E.g. dict(thickness=15) to make colorbar thinner.
**kwargs: Passed to px.scatter().
Returns:
Expand Down Expand Up @@ -253,6 +256,9 @@ def density_scatter_plotly(
**kwargs,
)

colorbar_defaults = dict(thickness=15)
fig.layout.coloraxis.colorbar.update(colorbar_defaults | (colorbar_kwargs or {}))

if log_density:
_update_colorbar_for_log_density(fig, color_vals, bin_counts_col)

Expand Down
Loading

0 comments on commit 706a73e

Please sign in to comment.