diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
deleted file mode 100644
index 5399cb6f..00000000
--- a/.devcontainer/devcontainer.json
+++ /dev/null
@@ -1,18 +0,0 @@
-{
- "image": "mcr.microsoft.com/devcontainers/universal:2",
- "waitFor": "onCreateCommand",
- "updateContentCommand": "pip install -e .",
- "customizations": {
- "codespaces": {
- "openFiles": [
- "examples/matbench_dielectric_eda.ipynb",
- "examples/matbench_perovskites_eda.ipynb",
- "examples/mp_bimodal_e_form.ipynb",
- "examples/mprester_ptable.ipynb"
- ]
- },
- "vscode": {
- "extensions": ["ms-toolsai.jupyter", "ms-python.python"]
- }
- }
-}
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2b9690a7..f038d442 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.8.4
+ rev: v0.8.6
hooks:
- id: ruff
args: [--fix]
@@ -17,7 +17,7 @@ repos:
types_or: [python, jupyter]
- repo: https://github.com/pre-commit/mirrors-mypy
- rev: v1.14.0
+ rev: v1.14.1
hooks:
- id: mypy
additional_dependencies: [types-requests, types-PyYAML]
@@ -85,7 +85,7 @@ repos:
- id: pyright
- repo: https://github.com/python-jsonschema/check-jsonschema
- rev: 0.30.0
+ rev: 0.31.0
hooks:
- id: check-jsonschema
files: ^pymatviz/keys\.yml$
diff --git a/examples/diatomics/mace_pair_repulsion.py b/examples/diatomics/calc_mlip_diatomic_curves.py
similarity index 65%
rename from examples/diatomics/mace_pair_repulsion.py
rename to examples/diatomics/calc_mlip_diatomic_curves.py
index 0511b5b2..566bb416 100644
--- a/examples/diatomics/mace_pair_repulsion.py
+++ b/examples/diatomics/calc_mlip_diatomic_curves.py
@@ -19,6 +19,7 @@
from ase import Atoms
from ase.data import chemical_symbols
from mace.calculators import MACECalculator, mace_mp
+from tqdm import tqdm
if TYPE_CHECKING:
@@ -77,63 +78,82 @@ def calc_one_pair(
]
-def generate_homo_nuclear(calculator: MACECalculator, label: str) -> None:
+def calc_homo_diatomics(
+ calculator: MACECalculator, model_name: str
+) -> dict[str, list[float]]:
"""Generate potential energy data for homonuclear diatomic molecules.
Args:
calculator: MACECalculator instance.
- label: Label for the output file.
+ model_name: Name of the model for the output file.
"""
distances = np.linspace(0.1, 6.0, 119)
allowed_atomic_numbers = calculator.z_table.zs
# saving the results in a dict: "z0-z1" -> [energy] & saved the distances
results = {"distances": list(distances)}
# homo-nuclear diatomics
- for z0 in allowed_atomic_numbers:
+ pbar = tqdm(allowed_atomic_numbers, desc=f"Homo-nuclear diatomics for {model_name}")
+ for z0 in pbar:
elem1, elem2 = chemical_symbols[z0], chemical_symbols[z0]
formula = f"{elem1}-{elem2}"
- with timer(formula):
- results[formula] = calc_one_pair(elem1, elem2, calculator, distances)
- with lzma.open(f"homo-nuclear-{label}.json.xz", mode="wt") as file:
+ pbar.set_postfix_str(formula)
+ results[formula] = calc_one_pair(elem1, elem2, calculator, distances)
+
+ with lzma.open(f"homo-nuclear-{model_name}.json.xz", mode="wt") as file:
json.dump(results, file)
+ return results
+
-def generate_hetero_nuclear(z0: int, calculator: MACECalculator, label: str) -> None:
+def calc_hetero_diatomics(
+ z0: int, calculator: MACECalculator, model_name: str
+) -> dict[str, list[float]]:
"""Generate potential energy data for hetero-nuclear diatomic molecules with a
- fixed first element.
+ fixed first element and save to .json.xz file.
Args:
z0: Atomic number of the fixed first element.
calculator: MACECalculator instance.
- label: Label for the output file.
+ model_name: Name of the model for the output file.
+
+ Returns:
+ dict[str, list[float]]: Potential energy data for hetero-nuclear
+ diatomic molecules or None if the file already exists.
"""
- out_path = f"hetero-nuclear-diatomics-{z0}-{label}.json.xz"
+ out_path = f"hetero-nuclear-diatomics-{z0}-{model_name}.json.xz"
if os.path.isfile(out_path):
print(f"Skipping {z0} because {out_path} already exists")
- return
+ with lzma.open(out_path, mode="rt") as file:
+ return json.load(file)
+
print(f"Starting {z0}")
distances = np.linspace(0.1, 6.0, 119)
allowed_atomic_numbers = calculator.z_table.zs
# saving the results in a dict: "z0-z1" -> [energy] & saved the distances
results = {"distances": list(distances)}
# hetero-nuclear diatomics
- for z1 in allowed_atomic_numbers:
+ pbar = tqdm(
+ allowed_atomic_numbers, desc=f"Hetero-nuclear diatomics for {model_name}"
+ )
+ for z1 in pbar:
elem1, elem2 = chemical_symbols[z0], chemical_symbols[z1]
formula = f"{elem1}-{elem2}"
- with timer(formula):
- results[formula] = calc_one_pair(elem1, elem2, calculator, distances)
+ pbar.set_postfix_str(formula)
+ results[formula] = calc_one_pair(elem1, elem2, calculator, distances)
+
with lzma.open(out_path, mode="wt") as file:
json.dump(results, file)
+ return results
+
if __name__ == "__main__":
- # first homo-nuclear diatomics
- # for label in ("small", "medium", "large"):
- # calculator = mace_mp(model=label)
- # generate_homo_nuclear(calculator, f"mace-{label}")
-
- # then all hetero-nuclear diatomics
- for label in ("small", "medium", "large"):
- calculator = mace_mp(model=label)
- for z0 in calculator.z_table.zs:
- generate_hetero_nuclear(z0, calculator, f"mace-{label}")
+ mace_mpa_0_medium_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model"
+ for checkpoint_path in (mace_mpa_0_medium_url,):
+ calculator = mace_mp(model=checkpoint_path)
+ model_name = os.path.basename(checkpoint_path).split(".")[0]
+ calc_homo_diatomics(calculator, model_name)
+
+ # calculate all hetero-nuclear diatomics (takes a long time)
+ # for z0 in calculator.z_table.zs:
+ # calc_hetero_diatomics(z0, calculator, model_name)
diff --git a/examples/diatomics/make_vasp_diatomics_inputs.py b/examples/diatomics/make_vasp_diatomics_inputs.py
index 59326854..161ee916 100644
--- a/examples/diatomics/make_vasp_diatomics_inputs.py
+++ b/examples/diatomics/make_vasp_diatomics_inputs.py
@@ -22,6 +22,8 @@ def create_diatomic_inputs(
base_dir: str = "diatomic-calcs",
) -> None:
"""Create VASP input files for all pairs of elements at different separations.
+ The calculations can be run using run_vasp_diatomics.py, which will automatically
+ handle running calculations in sequence and copying WAVECAR files between distances.
Args:
distances (tuple[float, ...]): If tuple and length is 3 and last item is int,
diff --git a/examples/diatomics/plot_homo_diatomic_curves.py b/examples/diatomics/plot_homo_diatomic_curves.py
index 431ceab9..181a9ad5 100644
--- a/examples/diatomics/plot_homo_diatomic_curves.py
+++ b/examples/diatomics/plot_homo_diatomic_curves.py
@@ -17,26 +17,25 @@
# %% plot homo-nuclear and heteronuclear pair repulsion curves
-model = "medium" # "small"
-lzma_path = f"{module_dir}/homo-nuclear-mace-{model}.json.xz"
+model_name = ("mace-small", "mace-medium", "mace-mpa-0-medium")[-1]
+lzma_path = f"{module_dir}/homo-nuclear-{model_name}.json.xz"
with lzma.open(lzma_path, mode="rt") as file:
homo_nuc_diatomics = json.load(file)
# Convert data to format needed for plotting
# Each element in diatomic_curves should be a tuple of (x_values, y_values)
-diatomic_curves: dict[str, tuple[list[float], list[float]]] = {}
+diatomic_curves: dict[str, dict[str, tuple[np.ndarray, np.ndarray]]] = {}
distances = homo_nuc_diatomics.pop("distances", locals().get("distances"))
for symbol in homo_nuc_diatomics:
energies = np.asarray(homo_nuc_diatomics[symbol])
# Get element symbol from the key (format is "Z-Z" where Z is atomic number)
- elem_z = int(symbol.split("-")[0])
- elem_symbol = Element.from_Z(elem_z).symbol
+ elem_symbol = Element(symbol.split("-")[0]).symbol
# Shift energies so the energy at infinite separation (last point) is 0
shifted_energies = energies - energies[-1]
- diatomic_curves[elem_symbol] = distances, shifted_energies
+ diatomic_curves[elem_symbol] = {model_name: (distances, shifted_energies)}
fig = pmv.ptable_scatter_plotly(
@@ -44,18 +43,18 @@
mode="lines",
x_axis_kwargs=dict(range=[0, 6]),
y_axis_kwargs=dict(range=[-8, 15]),
- scale=1.2,
+ scale=1.5,
)
-fig.layout.title.update(text=f"MACE {model.title()} Diatomic Curves", x=0.4, y=0.8)
+fig.layout.title.update(text=f"{model_name.title()} Diatomic Curves", x=0.4, y=0.8)
fig.show()
-pmv.io.save_and_compress_svg(fig, f"homo-nuclear-mace-{model}")
+# pmv.io.save_and_compress_svg(fig, f"homo-nuclear-{model}")
# %% count number of elements with energies below E_TOO_LOW
E_TOO_LOW = -20
-for model in ("small", "medium"):
- lzma_path = f"{module_dir}/homo-nuclear-mace-{model}.json.xz"
+for model_name in ("mace-small", "mace-medium"):
+ lzma_path = f"{module_dir}/homo-nuclear-{model_name}.json.xz"
with lzma.open(lzma_path, mode="rt") as file:
homo_nuc_diatomics = json.load(file)
@@ -65,4 +64,6 @@
for key, y_vals in homo_nuc_diatomics.items()
}
n_lt_10 = sum(val < E_TOO_LOW for val in min_energies.values())
- print(f"diatomic curves for {model=} that dip below {E_TOO_LOW=} eV: {n_lt_10=}")
+ print(
+ f"diatomic curves for {model_name} that dip below {E_TOO_LOW=} eV: {n_lt_10=}"
+ )
diff --git a/pymatviz/ptable/ptable_plotly.py b/pymatviz/ptable/ptable_plotly.py
index 95944acb..a7e1f138 100644
--- a/pymatviz/ptable/ptable_plotly.py
+++ b/pymatviz/ptable/ptable_plotly.py
@@ -948,8 +948,7 @@ def create_section_coords(
font=dict(color="black", size=(font_size or 14) * scale),
)
fig.add_annotation(
- text=f"{display_symbol}",
- **symbol_defaults | xy_ref | (symbol_kwargs or {}),
+ text=display_symbol, **symbol_defaults | xy_ref | (symbol_kwargs or {})
)
# Add hover data
@@ -1380,8 +1379,7 @@ def ptable_scatter_plotly(
),
)
fig.add_annotation(
- text=f"{display_symbol}",
- **symbol_defaults | xy_ref | (symbol_kwargs or {}),
+ text=display_symbol, **symbol_defaults | xy_ref | (symbol_kwargs or {})
)
# Add custom annotations if provided
diff --git a/tests/ptable/plotly/test_ptable_scatter_plotly.py b/tests/ptable/plotly/test_ptable_scatter_plotly.py
index b826ab88..b759eedc 100644
--- a/tests/ptable/plotly/test_ptable_scatter_plotly.py
+++ b/tests/ptable/plotly/test_ptable_scatter_plotly.py
@@ -7,6 +7,7 @@
import numpy as np
import plotly.graph_objects as go
import pytest
+from pymatgen.core import Element
import pymatviz as pmv
@@ -81,7 +82,7 @@ def test_basic_scatter_plot(sample_data: SampleData) -> None:
# Check element annotations
symbol_annotations = [
- ann for ann in fig.layout.annotations if "" in str(ann.text)
+ ann for ann in fig.layout.annotations if ann.text in {e.name for e in Element}
]
assert len(symbol_annotations) == len(sample_data)
@@ -92,8 +93,8 @@ def test_basic_scatter_plot(sample_data: SampleData) -> None:
assert ann.x == 1
assert ann.y == 1
assert ann.font.size == 11 # default font size
- # Check that text is either Fe or O
- assert ann.text in ("Fe", "O")
+ # Check that text is either Fe or O
+ assert ann.text in ("Fe", "O")
@pytest.mark.parametrize(
@@ -193,7 +194,7 @@ def test_scaling(
assert fig.layout.height == 500 * scale
# Check font scaling
symbol_annotations = [
- ann for ann in fig.layout.annotations if "" in str(ann.text)
+ ann for ann in fig.layout.annotations if ann.text in {e.name for e in Element}
]
assert all(ann.font.size == 11 * scale for ann in symbol_annotations)