Skip to content

Commit

Permalink
calculate MACE-MPA-0 diatomic curves
Browse files Browse the repository at this point in the history
- modify plot_homo_diatomic_curves.py to support multiple models in single ptable_scatter_plotly()
- update tests in test_ptable_scatter_plotly.py to reflect changes in element annotations
- bump pre-commit hooks for ruff (v0.8.6), mypy (v1.14.1), and check-jsonschema (0.31.0).
- remove obsolete .devcontainer/devcontainer.json file.
  • Loading branch information
janosh committed Jan 9, 2025
1 parent eb12217 commit 4dc5f1e
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 65 deletions.
18 changes: 0 additions & 18 deletions .devcontainer/devcontainer.json

This file was deleted.

6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.4
rev: v0.8.6
hooks:
- id: ruff
args: [--fix]
Expand All @@ -17,7 +17,7 @@ repos:
types_or: [python, jupyter]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.14.0
rev: v1.14.1
hooks:
- id: mypy
additional_dependencies: [types-requests, types-PyYAML]
Expand Down Expand Up @@ -85,7 +85,7 @@ repos:
- id: pyright

- repo: https://github.com/python-jsonschema/check-jsonschema
rev: 0.30.0
rev: 0.31.0
hooks:
- id: check-jsonschema
files: ^pymatviz/keys\.yml$
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ase import Atoms
from ase.data import chemical_symbols
from mace.calculators import MACECalculator, mace_mp
from tqdm import tqdm


if TYPE_CHECKING:
Expand Down Expand Up @@ -77,63 +78,82 @@ def calc_one_pair(
]


def generate_homo_nuclear(calculator: MACECalculator, label: str) -> None:
def calc_homo_diatomics(
calculator: MACECalculator, model_name: str
) -> dict[str, list[float]]:
"""Generate potential energy data for homonuclear diatomic molecules.
Args:
calculator: MACECalculator instance.
label: Label for the output file.
model_name: Name of the model for the output file.
"""
distances = np.linspace(0.1, 6.0, 119)
allowed_atomic_numbers = calculator.z_table.zs
# saving the results in a dict: "z0-z1" -> [energy] & saved the distances
results = {"distances": list(distances)}
# homo-nuclear diatomics
for z0 in allowed_atomic_numbers:
pbar = tqdm(allowed_atomic_numbers, desc=f"Homo-nuclear diatomics for {model_name}")
for z0 in pbar:
elem1, elem2 = chemical_symbols[z0], chemical_symbols[z0]
formula = f"{elem1}-{elem2}"
with timer(formula):
results[formula] = calc_one_pair(elem1, elem2, calculator, distances)
with lzma.open(f"homo-nuclear-{label}.json.xz", mode="wt") as file:
pbar.set_postfix_str(formula)
results[formula] = calc_one_pair(elem1, elem2, calculator, distances)

with lzma.open(f"homo-nuclear-{model_name}.json.xz", mode="wt") as file:
json.dump(results, file)

return results


def generate_hetero_nuclear(z0: int, calculator: MACECalculator, label: str) -> None:
def calc_hetero_diatomics(
z0: int, calculator: MACECalculator, model_name: str
) -> dict[str, list[float]]:
"""Generate potential energy data for hetero-nuclear diatomic molecules with a
fixed first element.
fixed first element and save to .json.xz file.
Args:
z0: Atomic number of the fixed first element.
calculator: MACECalculator instance.
label: Label for the output file.
model_name: Name of the model for the output file.
Returns:
dict[str, list[float]]: Potential energy data for hetero-nuclear
diatomic molecules or None if the file already exists.
"""
out_path = f"hetero-nuclear-diatomics-{z0}-{label}.json.xz"
out_path = f"hetero-nuclear-diatomics-{z0}-{model_name}.json.xz"
if os.path.isfile(out_path):
print(f"Skipping {z0} because {out_path} already exists")
return
with lzma.open(out_path, mode="rt") as file:
return json.load(file)

print(f"Starting {z0}")
distances = np.linspace(0.1, 6.0, 119)
allowed_atomic_numbers = calculator.z_table.zs
# saving the results in a dict: "z0-z1" -> [energy] & saved the distances
results = {"distances": list(distances)}
# hetero-nuclear diatomics
for z1 in allowed_atomic_numbers:
pbar = tqdm(
allowed_atomic_numbers, desc=f"Hetero-nuclear diatomics for {model_name}"
)
for z1 in pbar:
elem1, elem2 = chemical_symbols[z0], chemical_symbols[z1]
formula = f"{elem1}-{elem2}"
with timer(formula):
results[formula] = calc_one_pair(elem1, elem2, calculator, distances)
pbar.set_postfix_str(formula)
results[formula] = calc_one_pair(elem1, elem2, calculator, distances)

with lzma.open(out_path, mode="wt") as file:
json.dump(results, file)

return results


if __name__ == "__main__":
# first homo-nuclear diatomics
# for label in ("small", "medium", "large"):
# calculator = mace_mp(model=label)
# generate_homo_nuclear(calculator, f"mace-{label}")

# then all hetero-nuclear diatomics
for label in ("small", "medium", "large"):
calculator = mace_mp(model=label)
for z0 in calculator.z_table.zs:
generate_hetero_nuclear(z0, calculator, f"mace-{label}")
mace_mpa_0_medium_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model"
for checkpoint_path in (mace_mpa_0_medium_url,):
calculator = mace_mp(model=checkpoint_path)
model_name = os.path.basename(checkpoint_path).split(".")[0]
calc_homo_diatomics(calculator, model_name)

# calculate all hetero-nuclear diatomics (takes a long time)
# for z0 in calculator.z_table.zs:
# calc_hetero_diatomics(z0, calculator, model_name)
2 changes: 2 additions & 0 deletions examples/diatomics/make_vasp_diatomics_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def create_diatomic_inputs(
base_dir: str = "diatomic-calcs",
) -> None:
"""Create VASP input files for all pairs of elements at different separations.
The calculations can be run using run_vasp_diatomics.py, which will automatically
handle running calculations in sequence and copying WAVECAR files between distances.
Args:
distances (tuple[float, ...]): If tuple and length is 3 and last item is int,
Expand Down
25 changes: 13 additions & 12 deletions examples/diatomics/plot_homo_diatomic_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,44 @@


# %% plot homo-nuclear and heteronuclear pair repulsion curves
model = "medium" # "small"
lzma_path = f"{module_dir}/homo-nuclear-mace-{model}.json.xz"
model_name = ("mace-small", "mace-medium", "mace-mpa-0-medium")[-1]
lzma_path = f"{module_dir}/homo-nuclear-{model_name}.json.xz"
with lzma.open(lzma_path, mode="rt") as file:
homo_nuc_diatomics = json.load(file)

# Convert data to format needed for plotting
# Each element in diatomic_curves should be a tuple of (x_values, y_values)
diatomic_curves: dict[str, tuple[list[float], list[float]]] = {}
diatomic_curves: dict[str, dict[str, tuple[np.ndarray, np.ndarray]]] = {}
distances = homo_nuc_diatomics.pop("distances", locals().get("distances"))

for symbol in homo_nuc_diatomics:
energies = np.asarray(homo_nuc_diatomics[symbol])
# Get element symbol from the key (format is "Z-Z" where Z is atomic number)
elem_z = int(symbol.split("-")[0])
elem_symbol = Element.from_Z(elem_z).symbol
elem_symbol = Element(symbol.split("-")[0]).symbol

# Shift energies so the energy at infinite separation (last point) is 0
shifted_energies = energies - energies[-1]

diatomic_curves[elem_symbol] = distances, shifted_energies
diatomic_curves[elem_symbol] = {model_name: (distances, shifted_energies)}


fig = pmv.ptable_scatter_plotly(
diatomic_curves,
mode="lines",
x_axis_kwargs=dict(range=[0, 6]),
y_axis_kwargs=dict(range=[-8, 15]),
scale=1.2,
scale=1.5,
)

fig.layout.title.update(text=f"MACE {model.title()} Diatomic Curves", x=0.4, y=0.8)
fig.layout.title.update(text=f"{model_name.title()} Diatomic Curves", x=0.4, y=0.8)
fig.show()
pmv.io.save_and_compress_svg(fig, f"homo-nuclear-mace-{model}")
# pmv.io.save_and_compress_svg(fig, f"homo-nuclear-{model}")


# %% count number of elements with energies below E_TOO_LOW
E_TOO_LOW = -20
for model in ("small", "medium"):
lzma_path = f"{module_dir}/homo-nuclear-mace-{model}.json.xz"
for model_name in ("mace-small", "mace-medium"):
lzma_path = f"{module_dir}/homo-nuclear-{model_name}.json.xz"
with lzma.open(lzma_path, mode="rt") as file:
homo_nuc_diatomics = json.load(file)

Expand All @@ -65,4 +64,6 @@
for key, y_vals in homo_nuc_diatomics.items()
}
n_lt_10 = sum(val < E_TOO_LOW for val in min_energies.values())
print(f"diatomic curves for {model=} that dip below {E_TOO_LOW=} eV: {n_lt_10=}")
print(
f"diatomic curves for {model_name} that dip below {E_TOO_LOW=} eV: {n_lt_10=}"
)
6 changes: 2 additions & 4 deletions pymatviz/ptable/ptable_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,7 @@ def create_section_coords(
font=dict(color="black", size=(font_size or 14) * scale),
)
fig.add_annotation(
text=f"<b>{display_symbol}</b>",
**symbol_defaults | xy_ref | (symbol_kwargs or {}),
text=display_symbol, **symbol_defaults | xy_ref | (symbol_kwargs or {})
)

# Add hover data
Expand Down Expand Up @@ -1380,8 +1379,7 @@ def ptable_scatter_plotly(
),
)
fig.add_annotation(
text=f"<b>{display_symbol}</b>",
**symbol_defaults | xy_ref | (symbol_kwargs or {}),
text=display_symbol, **symbol_defaults | xy_ref | (symbol_kwargs or {})
)

# Add custom annotations if provided
Expand Down
9 changes: 5 additions & 4 deletions tests/ptable/plotly/test_ptable_scatter_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import plotly.graph_objects as go
import pytest
from pymatgen.core import Element

import pymatviz as pmv

Expand Down Expand Up @@ -81,7 +82,7 @@ def test_basic_scatter_plot(sample_data: SampleData) -> None:

# Check element annotations
symbol_annotations = [
ann for ann in fig.layout.annotations if "<b>" in str(ann.text)
ann for ann in fig.layout.annotations if ann.text in {e.name for e in Element}
]
assert len(symbol_annotations) == len(sample_data)

Expand All @@ -92,8 +93,8 @@ def test_basic_scatter_plot(sample_data: SampleData) -> None:
assert ann.x == 1
assert ann.y == 1
assert ann.font.size == 11 # default font size
# Check that text is either <b>Fe</b> or <b>O</b>
assert ann.text in ("<b>Fe</b>", "<b>O</b>")
# Check that text is either Fe or O
assert ann.text in ("Fe", "O")


@pytest.mark.parametrize(
Expand Down Expand Up @@ -193,7 +194,7 @@ def test_scaling(
assert fig.layout.height == 500 * scale
# Check font scaling
symbol_annotations = [
ann for ann in fig.layout.annotations if "<b>" in str(ann.text)
ann for ann in fig.layout.annotations if ann.text in {e.name for e in Element}
]
assert all(ann.font.size == 11 * scale for ann in symbol_annotations)

Expand Down

0 comments on commit 4dc5f1e

Please sign in to comment.