"
for t in reg_targets
)
@@ -747,7 +1135,7 @@ def slide(body: str) -> str:
"
Descriptor: invertible KMD-1d, computed on the fly (descriptor → composition via KMD.inverse).
"
"
Continual finetuning: tasks added one at a time; AE head always on.
"
f"
Rehearsal: learned tasks keep only {self.config.replay_ratio:.0%} of their training targets per step.
"
- "
Inverse design: optimize the latent toward regression targets + quasicrystal probability, then decode a composition.
"
+ "
Inverse design: from the highest-QC training compositions, optimize the latent to raise quasicrystal probability (primary) with low formation energy & high κ_lat (secondary), then decode a composition.
mean P(QC) over seeds: {qc_before:.3f} → {qc_after:.3f} (round-trip)
"
+ + "
Secondary — regression targets (in latent)
"
+ inv_lines
- + f"
Quasicrystal probability (round-trip): {qc_before:.3f} → {qc_after:.3f}
"
- + "
Decoded compositions (KMD.inverse)
"
+ + "
Decoded compositions (KMD.inverse)
"
+ decoded
+ "
"
),
slide(
"
Takeaways
"
"
One shared encoder serves regression, kernel regression, classification & reconstruction across 4 inorganic datasets.
"
- "
5% rehearsal keeps well-learned tasks (density, formation energy, material type) near their peak while new heads are added.
"
+ "
5% rehearsal keeps well-learned tasks (Density, Formation Energy) near their peak while new heads are added.
"
"
Latent-space optimization with regression + classification conditions hits the targets and decodes back to real compositions via the invertible KMD descriptor.
"
"
"
),
@@ -841,13 +1230,20 @@ def _load_toml(path: Path) -> dict[str, Any]:
return tomllib.loads(Path(path).read_text(encoding="utf-8"))
-def _parse_args(argv: list[str] | None = None) -> ContinualRehearsalConfig:
+def _parse_args(argv: list[str] | None = None) -> tuple[ContinualRehearsalConfig, argparse.Namespace]:
parser = argparse.ArgumentParser(description="Continual rehearsal + inverse-design demo.")
parser.add_argument("--config-file", type=Path, default=None)
parser.add_argument("--output-dir", type=Path, default=None)
parser.add_argument("--sample-per-dataset", type=int, default=None)
parser.add_argument("--max-epochs-per-step", type=int, default=None)
parser.add_argument("--accelerator", type=str, default=None)
+ parser.add_argument(
+ "--inverse-only",
+ type=Path,
+ default=None,
+ metavar="CKPT",
+ help="Skip training; load a final_model.pt checkpoint and run only the inverse-design stage.",
+ )
args = parser.parse_args(argv)
data = _load_toml(args.config_file) if args.config_file else {}
@@ -871,11 +1267,16 @@ def _parse_args(argv: list[str] | None = None) -> ContinualRehearsalConfig:
logger.warning(f"Ignoring unknown config key '{key}'.")
continue
kwargs[key] = Path(value) if key in path_fields and value is not None else value
- return ContinualRehearsalConfig(**kwargs)
+ return ContinualRehearsalConfig(**kwargs), args
def main(argv: list[str] | None = None) -> None:
- ContinualRehearsalRunner(_parse_args(argv)).run()
+ config, args = _parse_args(argv)
+ runner = ContinualRehearsalRunner(config)
+ if args.inverse_only is not None:
+ runner.run_inverse_only(args.inverse_only)
+ else:
+ runner.run()
if __name__ == "__main__":
diff --git a/src/foundation_model/scripts/continual_rehearsal_demo_test.py b/src/foundation_model/scripts/continual_rehearsal_demo_test.py
new file mode 100644
index 0000000..4d33591
--- /dev/null
+++ b/src/foundation_model/scripts/continual_rehearsal_demo_test.py
@@ -0,0 +1,242 @@
+# Copyright 2025 TsumiNa.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Tests for the configuration / pure helpers in :mod:`continual_rehearsal_demo`.
+
+The runner's training loop is exercised end-to-end by smoke runs (it needs real parquet
+data + a GPU/MPS device), so this file targets the *units that don't need either*:
+
+* ``ContinualRehearsalConfig`` validation in ``__post_init__``.
+* The element-system seed dedup / explicit-append logic.
+* The ``_plot_kr_sequences`` regression (the function used to raise ``NameError`` when
+ ``comps`` was empty — see the PR #18 code review).
+* The material-type 5→3 class merge map shape.
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+import numpy as np
+import pytest
+
+from foundation_model.scripts.continual_rehearsal_common import plot_kr_sequences
+from foundation_model.scripts.continual_rehearsal_demo import (
+ DEFAULT_SEQUENCE,
+ MATERIAL_TYPE_CLASSES,
+ MATERIAL_TYPE_DISPLAY_ORDER,
+ QC_CLASSES,
+ TASK_SPECS,
+ ContinualRehearsalConfig,
+ ContinualRehearsalRunner,
+ _MATERIAL_TYPE_MERGE,
+)
+
+
+# --- ContinualRehearsalConfig ---------------------------------------------------------------
+
+
+def _base_kwargs(**overrides):
+ """Minimal valid config kwargs — only the fields without sane defaults need to be filled in.
+
+ Paths are dummies; the validators inside ``__post_init__`` don't touch the filesystem.
+ """
+ defaults = {
+ "qc_data_path": Path("/tmp/qc.parquet"),
+ "qc_preprocessing_path": None,
+ "superconductor_path": Path("/tmp/sc.parquet"),
+ "magnetic_path": Path("/tmp/mag.parquet"),
+ "phonix_path": Path("/tmp/ph.parquet"),
+ "output_dir": Path("/tmp/out"),
+ "task_sequence": list(DEFAULT_SEQUENCE),
+ }
+ defaults.update(overrides)
+ return defaults
+
+
+def test_config_default_post_init_accepts_default_sequence():
+ cfg = ContinualRehearsalConfig(**_base_kwargs())
+ assert cfg.task_sequence == list(DEFAULT_SEQUENCE)
+ # Every task in the default sequence is registered in TASK_SPECS — would otherwise raise.
+ assert set(cfg.task_sequence) <= set(TASK_SPECS)
+
+
+def test_config_rejects_unknown_task():
+ with pytest.raises(ValueError, match="Unknown task"):
+ ContinualRehearsalConfig(**_base_kwargs(task_sequence=["density", "this_task_does_not_exist"]))
+
+
+def test_config_rejects_bad_replay_ratio():
+ with pytest.raises(ValueError, match="replay_ratio must be in"):
+ ContinualRehearsalConfig(**_base_kwargs(replay_ratio=-0.1))
+ with pytest.raises(ValueError, match="replay_ratio must be in"):
+ ContinualRehearsalConfig(**_base_kwargs(replay_ratio=1.5))
+
+
+def test_config_rejects_reg_task_target_length_mismatch():
+ with pytest.raises(ValueError, match="inverse_reg_tasks and inverse_reg_targets"):
+ ContinualRehearsalConfig(
+ **_base_kwargs(inverse_reg_tasks=["formation_energy", "klat"], inverse_reg_targets=[-2.0])
+ )
+
+
+def test_config_rejects_unknown_seed_strategy():
+ with pytest.raises(ValueError, match="inverse_seed_strategy must be"):
+ ContinualRehearsalConfig(**_base_kwargs(inverse_seed_strategy="oracle"))
+
+
+def test_config_explicit_strategy_requires_compositions():
+ with pytest.raises(ValueError, match="requires inverse_seed_compositions"):
+ ContinualRehearsalConfig(**_base_kwargs(inverse_seed_strategy="explicit", inverse_seed_compositions=[]))
+
+
+def test_config_rejects_nonpositive_n_seeds():
+ """``inverse_n_seeds <= 0`` would silently return only the explicit-append entries; fail
+ loudly at config-load time so the misuse points at the TOML, not at a downstream shape error."""
+ with pytest.raises(ValueError, match="inverse_n_seeds must be > 0"):
+ ContinualRehearsalConfig(**_base_kwargs(inverse_n_seeds=0))
+ with pytest.raises(ValueError, match="inverse_n_seeds must be > 0"):
+ ContinualRehearsalConfig(**_base_kwargs(inverse_n_seeds=-3))
+
+
+def test_config_rejects_ae_align_scale_out_of_range():
+ """``ae_align_scale ∉ [0, 1]`` is rejected by the model at runtime; we catch it earlier so
+ the error message points at the TOML field rather than a deep model backtrace."""
+ with pytest.raises(ValueError, match="inverse_ae_align_scale must be in"):
+ ContinualRehearsalConfig(**_base_kwargs(inverse_ae_align_scale=-0.1))
+ with pytest.raises(ValueError, match="inverse_ae_align_scale must be in"):
+ ContinualRehearsalConfig(**_base_kwargs(inverse_ae_align_scale=1.5))
+
+
+# --- material-type 5→3 merge map ------------------------------------------------------------
+
+
+def test_material_type_merge_covers_all_5_classes_and_3_targets():
+ # Source labels are 0..4 (5 classes); merged labels are 0..2 (3 classes: AC / QC / others).
+ assert set(_MATERIAL_TYPE_MERGE.keys()) == {0, 1, 2, 3, 4}
+ assert set(_MATERIAL_TYPE_MERGE.values()) == {0, 1, 2}
+ # QC label index must agree with QC_CLASSES.
+ assert QC_CLASSES == [_MATERIAL_TYPE_MERGE[1]] == [_MATERIAL_TYPE_MERGE[3]]
+
+
+def test_material_type_class_names_and_display_order_consistent():
+ # 3 merged classes, both lists carry exactly those names.
+ assert len(MATERIAL_TYPE_CLASSES) == 3
+ assert sorted(MATERIAL_TYPE_CLASSES) == sorted(MATERIAL_TYPE_DISPLAY_ORDER)
+
+
+# --- element-system dedup (classmethod, no runner state needed) ------------------------------
+
+
+def test_dedupe_by_element_system_keeps_first_per_set():
+ # First occurrence per element-set wins. Mg-Al-Cu appears twice; only the first survives.
+ candidates = [
+ "Mg12 Cu3 Ni3", # {Mg, Cu, Ni}
+ "Mg2 Cu1 Ni1", # {Mg, Cu, Ni} ← duplicate set, dropped
+ "Y8.7 Mg34.6 Zn56.8", # {Y, Mg, Zn}
+ "Y1 Mg1 Zn1", # {Y, Mg, Zn} ← duplicate set, dropped
+ "Au65 Ga20 Gd15", # {Au, Ga, Gd}
+ ]
+ out = ContinualRehearsalRunner._dedupe_by_element_system(candidates, n=10)
+ assert out == ["Mg12 Cu3 Ni3", "Y8.7 Mg34.6 Zn56.8", "Au65 Ga20 Gd15"]
+
+
+def test_dedupe_by_element_system_respects_n_cap():
+ candidates = [
+ "Mg1", # {Mg}
+ "Al1", # {Al}
+ "Cu1", # {Cu}
+ "Ni1", # {Ni}
+ ]
+ out = ContinualRehearsalRunner._dedupe_by_element_system(candidates, n=2)
+ assert out == ["Mg1", "Al1"]
+
+
+def test_dedupe_by_element_system_ignores_empty_strings():
+ out = ContinualRehearsalRunner._dedupe_by_element_system(["", "Mg1", " ", "Al1"], n=5)
+ assert out == ["Mg1", "Al1"]
+
+
+def test_merge_strategy_and_explicit_drops_strategy_seeds_sharing_element_system():
+ """When an explicit-append seed (Au-Ga-Gd) shares an element-system with a strategy seed,
+ the *strategy* seed is dropped — the explicit-append wins because it's the user's deliberate
+ pick. Mirrors ``_select_seeds._finalise``'s contract end-to-end."""
+ strategy = [
+ "Mg12 Cu3 Ni3", # {Mg, Cu, Ni} — kept
+ "Au70 Ga20 Gd10", # {Au, Ga, Gd} — *dropped*, overlaps the explicit append
+ "Y8 Mg34 Zn58", # {Y, Mg, Zn} — kept
+ "Al6 Co1 Cu3", # {Al, Co, Cu} — kept
+ ]
+ appended = ["Au65 Ga20 Gd15"] # {Au, Ga, Gd}
+ out = ContinualRehearsalRunner._merge_strategy_and_explicit(strategy, appended, n_strategy=3)
+ assert out == ["Mg12 Cu3 Ni3", "Y8 Mg34 Zn58", "Al6 Co1 Cu3", "Au65 Ga20 Gd15"]
+
+
+def test_merge_strategy_and_explicit_caps_strategy_after_dedup():
+ """``n_strategy`` is the post-dedup cap on the strategy portion. Total output length is
+ ``n_strategy + len(appended)`` — the appended entries are always preserved."""
+ strategy = ["Mg1 Cu1", "Al1 Fe1", "Zn1 Cd1"]
+ appended = ["Au1 Ga1"]
+ out = ContinualRehearsalRunner._merge_strategy_and_explicit(strategy, appended, n_strategy=2)
+ assert out == ["Mg1 Cu1", "Al1 Fe1", "Au1 Ga1"]
+
+
+def test_merge_strategy_and_explicit_handles_empty_appended():
+ """No explicit-append entries ⇒ just truncates the (already-deduped) strategy list."""
+ out = ContinualRehearsalRunner._merge_strategy_and_explicit(
+ ["Mg1 Cu1", "Al1 Fe1", "Zn1 Cd1"], [], n_strategy=2
+ )
+ assert out == ["Mg1 Cu1", "Al1 Fe1"]
+
+
+def test_element_system_extracts_symbols_ignoring_amounts():
+ # Static-method shape: returns a frozenset of element symbols, no stoichiometry leaks through.
+ es = ContinualRehearsalRunner._element_system("Au65 Ga20 Gd15")
+ assert es == frozenset({"Au", "Ga", "Gd"})
+ # Multi-digit / float amounts handled the same way.
+ es = ContinualRehearsalRunner._element_system("Mg36.3 Al32 Zn31.7")
+ assert es == frozenset({"Mg", "Al", "Zn"})
+
+
+# --- plot_kr_sequences empty-comps regression (P1 bug from PR #18 code review) -------------
+# The function is now in ``continual_rehearsal_common`` (PR #18 refactor); pre-refactor it lived
+# as a bound method on each runner and the empty-comps NameError silently shipped on the demo
+# side for several PRs. These tests pin the post-refactor behaviour from both call sites.
+
+
+def test_plot_kr_sequences_handles_empty_comps_without_crashing(tmp_path):
+ """Empty ``comps`` used to raise ``NameError: line_true`` from ``fig.legend(...)``. Now it
+ logs a warning and returns early; no file is written."""
+ out_dir = tmp_path / "step01_density"
+ out_dir.mkdir()
+ plot_kr_sequences(
+ comps=[],
+ t_list=[],
+ true_parts=[],
+ pred=np.array([]),
+ task_name="dos_density",
+ step_dir=out_dir,
+ title="DOS density",
+ )
+ assert not (out_dir / "dos_density_sequences.png").exists()
+
+
+def test_plot_kr_sequences_renders_when_comps_nonempty(tmp_path):
+ """Smoke: one composition's sequence renders a PNG with no errors."""
+ import torch
+
+ out_dir = tmp_path / "step01_density"
+ out_dir.mkdir()
+ t = torch.linspace(0.0, 1.0, 8)
+ true_part = np.linspace(0.0, 1.0, 8)
+ pred = np.linspace(0.05, 0.95, 8)
+ plot_kr_sequences(
+ comps=["Mg1 Cu1"],
+ t_list=[t],
+ true_parts=[true_part],
+ pred=pred,
+ task_name="dos_density",
+ step_dir=out_dir,
+ title="DOS density",
+ )
+ assert (out_dir / "dos_density_sequences.png").exists()
diff --git a/src/foundation_model/scripts/continual_rehearsal_full.py b/src/foundation_model/scripts/continual_rehearsal_full.py
new file mode 100644
index 0000000..648b556
--- /dev/null
+++ b/src/foundation_model/scripts/continual_rehearsal_full.py
@@ -0,0 +1,2855 @@
+# Copyright 2025 TsumiNa.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Continual multi-task rehearsal + inverse-design — **full / formal** run.
+
+A larger, "formal training" sibling of :mod:`continual_rehearsal_demo`. It covers the complete
+inorganic task catalogue (24 supervised tasks + always-on autoencoder) over four datasets and,
+relative to the demo, adds:
+
+* **Tiered rehearsal** — a configurable high-replay set (the inverse-design-relevant tail tasks,
+ e.g. formation_energy / magnetic_moment / tc / klat / material_type) keeps ``replay_ratio_high``
+ of its labels when replayed as an *old* task, while every other learned task keeps ``replay_ratio``.
+* **EarlyStopping** on ``val_final_loss`` (full data ⇒ ``max_epochs_per_step`` is just a ceiling).
+* **Per-stage raw artifacts** — at every step, every active head's test ``(composition, true, pred)``
+ is dumped to parquet (kernel heads additionally store the ``t`` series), alongside a per-task
+ ``_metrics.json`` and a per-step ``checkpoint.pt`` (model state + active-task metadata).
+ Everything lives under ``training/stepNN_/`` so any intermediate stage can be revisited.
+* **Final checkpoint** — ``training/final_model.pt`` + ``training/final_model_taskconfigs.json``.
+* **Multiple inverse-design scenarios** — the same final model is optimized through **eight
+ PR #18 paths per scenario** (3 latent ``ae_align_scale`` sweep points + 5 composition configs:
+ strict seed / blended seed / alloy palette / alloy + low diversity / random init), with
+ results, an 8-path comparison plot, an element-frequency heatmap (discovered elements
+ highlighted in bold orange), and `targets.json` written to ``inverse_design//``.
+* **Slide-prep deliverables (no auto PPT / HTML)** — the runner emits ``SLIDE_PREP.md`` (9-section
+ outline + raw-data pointers), ``ANALYSIS.md`` (long-form English narrative), ``README.md``
+ (directory index), and per-scenario ``comparison.png`` / ``element_frequency_heatmap.png``
+ inside ``inverse_design//``. The three scenarios are first-class results — the runner
+ does **not** promote any single scenario as the headline (that was a demo-only convention).
+ The slide author builds the deck externally; every figure is reproducible from the raw arrays
+ without retraining.
+
+No layers are frozen: every step jointly trains the shared encoder + all active task heads
+(``freeze_shared_encoder=False``, per-task ``freeze_parameters=False``). The "continual" behaviour
+comes purely from the rehearsal mask, not from freezing.
+
+Run:
+ ./run_continual_rehearsal_full.sh samples/continual_rehearsal_full_config.toml
+ python -m foundation_model.scripts.continual_rehearsal_full --config-file
+"""
+
+from __future__ import annotations
+
+import argparse
+import datetime as _datetime
+import json
+import re
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any
+
+import matplotlib
+
+matplotlib.use("Agg") # headless
+
+import joblib # type: ignore[import-untyped]
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import torch
+from lightning import Trainer, seed_everything
+from lightning.pytorch.callbacks import Callback, EarlyStopping
+from loguru import logger
+from sklearn.metrics import accuracy_score, f1_score, mean_absolute_error, r2_score # type: ignore[import-untyped]
+from torch.utils.data import DataLoader
+
+from foundation_model.data.composition_sources import normalize_composition
+from foundation_model.data.datamodule import CompoundDataModule
+from foundation_model.models.flexible_multi_task_model import FlexibleMultiTaskModel
+from foundation_model.models.model_config import (
+ ClassificationTaskConfig,
+ KernelRegressionTaskConfig,
+ MLPEncoderConfig,
+ OptimizerConfig,
+ RegressionTaskConfig,
+)
+from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS, KMD, element_features, formula_to_composition
+
+# Shared dump/plot helpers live in the common module. The material_type constants and the
+# scatter colour are consumed *inside* the common functions, so this file no longer needs to
+# import them directly — they used to be imported here because the bound-method plot helpers
+# inlined them.
+from foundation_model.scripts.continual_rehearsal_common import (
+ dump_kr_predictions,
+ dump_metrics,
+ dump_predictions,
+ plot_confusion,
+ plot_element_frequency_heatmap,
+ plot_kr_sequences,
+ plot_parity,
+)
+from foundation_model.scripts.continual_rehearsal_demo import (
+ _PALETTE,
+ _apply_plot_style,
+ _as_float_array,
+ _composition_key,
+ _init_kernels,
+)
+from foundation_model.scripts.paper_inverse_comparison import (
+ _emit_trajectory_outputs,
+ _path_slug,
+ _plot_qc_vs_reg_scatter,
+ _plot_seed_to_optimized_mapping,
+)
+
+# --- Task catalogue ----------------------------------------------------------
+# source: dataset the task's targets come from. qc columns are pre-normalized; raw NEMAD/phonix
+# regression columns are log1p + z-scored (train-only stats) + clipped at load time.
+TASK_SPECS: dict[str, dict[str, Any]] = {
+ # --- qc: regression (9) ---
+ "density": {"source": "qc", "kind": "reg", "column": "Density (normalized)"},
+ "efermi": {"source": "qc", "kind": "reg", "column": "Efermi (normalized)"},
+ "final_energy": {"source": "qc", "kind": "reg", "column": "Final energy per atom (normalized)"},
+ "formation_energy": {"source": "qc", "kind": "reg", "column": "Formation energy per atom (normalized)"},
+ "total_magnetization": {"source": "qc", "kind": "reg", "column": "Total magnetization (normalized)"},
+ "volume": {"source": "qc", "kind": "reg", "column": "Volume (normalized)"},
+ "dielectric_total": {"source": "qc", "kind": "reg", "column": "Dielectric total (normalized)"},
+ "dielectric_ionic": {"source": "qc", "kind": "reg", "column": "Dielectric ionic (normalized)"},
+ "dielectric_electronic": {"source": "qc", "kind": "reg", "column": "Dielectric electronic (normalized)"},
+ # --- qc: kernel regression (7) ---
+ "dos_density": {"source": "qc", "kind": "kr", "column": "DOS density (normalized)", "t_column": "DOS energy"},
+ "electrical_resistivity": {
+ "source": "qc",
+ "kind": "kr",
+ "column": "Electrical resistivity (normalized)",
+ "t_column": "Electrical resistivity (T/K)",
+ },
+ "power_factor": {
+ "source": "qc",
+ "kind": "kr",
+ "column": "Power factor (normalized)",
+ "t_column": "Power factor (T/K)",
+ },
+ "seebeck": {
+ "source": "qc",
+ "kind": "kr",
+ "column": "Seebeck coefficient (normalized)",
+ "t_column": "Seebeck coefficient (T/K)",
+ },
+ "thermal_conductivity": {
+ "source": "qc",
+ "kind": "kr",
+ "column": "Thermal conductivity (normalized)",
+ "t_column": "Thermal conductivity (T/K)",
+ },
+ "zt": {"source": "qc", "kind": "kr", "column": "ZT (normalized)", "t_column": "ZT (T/K)"},
+ "magnetic_susceptibility": {
+ "source": "qc",
+ "kind": "kr",
+ "column": "Magnetic susceptibility (normalized)",
+ "t_column": "Magnetic susceptibility (T/K)",
+ },
+ # --- qc: classification (1) ---
+ "material_type": {"source": "qc", "kind": "clf", "column": "Material type (label)", "num_classes": 3},
+ # --- phonix-db: regression (2) ---
+ "kp": {"source": "phonix", "kind": "reg", "column": "kp[W/mK]"},
+ "klat": {"source": "phonix", "kind": "reg", "column": "klat[W/mK]"},
+ # --- NEMAD superconductor: regression (1) ---
+ "tc": {"source": "superconductor", "kind": "reg", "column": "Transition temperature[K]"},
+ # --- NEMAD magnetic: regression (4) ---
+ "magnetic_moment": {"source": "magnetic", "kind": "reg", "column": "Magnetic moment[μB/f.u.]"},
+ "magnetization": {"source": "magnetic", "kind": "reg", "column": "Magnetization[A·m²/mol]"},
+ "curie": {"source": "magnetic", "kind": "reg", "column": "Curie temperature[K]"},
+ "neel": {"source": "magnetic", "kind": "reg", "column": "Neel temperature[K]"},
+}
+
+# Raw (non-qc) regression targets span orders of magnitude; log1p-compress, z-score, clip tails.
+_RAW_TARGET_CLIP = 5.0
+
+# Default 24-task sequence: 19 free-order tasks, then the fixed inverse-design tail (kept freshest).
+DEFAULT_SEQUENCE = [
+ # qc regression (free)
+ "density",
+ "efermi",
+ "final_energy",
+ "total_magnetization",
+ "volume",
+ "dielectric_total",
+ "dielectric_ionic",
+ "dielectric_electronic",
+ # qc kernel regression (free)
+ "dos_density",
+ "electrical_resistivity",
+ "power_factor",
+ "seebeck",
+ "thermal_conductivity",
+ "zt",
+ "magnetic_susceptibility",
+ # magnetic + phonix (free)
+ "magnetization",
+ "curie",
+ "neel",
+ "kp",
+ # fixed tail (inverse-design heads, freshest at the end)
+ "formation_energy",
+ "magnetic_moment",
+ "tc",
+ "klat",
+ "material_type",
+]
+# The inverse-design-relevant tail: kept at the higher replay ratio when replayed as an old task.
+DEFAULT_FIXED_TAIL = ["formation_energy", "magnetic_moment", "tc", "klat", "material_type"]
+
+# 5 fine labels merged into AC / QC / others (index == merged class id).
+# ``MATERIAL_TYPE_CLASSES`` / ``MATERIAL_TYPE_DISPLAY_ORDER`` now live in
+# :mod:`continual_rehearsal_common` and are imported above; the runner-specific merge map stays.
+_MATERIAL_TYPE_MERGE = {0: 0, 2: 0, 1: 1, 3: 1, 4: 2}
+QC_CLASSES = [1] # merged quasicrystal class index — inverse-design classification objective.
+
+# --- Presentation -------------------------------------------------------------
+TASK_DISPLAY: dict[str, str] = {
+ "density": "Density",
+ "efermi": "E_Fermi",
+ "final_energy": "Final Energy / atom",
+ "formation_energy": "Formation Energy",
+ "total_magnetization": "Total Magnetization",
+ "volume": "Volume",
+ "dielectric_total": "Dielectric (total)",
+ "dielectric_ionic": "Dielectric (ionic)",
+ "dielectric_electronic": "Dielectric (electronic)",
+ "dos_density": "DOS Density",
+ "electrical_resistivity": "Electrical Resistivity",
+ "power_factor": "Power Factor",
+ "seebeck": "Seebeck Coefficient",
+ "thermal_conductivity": "Thermal Conductivity",
+ "zt": "ZT",
+ "magnetic_susceptibility": "Magnetic Susceptibility",
+ "material_type": "Material Type",
+ "kp": "Phonon Conductivity (κₚ)",
+ "klat": "Lattice Conductivity (κ_lat)",
+ "tc": "Critical Temperature (Tc)",
+ "magnetic_moment": "Magnetic Moment",
+ "magnetization": "Magnetization",
+ "curie": "Curie Temperature",
+ "neel": "Néel Temperature",
+}
+SOURCE_DISPLAY = {
+ "qc": "qc_ac_te_mp",
+ "phonix": "phonix-db",
+ "superconductor": "NEMAD superconductor",
+ "magnetic": "NEMAD magnetic",
+}
+KIND_LABEL = {"reg": "regression", "kr": "kernel regression", "clf": "classification"}
+
+# --- Inverse design — paths + element constraints ----------------------------
+# 48-element alloy palette for the composition-space ``C-alloy`` path (plan §5, extended). Covers
+# classic i-QC / d-QC formers (Mg–Zn–RE, Al–Mn, Al–Cu–Fe, Al–Ni–Co, Au–Ga–RE …), the Sc–Zn
+# 4th-period TMs, the Y–Cd 5th-period TMs (Tc excluded for radioactivity), the full Hf–Pt 5d TM
+# row (added 2026-05 — broadens the heavy-TM coverage for the composition search and lets the
+# optimiser reach refractory / noble-metal i-QC families like Hf–Pd / Ta–Ni / Ir-based phases),
+# Au (Au–Ga–Ln seeds need it), group 13/14 enablers (B/Al/Ga/In/Tl, Si/Ge), and the 12 easy
+# lanthanides. Pm/Tc are radioactive; Tm/Lu are scarce. The three explicit-append Au–Ga–Ln seeds
+# (Gd/Tb/Dy) all fit in this palette.
+ALLOY_PALETTE: list[str] = [
+ "Mg",
+ "Ca",
+ "B",
+ "Al",
+ "Ga",
+ "In",
+ "Tl",
+ "Si",
+ "Ge",
+ "Sc",
+ "Ti",
+ "V",
+ "Cr",
+ "Mn",
+ "Fe",
+ "Co",
+ "Ni",
+ "Cu",
+ "Zn",
+ "Y",
+ "Zr",
+ "Nb",
+ "Mo",
+ "Ru",
+ "Rh",
+ "Pd",
+ "Ag",
+ "Cd",
+ # 5d transition metals (Hf–Pt). Added 2026-05 to extend the previous 41-element palette;
+ # placed between Cd (end of 5th-period TMs) and Au so the 6th-period TM block is contiguous.
+ "Hf",
+ "Ta",
+ "W",
+ "Re",
+ "Os",
+ "Ir",
+ "Pt",
+ "Au",
+ "La",
+ "Ce",
+ "Pr",
+ "Nd",
+ "Sm",
+ "Eu",
+ "Gd",
+ "Tb",
+ "Dy",
+ "Ho",
+ "Er",
+ "Yb",
+]
+
+# Inverse-design comparison configurations, one row per box in ``comparison.png``. Mirrors the
+# PR #18 demo's ``paper_inverse_comparison.py``: a 3-point ``ae_align_scale`` sweep on the latent
+# side (failure α=0 / mid α=0.25 / max α=1.0) plus five composition configurations that layer
+# blend, palette and diversity-scale knobs against a random-init control. The ``allowed`` field
+# uses the sentinel ``"__palette__"`` to refer to ``config.inverse_composition_allowed_elements``
+# (the 48-element ``ALLOY_PALETTE`` by default); every other field is fixed at the module level so
+# the comparison is a stable plan-§5 ablation across runs.
+_PALETTE_SENTINEL = "__palette__"
+INVERSE_PATH_CONFIGS: list[dict[str, Any]] = [
+ {"key": "latent_align0p0", "label": "latent α=0", "method": "latent", "ae_align_scale": 0.0},
+ {"key": "latent_align0p25", "label": "latent α=0.25", "method": "latent", "ae_align_scale": 0.25},
+ {"key": "latent_align1p0", "label": "latent α=1", "method": "latent", "ae_align_scale": 1.0},
+ {
+ "key": "comp_seed",
+ "label": "comp (seed)",
+ "method": "composition",
+ "init": "seed",
+ "blend": 1.0,
+ "allowed": "all",
+ "diversity": 1.0,
+ },
+ {
+ "key": "comp_seed_blend",
+ "label": "comp (seed, 5% all)",
+ "method": "composition",
+ "init": "seed",
+ "blend": 0.95,
+ "allowed": "all",
+ "diversity": 1.0,
+ },
+ {
+ "key": "comp_seed_blend_palette",
+ "label": "comp (seed, 5% all, element list)",
+ "method": "composition",
+ "init": "seed",
+ "blend": 0.95,
+ "allowed": _PALETTE_SENTINEL,
+ "diversity": 1.0,
+ },
+ {
+ # Ablation: clamp diversity to 0 → max entropy penalty → forced peaky few-element recipes.
+ "key": "comp_seed_blend_palette_lowdiv",
+ "label": "comp (seed, 5% all, element list, low diversity)",
+ "method": "composition",
+ "init": "seed",
+ "blend": 0.95,
+ "allowed": _PALETTE_SENTINEL,
+ "diversity": 0.0,
+ },
+ {
+ "key": "comp_random",
+ "label": "comp (random)",
+ "method": "composition",
+ "init": "random",
+ "blend": 0.95,
+ "allowed": "all",
+ "diversity": 1.0,
+ },
+]
+INVERSE_PATHS: list[str] = [c["key"] for c in INVERSE_PATH_CONFIGS]
+INVERSE_PATH_CONFIGS_BY_KEY: dict[str, dict[str, Any]] = {c["key"]: c for c in INVERSE_PATH_CONFIGS}
+
+# Per-regression-task panel title (units + arrow). Matches the demo's REG_TASK_TITLES so plots
+# read the same across both runners. Falls back to the bare task name if a task isn't listed.
+REG_TASK_TITLES: dict[str, str] = {
+ "formation_energy": "Formation energy [eV/atom] ↓",
+ "klat": "klat [W/mK] ↑",
+ "magnetic_moment": "Magnetic moment [μB/f.u.] ↑",
+ "tc": "Critical temperature [K] ↑",
+}
+
+
+def _seed_weights_from_compositions(seeds: list[str], n_components: int) -> torch.Tensor:
+ """Element-weight tensor ``(B, n_components)`` for seeding ``optimize_composition``.
+
+ Order matches DEFAULT_ELEMENTS. Raises if any seed cannot be parsed — we fail fast rather than
+ silently dropping rows (callers rely on per-seed correspondence with the latent path).
+ """
+ rows = []
+ for c in seeds:
+ w = formula_to_composition(c)
+ if w is None:
+ raise ValueError(f"Cannot parse seed composition '{c}' to element weights.")
+ rows.append(np.asarray(w, dtype=np.float64))
+ return torch.tensor(np.stack(rows), dtype=torch.float64)
+
+
+def _format_weights(weights: np.ndarray, top_k: int = 6, eps: float = 1e-3) -> list[str]:
+ """Render element-weight rows as compact formula strings (top-K elements above ``eps``)."""
+ out: list[str] = []
+ for row in weights:
+ order = np.argsort(row)[::-1]
+ parts = [f"{DEFAULT_ELEMENTS[i]}{row[i]:.3f}" for i in order[:top_k] if row[i] > eps]
+ out.append(" ".join(parts) if parts else "")
+ return out
+
+
+def _display(task: str) -> str:
+ return TASK_DISPLAY.get(task, task.replace("_", " ").title())
+
+
+def _scale_label(task: str) -> str:
+ return "normalized" if TASK_SPECS[task]["source"] == "qc" else "log1p, z-scored"
+
+
+def _title(task: str) -> str:
+ return f"{_display(task)} ({_scale_label(task)})"
+
+
+def _arrow(value: float) -> str:
+ return "↓" if value < 0 else "↑"
+
+
+@dataclass
+class InverseScenario:
+ """One inverse-design objective set (primary = QC probability; secondary = regression targets)."""
+
+ name: str
+ reg_tasks: list[str]
+ reg_targets: list[float]
+
+ def __post_init__(self) -> None:
+ if len(self.reg_tasks) != len(self.reg_targets):
+ raise ValueError(f"Scenario '{self.name}': reg_tasks and reg_targets must have equal length.")
+
+
+@dataclass
+class ContinualRehearsalFullConfig:
+ """Configuration for the full continual rehearsal + inverse-design run."""
+
+ qc_data_path: Path = Path("data/qc_ac_te_mp_dos_reformat_20260515.pd.parquet")
+ qc_preprocessing_path: Path | None = None
+ superconductor_path: Path = Path("data/NEMAD_superconductor_20260425.parquet")
+ magnetic_path: Path = Path("data/NEMAD_magnetic_20260419.parquet")
+ phonix_path: Path = Path("data/phonix-db-filtered_20260425.parquet")
+ output_dir: Path = Path("artifacts/continual_rehearsal_full")
+
+ task_sequence: list[str] = field(default_factory=lambda: list(DEFAULT_SEQUENCE))
+ fixed_tail: list[str] = field(default_factory=lambda: list(DEFAULT_FIXED_TAIL))
+ replay_ratio: float = 0.05 # ordinary old-task replay ratio
+ replay_ratio_high: float = 0.10 # replay ratio for fixed_tail tasks when replayed as old
+ sample_per_dataset: int | None = None # cap rows per dataset (for fast/smoke runs)
+
+ max_epochs_per_step: int = 100 # ceiling; EarlyStopping usually stops sooner
+ early_stop_patience: int = 8
+ early_stop_min_delta: float = 1e-4
+ batch_size: int = 256
+ num_workers: int = 0
+
+ n_grids: int = 8
+ latent_dim: int = 128
+ encoder_hidden: int = 256
+ head_hidden_dim: int = 64
+ head_lr: float = 5e-3
+ encoder_lr: float = 5e-3
+ n_kernel: int = 15
+ kr_lr: float = 5e-4
+ kr_decay: float = 5e-5
+
+ # Inverse design (shared across scenarios). Primary objective is QC probability ↑; each
+ # scenario runs the eight PR #18 paths (3 latent + 5 composition configs) — see plan §5.
+ inverse_n_seeds: int = 20 # 17 top-QC dedup + 3 explicit Au-Ga-Ln formers (plan §5)
+ inverse_steps: int = 300
+ inverse_lr: float = 0.05
+ inverse_class_weight: float = 5.0
+ # 48-element ``ALLOY_PALETTE`` for the composition rows that whitelist elements. Configurable
+ # in case the slide author wants a wider or narrower palette; everything else (ae_align_scale
+ # sweep, seed_blend, diversity_scale) is fixed at the module level in ``INVERSE_PATH_CONFIGS``
+ # so the comparison is a stable ablation across runs.
+ inverse_composition_allowed_elements: list[str] = field(default_factory=lambda: list(ALLOY_PALETTE))
+ inverse_seed_strategy: str = "top_qc" # "top_qc" | "random" | "explicit"
+ # Held-out test split is the right default for the formal full run: the model has seen the
+ # train compositions during training, so its top-QC ranking there is part memorisation; test
+ # compositions are held out, so the ranking is a genuine prediction → seeds are real novel QC
+ # candidates. (Override to "train" only when reproducing the demo / paper baseline.)
+ inverse_seed_split: str = "test" # "train" | "val" | "test" | "all"
+ inverse_seed_compositions: list[str] = field(default_factory=list)
+ # Compositions appended to the strategy-selected seeds regardless of QC ranking. Each must
+ # have a computable descriptor (fail-fast in _select_seeds). The strategy budget is reduced
+ # by len(explicit_append) so total seeds == inverse_n_seeds. Defaults to the three Au-Ga-Ln
+ # i-QC formers used in plan §5 (Au65 Ga20 Gd/Tb/Dy15).
+ inverse_seed_explicit_append: list[str] = field(
+ default_factory=lambda: ["Au65 Ga20 Gd15", "Au65 Ga20 Tb15", "Au65 Ga20 Dy15"]
+ )
+ inverse_scenarios: list[InverseScenario] = field(
+ default_factory=lambda: [
+ InverseScenario("scenario1_fe_down_moment_up", ["formation_energy", "magnetic_moment"], [-2.0, 2.0]),
+ InverseScenario("scenario2_fe_tc_moment", ["formation_energy", "tc", "magnetic_moment"], [-2.0, 2.0, 2.0]),
+ InverseScenario("scenario3_fe_down_klat_up", ["formation_energy", "klat"], [-2.0, 2.0]),
+ ]
+ )
+
+ random_seed: int = 2025
+ datamodule_random_seed: int = 42
+ accelerator: str = "auto"
+ devices: int = 1
+
+ def __post_init__(self) -> None:
+ unknown = [t for t in self.task_sequence if t not in TASK_SPECS]
+ if unknown:
+ raise ValueError(f"Unknown task(s) {unknown}. Available: {sorted(TASK_SPECS)}")
+ if len(set(self.task_sequence)) != len(self.task_sequence):
+ raise ValueError("task_sequence contains duplicates.")
+ bad_tail = [t for t in self.fixed_tail if t not in self.task_sequence]
+ if bad_tail:
+ raise ValueError(f"fixed_tail tasks {bad_tail} are not in task_sequence.")
+ for ratio_name, ratio in (("replay_ratio", self.replay_ratio), ("replay_ratio_high", self.replay_ratio_high)):
+ if not 0.0 <= ratio <= 1.0:
+ raise ValueError(f"{ratio_name} must be in [0, 1].")
+ if not self.inverse_composition_allowed_elements:
+ raise ValueError("inverse_composition_allowed_elements must be non-empty.")
+ unknown_palette = [e for e in self.inverse_composition_allowed_elements if e not in DEFAULT_ELEMENTS]
+ if unknown_palette:
+ raise ValueError(
+ f"inverse_composition_allowed_elements contains symbols not in DEFAULT_ELEMENTS: {unknown_palette}"
+ )
+ if self.inverse_seed_strategy not in {"top_qc", "random", "explicit"}:
+ raise ValueError("inverse_seed_strategy must be 'top_qc', 'random', or 'explicit'.")
+ if self.inverse_seed_split not in {"train", "val", "test", "all"}:
+ raise ValueError("inverse_seed_split must be 'train', 'val', 'test', or 'all'.")
+ if self.inverse_seed_strategy == "explicit" and not self.inverse_seed_compositions:
+ raise ValueError("inverse_seed_strategy='explicit' requires inverse_seed_compositions.")
+ # Every scenario's tasks must be regression tasks present in the sequence.
+ for sc in self.inverse_scenarios:
+ for t in sc.reg_tasks:
+ if t not in self.task_sequence:
+ raise ValueError(f"Scenario '{sc.name}': task '{t}' not in task_sequence.")
+ if TASK_SPECS[t]["kind"] != "reg":
+ raise ValueError(f"Scenario '{sc.name}': task '{t}' must be a (scalar) regression task.")
+ if "material_type" not in self.task_sequence:
+ raise ValueError("task_sequence must contain 'material_type' (QC classifier for inverse design).")
+
+
+class _DropLastTrainCompoundDataModule(CompoundDataModule):
+ """``CompoundDataModule`` variant whose train loader sets ``drop_last=True``.
+
+ PyTorch ``BatchNorm1d`` in training mode raises ``ValueError: Expected more than 1 value per
+ channel`` on a batch of size 1. With ``shuffle=True`` and ``drop_last=False`` (the upstream
+ default), any train subset whose size ``mod batch_size == 1`` will eventually feed that
+ single-row tail batch into the encoder's ``fc_layers`` BN and crash mid-epoch — exactly what
+ happened in the first attempted full-data MPS run (Step 1, ``density``).
+
+ Dropping the final partial batch costs at most ``batch_size − 1`` rows per epoch (~256 / 35k
+ rows in the qc train split ≈ 0.7 %), which is well within the noise of the rehearsal mask. We
+ only touch the train loader; val / test / predict keep ``drop_last=False`` so every held-out
+ row is evaluated. ``_train_sampler`` (used only by the DDP path) is left untouched — we are
+ not using DDP here.
+ """
+
+ def train_dataloader(self):
+ base = super().train_dataloader()
+ if base is None:
+ return None
+ return DataLoader(
+ base.dataset,
+ batch_size=base.batch_size,
+ shuffle=True,
+ num_workers=base.num_workers,
+ pin_memory=base.pin_memory,
+ collate_fn=base.collate_fn,
+ drop_last=True,
+ )
+
+
+class ContinualRehearsalFullRunner:
+ def __init__(self, config: ContinualRehearsalFullConfig):
+ self.config = config
+ self.output_dir = Path(config.output_dir)
+ # Plan §6 layout: training/ for per-step artifacts (incl. final_model.pt and forgetting
+ # trajectory), inverse_design/ for the dual-path scenarios, slide-prep / analysis / readme
+ # at the top level. Subdirs are created lazily where needed.
+ self.training_dir = self.output_dir / "training"
+ self.inverse_root = self.output_dir / "inverse_design"
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+ self.training_dir.mkdir(parents=True, exist_ok=True)
+ _apply_plot_style()
+ self._task_colors = {name: _PALETTE[i % len(_PALETTE)] for i, name in enumerate(config.task_sequence)}
+ self._kmd = KMD(element_features.values, method="1d", n_grids=config.n_grids, sigma="auto", scale=True)
+ self.x_dim = int(self._kmd.transform(np.eye(1, len(DEFAULT_ELEMENTS))).shape[1])
+ self._desc_cache: dict[str, np.ndarray] = {}
+ self._load_data()
+
+ # ------------------------------------------------------------------ data
+
+ def _load_data(self) -> None:
+ cfg = self.config
+ rng = np.random.default_rng(cfg.datamodule_random_seed)
+ self.task_frames: dict[str, pd.DataFrame] = {}
+ split_by_key: dict[str, str] = {}
+
+ sources = {
+ "qc": self._load_qc(),
+ "superconductor": pd.read_parquet(cfg.superconductor_path),
+ "magnetic": pd.read_parquet(cfg.magnetic_path),
+ "phonix": pd.read_parquet(cfg.phonix_path),
+ }
+
+ keyed: dict[str, pd.DataFrame] = {}
+ for name, df in sources.items():
+ df = df.copy()
+ if cfg.sample_per_dataset is not None and cfg.sample_per_dataset < len(df):
+ if name == "qc" and "Material type (label)" in df.columns:
+ df = self._stratified_qc_sample(df, cfg.sample_per_dataset, rng)
+ else:
+ df = df.iloc[rng.choice(len(df), size=cfg.sample_per_dataset, replace=False)]
+ df["__key__"] = [_composition_key(v) for v in df["composition"]]
+ df = df.dropna(subset=["__key__"]).drop_duplicates(subset="__key__", keep="first").set_index("__key__")
+ keyed[name] = df
+ if "split" in df.columns:
+ for k, s in df["split"].items():
+ split_by_key.setdefault(str(k), str(s))
+ else:
+ for k in df.index:
+ split_by_key.setdefault(str(k), rng.choice(["train", "val", "test"], p=[0.7, 0.15, 0.15]))
+
+ for task_name in cfg.task_sequence:
+ spec = TASK_SPECS[task_name]
+ df = keyed[spec["source"]]
+ col = spec["column"]
+ if col not in df.columns:
+ raise KeyError(f"Task '{task_name}': column '{col}' missing in {spec['source']} data.")
+ frame = pd.DataFrame(index=df.index)
+ values = df[col]
+ if task_name == "material_type":
+ values = values.map(_MATERIAL_TYPE_MERGE)
+ if spec["source"] != "qc" and spec["kind"] == "reg":
+ v = np.log1p(df[col].astype(float).clip(lower=0.0))
+ is_train = np.array([split_by_key.get(str(k)) == "train" for k in df.index])
+ ref = v[is_train] if is_train.any() else v
+ mean = float(ref.mean())
+ std = float(ref.std(ddof=0)) or 1.0
+ values = ((v - mean) / std).clip(-_RAW_TARGET_CLIP, _RAW_TARGET_CLIP)
+ frame[col] = values
+ if spec["kind"] == "kr":
+ frame[spec["t_column"]] = df[spec["t_column"]]
+ frame["split"] = [split_by_key.get(str(k), "train") for k in frame.index]
+ self.task_frames[task_name] = frame
+
+ self.split_by_key = split_by_key
+ n_keys = len(set().union(*[set(f.index) for f in self.task_frames.values()]))
+ logger.info(f"Built {len(self.task_frames)} task frames over {n_keys} unique compositions; x_dim={self.x_dim}.")
+
+ def _load_qc(self) -> pd.DataFrame:
+ cfg = self.config
+ df = pd.read_parquet(cfg.qc_data_path)
+ if cfg.qc_preprocessing_path is not None and Path(cfg.qc_preprocessing_path).exists():
+ dropped = joblib.load(cfg.qc_preprocessing_path).get("dropped_idx", [])
+ df = df.loc[~df.index.isin(dropped)]
+ return df
+
+ @staticmethod
+ def _stratified_qc_sample(df: pd.DataFrame, cap: int, rng: np.random.Generator) -> pd.DataFrame:
+ """Cap qc rows while keeping every minority (non-"others") material-type row."""
+ labels = df["Material type (label)"]
+ minority = df[labels != 4]
+ others = df[labels == 4]
+ n_others = max(cap - len(minority), 0)
+ if n_others < len(others):
+ others = others.iloc[rng.choice(len(others), size=n_others, replace=False)]
+ out = pd.concat([minority, others])
+ if len(out) > cap:
+ out = out.iloc[rng.choice(len(out), size=cap, replace=False)]
+ return out
+
+ def _class_weights(self, task_name: str) -> list[float]:
+ spec = TASK_SPECS[task_name]
+ frame = self.task_frames[task_name]
+ num_classes = int(spec["num_classes"])
+ train = frame.loc[frame["split"] == "train", spec["column"]].dropna().astype(int)
+ counts = np.bincount(train, minlength=num_classes).astype(float)
+ counts[counts == 0] = 1.0
+ weights = counts.sum() / (num_classes * counts)
+ return weights.tolist()
+
+ def descriptor_fn(self, compositions: list[str]) -> pd.DataFrame:
+ uncached = [c for c in dict.fromkeys(compositions) if c not in self._desc_cache]
+ if uncached:
+ weights = np.zeros((len(uncached), len(DEFAULT_ELEMENTS)), dtype=float)
+ valid: list[str] = []
+ for key in uncached:
+ try:
+ w = formula_to_composition(key)
+ except Exception:
+ w = None
+ if w is None or float(w.sum()) <= 0:
+ continue
+ weights[len(valid)] = w
+ valid.append(key)
+ if valid:
+ desc = self._kmd.transform(weights[: len(valid)])
+ for j, key in enumerate(valid):
+ self._desc_cache[key] = desc[j]
+ present = [c for c in compositions if c in self._desc_cache]
+ if not present:
+ return pd.DataFrame()
+ return pd.DataFrame(np.stack([self._desc_cache[c] for c in present]), index=present)
+
+ # ------------------------------------------------------------------ configs
+
+ def _build_task_config(self, task_name: str):
+ cfg = self.config
+ spec = TASK_SPECS[task_name]
+ ld, hd = cfg.latent_dim, cfg.head_hidden_dim
+ if spec["kind"] == "reg":
+ return RegressionTaskConfig(
+ name=task_name,
+ data_column=spec["column"],
+ dims=[ld, hd, 1],
+ optimizer=OptimizerConfig(lr=cfg.head_lr, weight_decay=1e-5),
+ )
+ if spec["kind"] == "clf":
+ return ClassificationTaskConfig(
+ name=task_name,
+ data_column=spec["column"],
+ dims=[ld, hd, 32],
+ num_classes=spec["num_classes"],
+ class_weights=self._class_weights(task_name),
+ optimizer=OptimizerConfig(lr=cfg.head_lr, weight_decay=1e-5),
+ )
+ train_t = self._collect_train_t(task_name)
+ centers, sigmas = _init_kernels(train_t, cfg.n_kernel)
+ return KernelRegressionTaskConfig(
+ name=task_name,
+ data_column=spec["column"],
+ t_column=spec["t_column"],
+ x_dim=[ld, 128, 64],
+ t_dim=[16, 8],
+ kernel_num_centers=cfg.n_kernel,
+ kernel_centers_init=centers or None,
+ kernel_sigmas_init=sigmas or None,
+ kernel_learnable_centers=True,
+ kernel_learnable_sigmas=True,
+ enable_mu3=False,
+ optimizer=OptimizerConfig(lr=cfg.kr_lr, weight_decay=cfg.kr_decay),
+ )
+
+ def _collect_train_t(self, task_name: str) -> np.ndarray:
+ spec = TASK_SPECS[task_name]
+ frame = self.task_frames[task_name]
+ mask = frame[spec["column"]].notna() & (frame["split"] == "train")
+ cells = frame.loc[mask, spec["t_column"]].dropna()
+ if cells.empty:
+ return np.array([])
+ return np.concatenate([_as_float_array(c) for c in cells])
+
+ # ------------------------------------------------------------------ run
+
+ def _build_empty_model(self) -> FlexibleMultiTaskModel:
+ """The bare model used as the starting point for both ``run`` and ``run_inverse_only``."""
+ cfg = self.config
+ encoder_config = MLPEncoderConfig(hidden_dims=[self.x_dim, cfg.encoder_hidden, cfg.latent_dim])
+ return FlexibleMultiTaskModel(
+ task_configs=[],
+ encoder_config=encoder_config,
+ enable_autoencoder=True,
+ shared_block_optimizer=OptimizerConfig(lr=cfg.encoder_lr, weight_decay=1e-2),
+ )
+
+ def _build_full_model(self) -> FlexibleMultiTaskModel:
+ """Rebuild the post-training model (all tasks added in sequence order) so a saved
+ ``final_model.pt`` ``state_dict`` can be loaded for inverse-only runs."""
+ model = self._build_empty_model()
+ for task_name in self.config.task_sequence:
+ model.add_task(self._build_task_config(task_name))
+ return model
+
+ def run(
+ self,
+ *,
+ record_trajectory: bool = True,
+ per_seed_trajectories: bool = False,
+ animation_formats: tuple[str, ...] = ("gif",),
+ ) -> None:
+ cfg = self.config
+ seed_everything(cfg.random_seed, workers=True)
+ model = self._build_empty_model()
+
+ task_configs: dict[str, Any] = {}
+ metric_history: dict[str, list[tuple[int, float]]] = {name: [] for name in cfg.task_sequence}
+ records: list[dict[str, Any]] = []
+ fixed_tail = set(cfg.fixed_tail)
+
+ for step, task_name in enumerate(cfg.task_sequence):
+ logger.info(f"=== Step {step + 1}/{len(cfg.task_sequence)}: add task '{task_name}' ===")
+ task_configs[task_name] = self._build_task_config(task_name)
+ model.add_task(task_configs[task_name])
+
+ active = cfg.task_sequence[: step + 1]
+ # New task fully active; old tasks replayed — fixed-tail tasks at the higher ratio.
+ for name in active:
+ if name == task_name:
+ ratio = 1.0
+ elif name in fixed_tail:
+ ratio = cfg.replay_ratio_high
+ else:
+ ratio = cfg.replay_ratio
+ task_configs[name].task_masking_ratio = ratio
+
+ datamodule = _DropLastTrainCompoundDataModule(
+ task_configs=[task_configs[name] for name in active],
+ descriptor_fn=self.descriptor_fn,
+ task_frames={name: self.task_frames[name] for name in active},
+ composition_column="composition",
+ random_seed=cfg.datamodule_random_seed,
+ batch_size=cfg.batch_size,
+ num_workers=cfg.num_workers,
+ )
+ callbacks: list[Callback] = [
+ EarlyStopping(
+ monitor="val_final_loss",
+ mode="min",
+ patience=cfg.early_stop_patience,
+ min_delta=cfg.early_stop_min_delta,
+ )
+ ]
+ trainer = Trainer(
+ max_epochs=cfg.max_epochs_per_step,
+ accelerator=cfg.accelerator,
+ devices=cfg.devices,
+ logger=False,
+ enable_checkpointing=False,
+ enable_progress_bar=False,
+ callbacks=callbacks,
+ )
+ trainer.fit(model, datamodule=datamodule)
+
+ test_keys: set[str] | None = None
+ if datamodule.split_series is not None:
+ resolved = datamodule.split_series
+ test_keys = set(resolved.index[resolved == "test"].astype(str))
+
+ step_dir = self.training_dir / f"step{step + 1:02d}_{task_name}"
+ step_dir.mkdir(parents=True, exist_ok=True)
+ step_metrics: dict[str, dict[str, float]] = {}
+ for name in active:
+ # Plot only the freshly-added head; dump raw (composition, true, pred) + per-task
+ # metrics.json for every active head so the forgetting trajectory is backed by
+ # raw data and per-task numbers at each stage.
+ metric = self._evaluate_task(model, name, step_dir, is_new=(name == task_name), test_keys=test_keys)
+ step_metrics[name] = metric
+ metric_history[name].append((step + 1, metric["primary"]))
+ # Per-step model checkpoint (mirrors the demo, PR #18). Lets analysts revisit any
+ # intermediate stage ("what did the encoder look like just after task K was added?")
+ # without retraining the prefix, and feeds downstream finetune scripts.
+ step_ckpt = step_dir / "checkpoint.pt"
+ torch.save(
+ {
+ "model": model.state_dict(),
+ "task_sequence": list(cfg.task_sequence),
+ "step": step + 1,
+ "new_task": task_name,
+ "active_tasks": list(active),
+ },
+ step_ckpt,
+ )
+ records.append(
+ {"step": step + 1, "new_task": task_name, "epochs_run": trainer.current_epoch, "metrics": step_metrics}
+ )
+ summary = ", ".join(f"{k}={v['primary']:.3f}" for k, v in step_metrics.items())
+ rel_ckpt = step_ckpt.relative_to(self.output_dir)
+ logger.info(f"Step {step + 1} ({trainer.current_epoch} epochs): {summary} (ckpt: {rel_ckpt})")
+
+ self._plot_forgetting(metric_history)
+ (self.training_dir / "experiment_records.json").write_text(json.dumps(records, indent=2), encoding="utf-8")
+ self._write_metrics_table(records)
+ self._save_final_model(model, task_configs)
+
+ inverse = self._inverse_design(
+ model,
+ record_trajectory=record_trajectory,
+ per_seed_trajectories=per_seed_trajectories,
+ animation_formats=animation_formats,
+ )
+ (self.inverse_root / "inverse_design.json").write_text(json.dumps(inverse, indent=2), encoding="utf-8")
+
+ # Slide-prep deliverables (plan §6) — no more PPT/HTML; the slide author works from
+ # SLIDE_PREP.md + the raw arrays + the standard image set. The three scenarios are
+ # treated as equal first-class results — no demo-style "headline scenario" promotion.
+ self._write_inverse_summary_md(inverse)
+ self._write_analysis_md(records, inverse)
+ self._write_slide_prep_md(records, inverse)
+ self._write_readme(records, inverse)
+ logger.info(f"Done. Outputs in {self.output_dir}")
+
+ def _save_final_model(self, model, task_configs: dict[str, Any]) -> None:
+ # Schema matches the demo's ``final_model.pt`` (PR #18) so the same downstream consumers —
+ # ``paper_inverse_comparison.py`` / ``finetune_inverse_heads.py`` / ``--inverse-only`` —
+ # can ingest checkpoints from either runner without translation.
+ ckpt = self.training_dir / "final_model.pt"
+ torch.save({"model": model.state_dict(), "task_sequence": list(self.config.task_sequence)}, ckpt)
+ spec_dump = {
+ name: {
+ "kind": TASK_SPECS[name]["kind"],
+ "column": TASK_SPECS[name]["column"],
+ "source": TASK_SPECS[name]["source"],
+ }
+ for name in self.config.task_sequence
+ }
+ (self.training_dir / "final_model_taskconfigs.json").write_text(
+ json.dumps(spec_dump, indent=2), encoding="utf-8"
+ )
+ logger.info(f"Saved final model checkpoint to {ckpt}")
+
+ def run_inverse_only(
+ self,
+ ckpt_path: Path,
+ *,
+ record_trajectory: bool = True,
+ per_seed_trajectories: bool = False,
+ animation_formats: tuple[str, ...] = ("gif",),
+ ) -> None:
+ """Skip training; load a saved ``final_model.pt`` and run only the inverse-design stage.
+
+ Use this to iterate on inverse-design knobs (seed split, palette, scenarios, …) without
+ repeating the multi-hour training. Data loading + descriptor computation still happen —
+ they're prerequisites for seed selection and the composition-path kernel — but no
+ ``Trainer.fit`` is called.
+
+ After the inverse-design pass we also **refresh the slide-prep deliverables**
+ (``ANALYSIS.md`` / ``SLIDE_PREP.md`` / ``README.md``) by loading the previous run's
+ ``training/experiment_records.json`` — without that, those documents would still quote
+ the inverse-design numbers from the previous pass. The training-derived sections
+ (forgetting trajectory, headline-task R² / accuracy) come from ``records`` unchanged.
+ If the records file is missing (e.g. inverse-only against a checkpoint from a different
+ run that didn't expose it), the deliverables are skipped with a warning.
+ """
+ logger.info(f"=== Inverse-only mode: loading model checkpoint {ckpt_path} ===")
+ seed_everything(self.config.random_seed, workers=True)
+ model = self._build_full_model()
+ state = torch.load(ckpt_path, map_location="cpu", weights_only=True)
+ state_dict = state["model"] if isinstance(state, dict) and "model" in state else state
+ model.load_state_dict(state_dict)
+ model.eval()
+ inverse = self._inverse_design(
+ model,
+ record_trajectory=record_trajectory,
+ per_seed_trajectories=per_seed_trajectories,
+ animation_formats=animation_formats,
+ )
+ (self.inverse_root / "inverse_design.json").write_text(json.dumps(inverse, indent=2), encoding="utf-8")
+ self._write_inverse_summary_md(inverse)
+
+ # Refresh the slide-prep deliverables so their inverse-design tables / seed lists match
+ # the values we just re-ran. The training records live next to the checkpoint.
+ records_path = self.training_dir / "experiment_records.json"
+ if records_path.exists():
+ records = json.loads(records_path.read_text(encoding="utf-8"))
+ self._write_analysis_md(records, inverse)
+ self._write_slide_prep_md(records, inverse)
+ self._write_readme(records, inverse)
+ logger.info(f"Refreshed ANALYSIS.md / SLIDE_PREP.md / README.md from {records_path}")
+ else:
+ logger.warning(
+ f"{records_path} not found — keeping previous ANALYSIS.md / SLIDE_PREP.md / "
+ "README.md unchanged. Inverse-design numbers in those docs may be stale."
+ )
+ logger.info(f"Inverse-only done. Outputs in {self.output_dir}")
+
+ def _write_metrics_table(self, records: list[dict[str, Any]]) -> None:
+ final = records[-1]["metrics"] if records else {}
+ intro = {r["new_task"]: r["metrics"][r["new_task"]] for r in records}
+ rows = []
+ for task in self.config.task_sequence:
+ spec = TASK_SPECS[task]
+ metric_name = "accuracy" if spec["kind"] == "clf" else "R2"
+ rows.append(
+ {
+ "task": task,
+ "display": _display(task),
+ "type": KIND_LABEL[spec["kind"]],
+ "dataset": SOURCE_DISPLAY[spec["source"]],
+ "metric": metric_name,
+ "at_intro": intro.get(task, {}).get("primary", float("nan")),
+ "final": final.get(task, {}).get("primary", float("nan")),
+ "final_mae": final.get(task, {}).get("mae", float("nan")),
+ "samples": final.get(task, {}).get("samples", 0),
+ }
+ )
+ pd.DataFrame(rows).to_csv(self.training_dir / "metrics_table.csv", index=False)
+
+ # ------------------------------------------------------------------ eval
+
+ def _test_rows(self, task_name: str, test_keys: set[str] | None = None) -> list[str]:
+ spec = TASK_SPECS[task_name]
+ frame = self.task_frames[task_name]
+ mask = frame[spec["column"]].notna()
+ mask &= frame.index.isin(test_keys) if test_keys is not None else (frame["split"] == "test")
+ return list(frame.index[mask])
+
+ def _descriptor_tensor(self, comps: list[str], device) -> tuple[torch.Tensor, list[str]]:
+ desc = self.descriptor_fn(comps)
+ comps = [c for c in comps if c in desc.index]
+ return torch.tensor(desc.loc[comps].values, dtype=torch.float32, device=device), comps
+
+ def _evaluate_task(self, model, task_name, step_dir, *, is_new, test_keys=None) -> dict[str, float]:
+ spec = TASK_SPECS[task_name]
+ kind = spec["kind"]
+ model.eval()
+ device = next(model.parameters()).device
+ comps = self._test_rows(task_name, test_keys)
+ if not comps:
+ return {"primary": float("nan"), "samples": 0}
+ frame = self.task_frames[task_name]
+ head = model.task_heads[task_name]
+
+ with torch.no_grad():
+ if kind in ("reg", "clf"):
+ x, comps = self._descriptor_tensor(comps, device)
+ if not comps:
+ return {"primary": float("nan"), "samples": 0}
+ h = torch.tanh(model.encoder(x))
+ if kind == "reg":
+ pred = head(h).squeeze(-1).cpu().numpy()
+ true = frame.loc[comps, spec["column"]].astype(float).to_numpy()
+ r2 = float(r2_score(true, pred))
+ metric = {
+ "r2": r2,
+ "mae": float(mean_absolute_error(true, pred)),
+ "samples": len(comps),
+ "primary": r2,
+ }
+ dump_predictions(task_name, step_dir, comps=list(comps), true=true, pred=pred)
+ dump_metrics(task_name, step_dir, metric)
+ if is_new:
+ plot_parity(true, pred, task_name, r2, step_dir, title=_title(task_name))
+ return metric
+ logits = head(h)
+ pred = logits.argmax(dim=-1).cpu().numpy()
+ true = frame.loc[comps, spec["column"]].astype(int).to_numpy()
+ acc = float(accuracy_score(true, pred))
+ metric = {
+ "accuracy": acc,
+ "macro_f1": float(f1_score(true, pred, average="macro", zero_division=0)),
+ "samples": len(comps),
+ "primary": acc,
+ }
+ dump_predictions(task_name, step_dir, comps=list(comps), true=true, pred=pred)
+ dump_metrics(task_name, step_dir, metric)
+ if is_new:
+ plot_confusion(
+ true,
+ pred,
+ task_name,
+ acc,
+ step_dir,
+ spec["num_classes"],
+ title=_display(task_name),
+ special_material_type=(task_name == "material_type"),
+ )
+ return metric
+
+ # kernel regression
+ keep, t_list, true_parts = [], [], []
+ for comp in comps:
+ if comp not in self._desc_cache and self.descriptor_fn([comp]).empty:
+ continue
+ y_arr = _as_float_array(frame.at[comp, spec["column"]])
+ t_arr = _as_float_array(frame.at[comp, spec["t_column"]])
+ if y_arr.size == 0 or y_arr.size != t_arr.size:
+ continue
+ keep.append(comp)
+ t_list.append(torch.tensor(t_arr, dtype=torch.float32, device=device))
+ true_parts.append(y_arr)
+ if not keep:
+ return {"primary": float("nan"), "samples": 0}
+ xk, _ = self._descriptor_tensor(keep, device)
+ h_k = torch.tanh(model.encoder(xk))
+ expanded_h, expanded_t = model._expand_for_kernel_regression(h_k, t_list)
+ pred = head(expanded_h, t=expanded_t).squeeze(-1).cpu().numpy()
+ true = np.concatenate(true_parts)
+ r2 = float(r2_score(true, pred))
+ metric = {
+ "r2": r2,
+ "mae": float(mean_absolute_error(true, pred)),
+ "samples": len(keep),
+ "points": int(true.size),
+ "primary": r2,
+ }
+ dump_kr_predictions(
+ task_name,
+ step_dir,
+ comps=keep,
+ t_list=[t.cpu().numpy() for t in t_list],
+ true_parts=true_parts,
+ pred=pred,
+ )
+ dump_metrics(task_name, step_dir, metric)
+ if is_new:
+ plot_kr_sequences(keep, t_list, true_parts, pred, task_name, step_dir, title=_title(task_name))
+ return metric
+
+ # --- per-task artifact dump helpers --------------------------------------
+ # ``dump_predictions`` / ``dump_kr_predictions`` / ``dump_metrics`` now live in
+ # :mod:`continual_rehearsal_common`; imported at the top of this file and called inline
+ # in ``_evaluate_task``. The bound-method versions were verbatim copies of demo's and
+ # caused drift (PR #18 code review).
+
+ # ------------------------------------------------------------------ inverse design
+
+ @staticmethod
+ def _element_system(composition: str) -> frozenset[str]:
+ """Element symbols (no amounts) in a composition string — used for system-level dedup."""
+ return frozenset(re.findall(r"[A-Z][a-z]?", composition))
+
+ @classmethod
+ def _dedupe_by_element_system(cls, candidates: list[str], n: int) -> list[str]:
+ """Walk ``candidates`` in order, keep the first occurrence of each element set, cap at ``n``.
+
+ Empty / malformed compositions (those that parse to an empty element-set) are silently
+ skipped so a bad row in the source dataframe doesn't blow up the seed picker — matches
+ the demo runner's behaviour at ``continual_rehearsal_demo._dedupe_by_element_system``
+ (the two used to differ; aligning them prevents drift when this gets shared into
+ ``continual_rehearsal_common``).
+ """
+ seen: set[frozenset[str]] = set()
+ out: list[str] = []
+ for comp in candidates:
+ key = cls._element_system(comp)
+ if not key or key in seen:
+ continue
+ seen.add(key)
+ out.append(comp)
+ if len(out) >= n:
+ break
+ return out
+
+ def _select_seeds(self, model, device, qc_prob_fn) -> dict[str, list[str]]:
+ """Pick seed compositions for inverse design (mirrors demo's PR #18 behaviour).
+
+ Returns ``{"strategy_seeds": […], "explicit_seeds": […]}``. Element-system dedup keeps the
+ best representative per element set (so 17 strategy seeds = 17 distinct alloy families,
+ not 17 ratio variants of three). ``inverse_seed_explicit_append`` is fail-fast validated
+ (each appended composition must have a computable descriptor) and the strategy budget is
+ reduced by its length so the total length equals ``inverse_n_seeds``.
+ """
+ cfg = self.config
+ n = cfg.inverse_n_seeds
+
+ # Pre-validate the explicit-append seeds so we fail fast on bad input.
+ appended: list[str] = []
+ for raw in cfg.inverse_seed_explicit_append:
+ norm = normalize_composition(raw) or str(raw)
+ if norm not in self._desc_cache and self.descriptor_fn([norm]).empty:
+ raise ValueError(
+ f"inverse_seed_explicit_append entry {raw!r} has no computable descriptor "
+ "(check the formula and that all elements are in DEFAULT_ELEMENTS)."
+ )
+ appended.append(norm)
+ # Dedup the appended list itself (in case the user listed near-duplicates).
+ appended = self._dedupe_by_element_system(appended, len(appended))
+ n_strategy = max(0, n - len(appended))
+
+ def _finalise(strategy_seeds: list[str]) -> dict[str, list[str]]:
+ """Combine strategy seeds + explicit-append, skipping any duplicate element systems."""
+ seen_keys = {self._element_system(c) for c in appended}
+ kept_strategy = [c for c in strategy_seeds if self._element_system(c) not in seen_keys][:n_strategy]
+ return {"strategy_seeds": kept_strategy, "explicit_seeds": appended}
+
+ if cfg.inverse_seed_strategy == "explicit":
+ seeds = [normalize_composition(c) or str(c) for c in cfg.inverse_seed_compositions]
+ seeds = [c for c in seeds if c in self._desc_cache or not self.descriptor_fn([c]).empty]
+ return _finalise(self._dedupe_by_element_system(seeds, n_strategy))
+
+ # Candidate pool: chosen split of the material_type frame, with a valid descriptor.
+ frame = self.task_frames["material_type"]
+ index = (
+ frame.index if cfg.inverse_seed_split == "all" else frame.index[frame["split"] == cfg.inverse_seed_split]
+ )
+ pool = [c for c in index if c in self._desc_cache or not self.descriptor_fn([c]).empty]
+ if not pool:
+ return {"strategy_seeds": [], "explicit_seeds": appended}
+
+ if cfg.inverse_seed_strategy == "random":
+ rng = np.random.default_rng(cfg.random_seed)
+ shuffled = [pool[i] for i in rng.permutation(len(pool))]
+ return _finalise(self._dedupe_by_element_system(shuffled, n_strategy))
+
+ # "top_qc": highest predicted QC probability, then element-system dedup.
+ x, pool = self._descriptor_tensor(pool, device)
+ probs = qc_prob_fn(x)
+ ranked = [pool[i] for i in np.argsort(probs)[::-1]]
+ return _finalise(self._dedupe_by_element_system(ranked, n_strategy))
+
+ def _decode_compositions_from_descriptor(self, descriptors: np.ndarray) -> list[str]:
+ """Latent-path composition output: AE-decoded descriptor → KMD.inverse → formula string."""
+ try:
+ weights = self._kmd.inverse(descriptors)
+ except Exception as exc: # pragma: no cover - QP edge cases
+ logger.warning(f"KMD.inverse failed ({exc}); skipping composition decoding.")
+ return [""] * descriptors.shape[0]
+ return _format_weights(weights)
+
+ def _inverse_design(
+ self,
+ model,
+ *,
+ record_trajectory: bool = False,
+ per_seed_trajectories: bool = False,
+ animation_formats: tuple[str, ...] = ("gif",),
+ ) -> dict[str, Any]:
+ """Run the 8 inverse-design configurations against each scenario on the same seeds.
+
+ The configurations are defined at module level in :data:`INVERSE_PATH_CONFIGS`, mirroring
+ the demo's ``paper_inverse_comparison.py``:
+
+ * **latent** (3 rows): ``optimize_latent`` with ``ae_align_scale ∈ {0.0, 0.25, 1.0}``
+ (failure / mid / max alignment).
+ * **composition** (5 rows): ``optimize_composition`` with seed_blend / palette / diversity
+ knobs swept — strict seed, blended seed, blended + palette, blended + palette + low
+ diversity, and random init (no seed) as the no-seed-bias control.
+
+ Saves per-path JSON + plot under ``inverse_design///`` plus a per-scenario
+ ``summary.json`` aggregating headline stats, and a top-level ``seeds.json`` recording the
+ strategy- vs explicit-appended seed split.
+
+ When ``record_trajectory`` is set we additionally emit per-step trajectory artefacts
+ (``trajectories/.npz`` + ``trajectories/trajectory__.{png,gif,…}``) per
+ scenario, using ``paper_inverse_comparison._emit_trajectory_outputs`` so the figures
+ match the demo verbatim. ``animation_formats`` controls the animation outputs; pass
+ ``("none",)`` to skip animations (the static plot still appears). ``per_seed_trajectories``
+ additionally emits one plot+animation per ``(path × seed)``.
+ """
+ cfg = self.config
+ device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
+ model.eval()
+ inv_root = self.output_dir / "inverse_design"
+ inv_root.mkdir(parents=True, exist_ok=True)
+
+ def _qc_prob(x: torch.Tensor) -> np.ndarray:
+ with torch.no_grad():
+ h = torch.tanh(model.encoder(x))
+ probs = torch.softmax(model.task_heads["material_type"](h), dim=-1)
+ return probs[:, QC_CLASSES].sum(dim=-1).cpu().numpy()
+
+ def _reg_preds(x: torch.Tensor, tasks: list[str]) -> dict[str, np.ndarray]:
+ with torch.no_grad():
+ h = torch.tanh(model.encoder(x))
+ return {t: model.task_heads[t](h).squeeze(-1).cpu().numpy() for t in tasks}
+
+ # Same seeds for every scenario, so all eight paths are directly comparable.
+ seed_split = self._select_seeds(model, device, _qc_prob)
+ seeds_all = seed_split["strategy_seeds"] + seed_split["explicit_seeds"]
+ if not seeds_all:
+ logger.warning("No seeds available for inverse design.")
+ return {}
+ x_seed, seeds = self._descriptor_tensor(seeds_all, device)
+ if not seeds:
+ logger.warning("No seeds have computable descriptors; aborting inverse design.")
+ return {}
+
+ # Composition path shares: kernel + per-seed initial weight tensor (B, n_components).
+ kmd_kernel = self._kmd.kernel_torch(device=device, dtype=dtype)
+ w_seed = _seed_weights_from_compositions(seeds, n_components=len(DEFAULT_ELEMENTS)).to(
+ device=device, dtype=dtype
+ )
+
+ # Top-level seeds.json with the strategy / explicit split (single source of truth across
+ # all scenarios). Per-path subdirs record their own ``seeds`` field for completeness.
+ seeds_meta = {
+ "strategy": cfg.inverse_seed_strategy,
+ "strategy_split": cfg.inverse_seed_split,
+ "n_target": cfg.inverse_n_seeds,
+ "n_used": len(seeds),
+ "strategy_seeds": [c for c in seed_split["strategy_seeds"] if c in seeds],
+ "explicit_seeds": [c for c in seed_split["explicit_seeds"] if c in seeds],
+ "all_seeds_used": seeds,
+ }
+ (inv_root / "seeds.json").write_text(json.dumps(seeds_meta, indent=2), encoding="utf-8")
+
+ # The shared ``plot_element_frequency_heatmap`` reads the seed list directly so it can
+ # mark x-tick labels that are absent from every seed as "discovered" — we no longer
+ # need to pre-compute a seed_element_pool here.
+
+ out: dict[str, Any] = {"seeds": seeds_meta, "scenarios": {}}
+ for sc in cfg.inverse_scenarios:
+ logger.info(f"=== Inverse design [{sc.name}]: targets={dict(zip(sc.reg_tasks, sc.reg_targets))} ===")
+ sc_dir = inv_root / sc.name
+ sc_dir.mkdir(parents=True, exist_ok=True)
+ reg_targets = {t: v for t, v in zip(sc.reg_tasks, sc.reg_targets)}
+
+ # Per-scenario targets.json (plan §5) — separate from results so a slide author can
+ # quote the objective without parsing the full result dump.
+ (sc_dir / "targets.json").write_text(
+ json.dumps(
+ {
+ "name": sc.name,
+ "primary": {"task": "material_type", "class_indices": QC_CLASSES, "direction": "max"},
+ "secondary": [
+ {"task": t, "target": v, "direction": "min" if v < 0 else "max"}
+ for t, v in reg_targets.items()
+ ],
+ },
+ indent=2,
+ ),
+ encoding="utf-8",
+ )
+
+ before_qc = _qc_prob(x_seed)
+ before_reg = _reg_preds(x_seed, sc.reg_tasks)
+
+ paths: dict[str, dict[str, Any]] = {}
+ for path_cfg in INVERSE_PATH_CONFIGS:
+ key = path_cfg["key"]
+ path_dir = sc_dir / key
+ if path_cfg["method"] == "latent":
+ paths[key] = self._run_latent_path(
+ model,
+ x_seed,
+ seeds,
+ reg_targets,
+ path_dir,
+ ae_align_scale=path_cfg["ae_align_scale"],
+ label=path_cfg["label"],
+ _qc_prob_fn=_qc_prob,
+ _reg_preds_fn=_reg_preds,
+ record_trajectory=record_trajectory,
+ )
+ else:
+ # Composition row: resolve the palette sentinel and seed/random init.
+ allowed = (
+ list(cfg.inverse_composition_allowed_elements)
+ if path_cfg["allowed"] == _PALETTE_SENTINEL
+ else path_cfg["allowed"]
+ )
+ init = path_cfg["init"]
+ paths[key] = self._run_composition_path(
+ model,
+ kmd_kernel,
+ w_seed if init == "seed" else None,
+ seeds,
+ reg_targets,
+ path_dir,
+ init=init,
+ blend=path_cfg["blend"] if init == "seed" else None,
+ allowed=allowed,
+ diversity=path_cfg["diversity"],
+ label=path_cfg["label"],
+ _qc_prob_fn=_qc_prob,
+ _reg_preds_fn=_reg_preds,
+ record_trajectory=record_trajectory,
+ )
+
+ scenario_summary = {
+ "name": sc.name,
+ "reg_targets": reg_targets,
+ "n_seeds": len(seeds),
+ "qc_before_mean": float(before_qc.mean()),
+ "paths": {
+ path_name: {
+ "qc_after_mean": float(np.mean(p["qc_after_decode"])),
+ "qc_after_std": float(np.std(p["qc_after_decode"])),
+ "reg_after_decode_mean": {t: float(np.mean(p["reg_after_decode"][t])) for t in reg_targets},
+ "reg_after_decode_std": {t: float(np.std(p["reg_after_decode"][t])) for t in reg_targets},
+ }
+ for path_name, p in paths.items()
+ },
+ }
+ (sc_dir / "summary.json").write_text(json.dumps(scenario_summary, indent=2), encoding="utf-8")
+ self._plot_inverse_scenario(sc, before_qc, before_reg, paths, reg_targets, sc_dir)
+ # Shared heatmap: pass per-path ``label`` + ``decoded_composition`` lists so the
+ # x-tick / colourbar / title styling matches the demo's paper_inverse_comparison.
+ heatmap_methods = [
+ {
+ "label": INVERSE_PATH_CONFIGS_BY_KEY[key]["label"],
+ "decoded_composition": p.get("decoded_composition", []) or [],
+ }
+ for key, p in paths.items()
+ if key in INVERSE_PATH_CONFIGS_BY_KEY
+ ]
+ plot_element_frequency_heatmap(heatmap_methods, list(seeds), sc_dir / "element_frequency_heatmap.png")
+
+ # ── per-scenario figures copied from the demo's ``paper_inverse_comparison.py`` ──
+ # The runner used to emit only the (boxplot) ``comparison.png`` and the
+ # ``element_frequency_heatmap.png``; the per-seed scatter and 1:1 mapping figures
+ # lived only in the demo. We import and call the demo's helpers directly so the
+ # two surfaces never drift on plot style or legend ordering. Inputs are built once
+ # per scenario from the same ``paths`` dict the existing plotters consume — no extra
+ # forward passes, no training touch-up.
+ results_for_demo_helpers = [
+ {
+ "method": paths[c["key"]]["method"],
+ "label": paths[c["key"]]["label"],
+ "qc_after_decode": paths[c["key"]]["qc_after_decode"],
+ "reg_after_decode": paths[c["key"]]["reg_after_decode"],
+ # ``_plot_seed_to_optimized_mapping`` doesn't need these but the scatter
+ # helper's legend grouping reads ``method``; carry them anyway so a future
+ # change picking up ``align_scale`` doesn't break silently.
+ "align_scale": paths[c["key"]].get("ae_align_scale"),
+ "decoded_composition": paths[c["key"]].get("decoded_composition", []),
+ }
+ for c in INVERSE_PATH_CONFIGS
+ if c["key"] in paths
+ ]
+ _plot_qc_vs_reg_scatter(
+ results_for_demo_helpers,
+ reg_targets,
+ sc_dir / "qc_vs_secondary_scatter.png",
+ title=f"QC probability vs secondary properties · {sc.name}",
+ seed_qc=before_qc,
+ seed_reg=before_reg,
+ )
+ # Per-path seed → optimised composition mapping. Skip ``comp_random`` (no per-row
+ # seed correspondence — its ``seeds`` field is a ``random_start_N`` placeholder).
+ for c in INVERSE_PATH_CONFIGS:
+ key = c["key"]
+ if key not in paths or key == "comp_random":
+ continue
+ p = paths[key]
+ decoded = p.get("decoded_composition", [])
+ if not decoded:
+ continue
+ _plot_seed_to_optimized_mapping(
+ seeds=list(seeds),
+ decoded=list(decoded),
+ out_path=sc_dir / f"seed_to_optimized__{key}.png",
+ title=f"Seed → optimised composition · {c['label']}",
+ seed_qc=before_qc,
+ seed_reg=before_reg,
+ optimized_qc=np.asarray(p["qc_after_decode"]),
+ optimized_reg={t: np.asarray(p["reg_after_decode"][t]) for t in reg_targets},
+ reg_targets=reg_targets,
+ )
+
+ # ── trajectory persistence + figures ──
+ # When ``record_trajectory`` is on, every path's ``_run_*_path`` returned a result
+ # carrying ``trajectory_targets`` (steps, B, T) and ``trajectory_weights``
+ # (steps, B, n_components). For a 300-step / B=20 / 94-component run those arrays
+ # together weigh ~3 MB per path × 8 paths × 3 scenarios ≈ 72 MB — too heavy to inline
+ # into ``inverse_design.json``. Persist as compressed npz next to each scenario's
+ # plots, then pop the inline arrays so the json stays browsable. Filenames use
+ # ``paper_inverse_comparison._path_slug`` so the demo's trajectory consumers can
+ # ingest these files directly.
+ if record_trajectory:
+ traj_dir = sc_dir / "trajectories"
+ traj_dir.mkdir(exist_ok=True)
+ results_for_traj: list[dict[str, Any]] = []
+ for key, p in paths.items():
+ if "trajectory_targets" not in p or "trajectory_weights" not in p:
+ continue
+ # ``_path_slug`` reads ``method``, ``label``, and (for latent) ``align_scale``.
+ # Our latent rows store ``ae_align_scale``; mirror it onto ``align_scale`` for
+ # the slug call (and so the demo's ``_emit_trajectory_outputs`` can group
+ # latents by α).
+ slug_record: dict[str, Any] = {
+ "method": p["method"],
+ "label": p["label"],
+ "align_scale": p.get("ae_align_scale"),
+ }
+ slug = _path_slug(slug_record)
+ npz_path = traj_dir / f"{slug}.npz"
+ np.savez_compressed(
+ npz_path,
+ targets=np.asarray(p["trajectory_targets"], dtype=np.float32),
+ weights=np.asarray(p["trajectory_weights"], dtype=np.float32),
+ )
+ # Drop the huge arrays now that they live on disk; carry a reference in their
+ # place so ``inverse_design.json`` consumers can find them.
+ p.pop("trajectory_targets", None)
+ p.pop("trajectory_weights", None)
+ p["trajectory_file"] = str(npz_path.relative_to(sc_dir))
+ # ``_emit_trajectory_outputs`` reads the npz via ``out_dir / r["trajectory_file"]``,
+ # so the result dict here has to use the *scenario-relative* path too.
+ results_for_traj.append(
+ {
+ **slug_record,
+ "qc_after_decode": p["qc_after_decode"],
+ "reg_after_decode": p["reg_after_decode"],
+ "trajectory_file": p["trajectory_file"],
+ }
+ )
+ if results_for_traj:
+ _emit_trajectory_outputs(
+ results=results_for_traj,
+ reg_targets=reg_targets,
+ seed_qc=before_qc,
+ seed_reg=before_reg,
+ out_dir=sc_dir,
+ traj_dir=traj_dir,
+ per_seed=per_seed_trajectories,
+ animation_formats=animation_formats,
+ )
+
+ # Explicit guard: ``list and float`` was a clever but fragile non-empty check —
+ # an empty ``qc_after_decode`` (no successful seeds for a path) returned the empty
+ # list, which then crashed ``f"{...:.3f}"`` with ``TypeError`` on format. NaN keeps
+ # the join uniform and is the natural "no data" sentinel for downstream readers.
+ def _qc_mean(path_name: str) -> float:
+ qc = paths[path_name].get("qc_after_decode") or []
+ return float(np.mean(qc)) if qc else float("nan")
+
+ qc_summary = " · ".join(f"{name}={_qc_mean(name):.3f}" for name in INVERSE_PATHS)
+ logger.info(f"[{sc.name}] QC after-decode mean — {qc_summary}")
+
+ out["scenarios"][sc.name] = {**scenario_summary, "paths_details": paths}
+ return out
+
+ # --- inverse path runners -------------------------------------------------
+
+ def _run_latent_path(
+ self,
+ model,
+ x_seed: torch.Tensor,
+ seeds: list[str],
+ reg_targets: dict[str, float],
+ path_dir: Path,
+ *,
+ ae_align_scale: float,
+ label: str,
+ _qc_prob_fn,
+ _reg_preds_fn,
+ record_trajectory: bool = False,
+ ) -> dict[str, Any]:
+ """Latent-space optimisation with cycle-consistency at a fixed ``ae_align_scale``.
+
+ When ``record_trajectory`` is set we (a) ask ``optimize_latent`` to keep its per-step
+ AE-decoded input, and (b) decode each step through ``KMD.inverse`` to recover the per-step
+ composition recipe — same trick the demo's ``_run_latent_method`` uses, so the trajectory
+ is on the same surface as the final ``reg_after_decode`` values. The huge ``(steps, B, *)``
+ arrays land in ``result["trajectory_targets"]`` / ``result["trajectory_weights"]``; the
+ caller is responsible for persisting them as a compressed npz and popping them off the
+ result dict so they don't bloat ``inverse_design.json``.
+ """
+ cfg = self.config
+ path_dir.mkdir(parents=True, exist_ok=True)
+ reg_names = list(reg_targets)
+
+ before_qc = _qc_prob_fn(x_seed)
+ before_reg = _reg_preds_fn(x_seed, reg_names)
+
+ res = model.optimize_latent(
+ initial_input=x_seed,
+ task_targets=reg_targets,
+ class_targets={"material_type": QC_CLASSES},
+ class_target_weight=cfg.inverse_class_weight,
+ ae_align_scale=ae_align_scale,
+ optimize_space="latent",
+ steps=cfg.inverse_steps,
+ lr=cfg.inverse_lr,
+ record_input_trajectory=record_trajectory,
+ )
+ achieved_latent = res.optimized_target[:, 0, :].cpu().numpy()
+ optimized_desc = res.optimized_input[:, 0, :]
+ optimized_desc_np = optimized_desc.detach().cpu().numpy()
+ after_qc = _qc_prob_fn(optimized_desc)
+ after_reg = _reg_preds_fn(optimized_desc, reg_names)
+ try:
+ optimized_weights = self._kmd.inverse(optimized_desc_np)
+ except Exception as exc: # pragma: no cover
+ logger.warning(f"KMD.inverse failed for latent path ({exc}); weights left empty.")
+ optimized_weights = np.zeros((optimized_desc_np.shape[0], len(DEFAULT_ELEMENTS)))
+ decoded = _format_weights(optimized_weights)
+
+ result = {
+ "method": "latent",
+ "label": label,
+ "ae_align_scale": ae_align_scale,
+ "seeds": list(seeds),
+ "qc_before": before_qc.tolist(),
+ "qc_after_decode": after_qc.tolist(),
+ "reg_before": {t: before_reg[t].tolist() for t in reg_names},
+ "reg_achieved_latent": {t: achieved_latent[:, j].tolist() for j, t in enumerate(reg_names)},
+ "reg_after_decode": {t: after_reg[t].tolist() for t in reg_names},
+ "decoded_composition": decoded,
+ "optimized_descriptor": optimized_desc_np.tolist(),
+ "optimized_weights": optimized_weights.tolist(),
+ }
+ # Trajectory arrays (kept out of result.json — caller persists them as a separate npz).
+ if record_trajectory and res.input_trajectory is not None and res.trajectory is not None:
+ # ``res.trajectory`` is (B, R=1, steps, T) — squeeze restart, permute to (steps, B, T).
+ result["trajectory_targets"] = res.trajectory[:, 0, :, :].cpu().numpy().transpose(1, 0, 2)
+ # ``res.input_trajectory`` is (B, R=1, steps, input_dim) → (steps, B, input_dim);
+ # ``KMD.inverse`` then maps each step's descriptor batch → (B, n_components).
+ per_step_inputs = res.input_trajectory[:, 0, :, :].cpu().numpy().transpose(1, 0, 2)
+ result["trajectory_weights"] = np.stack(
+ [self._kmd.inverse(per_step_inputs[s]) for s in range(per_step_inputs.shape[0])]
+ ) # (steps, B, n_components) — one QP solve per (step × seed), ~10 % overhead.
+ # Write result.json without the trajectory arrays (they live in the npz once persisted).
+ json_payload = {k: v for k, v in result.items() if k not in {"trajectory_targets", "trajectory_weights"}}
+ (path_dir / "result.json").write_text(json.dumps(json_payload, indent=2), encoding="utf-8")
+ return result
+
+ def _run_composition_path(
+ self,
+ model,
+ kmd_kernel: torch.Tensor,
+ w_seed: torch.Tensor | None,
+ seeds: list[str],
+ reg_targets: dict[str, float],
+ path_dir: Path,
+ *,
+ init: str,
+ blend: float | None,
+ allowed: str | list[str],
+ diversity: float,
+ label: str,
+ _qc_prob_fn,
+ _reg_preds_fn,
+ record_trajectory: bool = False,
+ ) -> dict[str, Any]:
+ """Composition-space optimisation via differentiable KMD (``optimize_composition``).
+
+ ``init="seed"`` uses ``w_seed`` + ``seed_blend``; ``init="random"`` ignores ``w_seed`` and
+ runs ``n_starts = len(seeds)`` so the per-row budget matches the latent run.
+
+ When ``record_trajectory`` is set, the per-step weight + reg-target trajectories come
+ straight from ``optimize_composition`` (composition's optim variable already lives on the
+ right surface, so no per-step KMD.inverse is needed — unlike the latent path).
+ """
+ cfg = self.config
+ path_dir.mkdir(parents=True, exist_ok=True)
+ reg_names = list(reg_targets)
+
+ if init == "seed":
+ if w_seed is None:
+ raise ValueError("Composition path with init='seed' requires w_seed.")
+ init_kwargs: dict[str, Any] = {"initial_weights": w_seed, "seed_blend": blend}
+ elif init == "random":
+ init_kwargs = {"initial_weights": None, "n_starts": len(seeds)}
+ else:
+ raise ValueError(f"Unknown init mode in composition path: {init!r}")
+
+ res = model.optimize_composition(
+ kmd_kernel,
+ task_targets=reg_targets,
+ class_targets={"material_type": QC_CLASSES},
+ class_target_weight=cfg.inverse_class_weight,
+ diversity_scale=diversity,
+ allowed_elements=allowed,
+ steps=cfg.inverse_steps,
+ lr=cfg.inverse_lr,
+ record_weights_trajectory=record_trajectory,
+ **init_kwargs,
+ )
+ # Composition's result tensors are 2D — ``(B, x_dim)`` / ``(B, n_components)`` /
+ # ``(B, T)`` — no restart axis, so no ``[:, 0, :]`` slicing (unlike ``optimize_latent``).
+ optimized_desc = res.optimized_descriptor # (B, x_dim) — w @ K, no AE round-trip
+ optimized_desc_np = optimized_desc.detach().cpu().numpy()
+ w_final = res.optimized_weights.detach().cpu().numpy()
+ achieved_latent = res.optimized_target.detach().cpu().numpy() # (B, T)
+ after_qc = _qc_prob_fn(optimized_desc)
+ after_reg = _reg_preds_fn(optimized_desc, reg_names)
+ decoded = _format_weights(w_final)
+
+ # Random init has no per-row correspondence with the seed list — preserve the seed list
+ # only when the init was seeded; otherwise label the rows as random restarts.
+ seed_labels = list(seeds) if init == "seed" else [f"random_start_{i}" for i in range(len(seeds))]
+
+ result = {
+ "method": "composition",
+ "label": label,
+ "init": init,
+ "seed_blend": blend,
+ "allowed_elements": allowed,
+ "diversity_scale": diversity,
+ "seeds": seed_labels,
+ "qc_after_decode": after_qc.tolist(),
+ "reg_achieved_latent": {t: achieved_latent[:, j].tolist() for j, t in enumerate(reg_names)},
+ "reg_after_decode": {t: after_reg[t].tolist() for t in reg_names},
+ "decoded_composition": decoded,
+ "optimized_descriptor": optimized_desc_np.tolist(),
+ "optimized_weights": w_final.tolist(),
+ }
+ # Trajectory arrays — same shape convention as the latent path so ``_emit_trajectory_outputs``
+ # consumes both interchangeably. ``res.trajectory`` is already (steps, B, T) and
+ # ``res.weights_trajectory`` is already (steps, B, n_components) — no transpose / decode.
+ if record_trajectory and res.weights_trajectory is not None and res.trajectory is not None:
+ result["trajectory_targets"] = res.trajectory.cpu().numpy()
+ result["trajectory_weights"] = res.weights_trajectory.cpu().numpy()
+ json_payload = {k: v for k, v in result.items() if k not in {"trajectory_targets", "trajectory_weights"}}
+ (path_dir / "result.json").write_text(json.dumps(json_payload, indent=2), encoding="utf-8")
+ return result
+
+ # ------------------------------------------------------------------ plots
+ # ``plot_parity`` / ``plot_confusion`` / ``plot_kr_sequences`` now live in
+ # :mod:`continual_rehearsal_common`; they were verbatim copies of demo's and caused PR
+ # #18's K=0 ``NameError`` to ship in demo for several PRs. The runner-specific plots
+ # below (``_plot_forgetting`` uses ``self._task_colors``; the inverse-design plotters use
+ # the 8-path layout) stay as bound methods.
+
+ def _plot_forgetting(self, metric_history):
+ n_tasks = sum(1 for pts in metric_history.values() if pts)
+ fig, ax = plt.subplots(figsize=(14, max(5.5, 0.32 * n_tasks + 3)))
+ all_steps: set[int] = set()
+ for task_name, points in metric_history.items():
+ if not points:
+ continue
+ steps = [s for s, _ in points]
+ vals = [v for _, v in points]
+ all_steps.update(steps)
+ is_clf = TASK_SPECS[task_name]["kind"] == "clf"
+ ax.plot(
+ steps,
+ vals,
+ marker="s" if is_clf else "o",
+ ms=5,
+ ls="--" if is_clf else "-",
+ color=self._task_colors.get(task_name, "#888888"),
+ label=_display(task_name) + (" · accuracy" if is_clf else ""),
+ )
+ if all_steps:
+ ax.set_xticks(sorted(all_steps))
+ ax.set_xlabel("Continual finetuning step (a new task is added at each step)")
+ ax.set_ylabel("Primary metric · R² (regression) / accuracy (classification)")
+ ax.set_title("Per-task performance across continual finetuning")
+ ncol = 1 if n_tasks <= 20 else 2
+ ax.legend(fontsize=8, ncol=ncol, loc="upper left", bbox_to_anchor=(1.01, 1.0), borderaxespad=0.0)
+ out_path = self.training_dir / "forgetting_trajectory.png"
+ fig.savefig(out_path)
+ plt.close(fig)
+ logger.info(f"Saved forgetting trajectory to {out_path}")
+
+ def _plot_inverse_scenario(
+ self,
+ scenario,
+ before_qc: np.ndarray,
+ before_reg: dict[str, np.ndarray],
+ paths: dict[str, dict[str, Any]],
+ reg_targets: dict[str, float],
+ sc_dir: Path,
+ ) -> None:
+ """Compare the 8 inverse-design configurations side-by-side on QC + each reg target.
+
+ Mirrors the demo's ``paper_inverse_comparison.py`` plot — same suptitle, panel titles
+ (via ``REG_TASK_TITLES``), x-tick labels (``INVERSE_PATH_CONFIGS[*]["label"]``), and
+ two-tone colour code (green ``#55A868`` for latent rows, blue ``#2563EB`` for composition
+ rows). We keep our boxplot style (vs the demo's bar+errorbar) to surface the full per-seed
+ distribution. Per the user override, the QC panel title is ``"Probability (QC)"``.
+ """
+ reg_names = list(reg_targets)
+ n_panels = 1 + len(reg_names)
+ fig, axes = plt.subplots(1, n_panels, figsize=(5.6 * n_panels, 5.6), squeeze=False)
+ axes = axes[0]
+
+ configs_in_order = [c for c in INVERSE_PATH_CONFIGS if c["key"] in paths]
+ path_labels = [c["label"] for c in configs_in_order]
+ # Two-tone colour code, matching the demo.
+ face_colors = ["#55A868" if c["method"] == "latent" else "#2563EB" for c in configs_in_order]
+ x_pos = list(range(len(configs_in_order)))
+
+ def _boxplot(ax, vals_per_path: list[list[float]]) -> None:
+ """Two-tone per-row boxplot. Box face matches the row's method colour at α=0.25."""
+ bp = ax.boxplot(
+ vals_per_path,
+ positions=x_pos,
+ widths=0.6,
+ patch_artist=True,
+ medianprops=dict(color="#222222", lw=1.4),
+ flierprops=dict(marker="o", mec="none", ms=3, alpha=0.55),
+ )
+ for patch, fc in zip(bp["boxes"], face_colors):
+ patch.set(facecolor=fc, alpha=0.25, edgecolor=fc)
+ for whisker, fc in zip(bp["whiskers"], [c for c in face_colors for _ in range(2)]):
+ whisker.set_color(fc)
+ for cap, fc in zip(bp["caps"], [c for c in face_colors for _ in range(2)]):
+ cap.set_color(fc)
+ for flier, fc in zip(bp["fliers"], face_colors):
+ flier.set(markerfacecolor=fc)
+
+ def _set_xticks(ax) -> None:
+ ax.set_xticks(x_pos)
+ ax.set_xticklabels(path_labels, rotation=45, ha="right", fontsize=9)
+
+ # Panel 1: QC probability. Title is the user-specified override "Probability (QC)";
+ # ylabel + target line follow the demo.
+ axq = axes[0]
+ qc_vals = [list(paths[c["key"]]["qc_after_decode"]) for c in configs_in_order]
+ _boxplot(axq, qc_vals)
+ axq.axhline(1.0, color="#C44E52", ls="--", lw=1.4, label="target = 1.0")
+ _set_xticks(axq)
+ axq.set_ylim(-0.02, 1.05)
+ axq.set_ylabel("P(quasicrystal)")
+ axq.set_title("Probability (QC)")
+ axq.legend(fontsize=9, loc="lower right")
+
+ # Remaining panels: per regression target. Title pulled from REG_TASK_TITLES with units
+ # and an arrow indicating whether the target is below (↓) or above (↑) the model's baseline.
+ for ax, (t, tgt) in zip(axes[1:], reg_targets.items()):
+ vals = [list(paths[c["key"]]["reg_after_decode"][t]) for c in configs_in_order]
+ _boxplot(ax, vals)
+ ax.axhline(tgt, color="#C44E52", ls="--", lw=1.4, label=f"target = {tgt:+.1f}")
+ _set_xticks(ax)
+ ax.set_ylabel("Predicted value")
+ ax.set_title(REG_TASK_TITLES.get(t, t))
+ ax.legend(fontsize=9, loc="best")
+
+ fig.suptitle(
+ "Inverse-design comparison: latent (ae_align_scale sweep) vs differentiable KMD (configs)",
+ y=1.00,
+ )
+ out = sc_dir / "comparison.png"
+ fig.savefig(out, dpi=150, bbox_inches="tight")
+ plt.close(fig)
+ logger.info(f"Saved inverse-design comparison plot to {out}")
+
+ # ------------------------------------------------------------------ slide-prep (plan §6)
+
+ def _counts(self) -> dict[str, int]:
+ seq = self.config.task_sequence
+ return {
+ "n_tasks": len(seq),
+ "n_reg": sum(1 for t in seq if TASK_SPECS[t]["kind"] == "reg"),
+ "n_kr": sum(1 for t in seq if TASK_SPECS[t]["kind"] == "kr"),
+ "n_clf": sum(1 for t in seq if TASK_SPECS[t]["kind"] == "clf"),
+ }
+
+ def _dataset_summary(self) -> list[tuple[str, int, int]]:
+ """(dataset display, #tasks, #unique compositions used) per source, in stable order."""
+ rows = []
+ for src in ("qc", "phonix", "superconductor", "magnetic"):
+ tasks = [t for t in self.config.task_sequence if TASK_SPECS[t]["source"] == src]
+ if not tasks:
+ continue
+ keys = set().union(*[set(self.task_frames[t].index) for t in tasks])
+ rows.append((SOURCE_DISPLAY[src], len(tasks), len(keys)))
+ return rows
+
+ def _final_target_metrics(self, records: list[dict[str, Any]]) -> dict[str, dict[str, float]]:
+ """Final-step metrics for the headline tasks the summary must report."""
+ final = records[-1]["metrics"] if records else {}
+ headline = ["formation_energy", "magnetic_moment", "tc", "klat", "material_type"]
+ return {t: final.get(t, {}) for t in headline if t in self.config.task_sequence}
+
+ # --- element-frequency heatmap ------------------------------------------
+ # The runner used to carry its own bound-method heatmap that consumed the per-path
+ # ``optimized_weights`` directly; we now share ``plot_element_frequency_heatmap`` from
+ # ``continual_rehearsal_common`` with the demo runner (same x-tick discovered-element
+ # styling, same colourbar label, same title format). The shared helper reads from the
+ # already-decoded ``decoded_composition`` strings (already in ``paths[key]``), so we
+ # don't need ``DEFAULT_ELEMENTS`` order or an ``eps`` threshold here.
+
+ # --- markdown writers (plan §6) -------------------------------------------
+
+ def _write_inverse_summary_md(self, inverse: dict[str, Any]) -> None:
+ """Compact cross-scenario summary (plan §6).
+
+ Scenarios have **heterogeneous** regression-target sets (e.g. scenario2 has 3 reg targets
+ vs 2 for the others), so a single flat table would let later rows spill past the header.
+ We keep the cross-scenario table to **QC only** (the metric every scenario shares), and
+ emit a per-scenario reg-target block underneath.
+ """
+ scenarios = inverse.get("scenarios", {}) if isinstance(inverse, dict) else {}
+ if not scenarios:
+ return
+ lines: list[str] = [
+ "# Inverse design — compact cross-scenario summary\n",
+ "Auto-generated. The headline QC table aggregates across all scenarios; per-scenario "
+ "reg-target tables follow. Full per-seed arrays in "
+ "`inverse_design///result.json`.\n",
+ ]
+
+ # Cross-scenario QC table — the one metric every scenario shares.
+ lines.append("## QC probability after decode\n")
+ lines.append("| scenario | path | QC mean | QC std |")
+ lines.append("|---|---|---:|---:|")
+ for name, data in scenarios.items():
+ paths_meta = data.get("paths", {})
+ for path_name in INVERSE_PATHS:
+ meta = paths_meta.get(path_name, {})
+ qc_m = meta.get("qc_after_mean", float("nan"))
+ qc_s = meta.get("qc_after_std", float("nan"))
+ lines.append(f"| {name} | {path_name} | {qc_m:.3f} | {qc_s:.3f} |")
+ lines.append("")
+
+ # Per-scenario regression targets (columns match that scenario's reg_targets).
+ for name, data in scenarios.items():
+ reg_targets = data.get("reg_targets", {})
+ paths_meta = data.get("paths", {})
+ lines.append(f"## {name} — regression targets (after decode)\n")
+ secondary = " · ".join(f"{_display(t)} {_arrow(v)} {v:+.1f}" for t, v in reg_targets.items())
+ lines.append(f"Targets: {secondary}\n")
+ header = ["path", *[_display(t) for t in reg_targets]]
+ lines.append("| " + " | ".join(header) + " |")
+ lines.append("|" + "|".join(["---"] * len(header)) + "|")
+ for path_name in INVERSE_PATHS:
+ meta = paths_meta.get(path_name, {})
+ row = [path_name]
+ for t in reg_targets:
+ row.append(f"{meta.get('reg_after_decode_mean', {}).get(t, float('nan')):+.2f}")
+ lines.append("| " + " | ".join(row) + " |")
+ lines.append("")
+
+ (self.inverse_root / "SUMMARY.md").write_text("\n".join(lines), encoding="utf-8")
+ logger.info(f"Saved inverse-design SUMMARY.md to {self.inverse_root / 'SUMMARY.md'}")
+
+ def _write_analysis_md(self, records: list[dict[str, Any]], inverse: dict[str, Any]) -> None:
+ """Long-form analysis (English, plan §0a). Reads as speaker-notes feedstock for SLIDE_PREP."""
+ c = self._counts()
+ intro = {r["new_task"]: r["metrics"][r["new_task"]]["primary"] for r in records}
+ final = records[-1]["metrics"] if records else {}
+ lines: list[str] = []
+ lines.append("# Analysis — continual rehearsal + inverse design\n")
+ lines.append(
+ "Long-form narrative analysis of this run. The structured slide outline lives in\n"
+ "[`SLIDE_PREP.md`](SLIDE_PREP.md); the compact cross-scenario table lives in\n"
+ "[`inverse_design/SUMMARY.md`](inverse_design/SUMMARY.md). Numbers below are\n"
+ "regenerable from the raw arrays under `training/` and `inverse_design/`.\n",
+ )
+
+ lines.append("## Run scale\n")
+ lines.append(
+ f"- **{c['n_tasks']} supervised tasks**: {c['n_reg']} regression · "
+ f"{c['n_kr']} kernel regression · {c['n_clf']} classification, plus the always-on autoencoder.\n"
+ )
+ lines.append("- Datasets (tasks · unique compositions used):")
+ for name, ntask, nkeys in self._dataset_summary():
+ lines.append(f" - {name}: {ntask} · {nkeys}")
+ lines.append("")
+
+ lines.append("## Continual learning — is there forgetting?\n")
+ drops = []
+ for task in self.config.task_sequence:
+ i = intro.get(task)
+ f_v = final.get(task, {}).get("primary")
+ if i is not None and f_v is not None and np.isfinite(i) and np.isfinite(f_v):
+ drops.append((task, i, f_v, f_v - i))
+ early = drops[: max(1, len(drops) // 2)]
+ mean_early_delta = float(np.mean([d for *_, d in early])) if early else float("nan")
+ verdict = "stable (no clear forgetting)" if mean_early_delta > -0.05 else "some forgetting"
+ lines.append(
+ f"Mean (final − at-intro) primary metric over the *earlier-trained half* is "
+ f"**{mean_early_delta:+.3f}** → **{verdict}**. The full per-step trajectory is in "
+ "`training/forgetting_trajectory.png`; per-task raw `(composition, true, pred)` for "
+ "every step is in `training/stepNN_/_pred.parquet` + `_metrics.json` "
+ "— rebuild any panel from those without retraining.\n"
+ )
+ lines.append("| task | at intro | final | Δ |")
+ lines.append("|---|---:|---:|---:|")
+ for task, i, f_v, d in drops:
+ lines.append(f"| {_display(task)} | {i:+.3f} | {f_v:+.3f} | {d:+.3f} |")
+ lines.append("")
+
+ lines.append("## Final model — headline targets (inverse-design heads)\n")
+ lines.append("| task | metric | value |")
+ lines.append("|---|---|---:|")
+ for task, m in self._final_target_metrics(records).items():
+ spec = TASK_SPECS[task]
+ metric_name = "accuracy" if spec["kind"] == "clf" else "R²"
+ val = m.get("primary", float("nan"))
+ lines.append(f"| {_display(task)} | {metric_name} | {val:+.3f} |")
+ lines.append("")
+
+ lines.append("## Inverse design — 3 scenarios × 4 paths\n")
+ lines.append(
+ "Each scenario shares the same 20 seeds (17 top-QC element-system-dedup + 3 explicit "
+ "Au-Ga-Ln). Path semantics: **latent** uses `optimize_latent(ae_align_scale=0.5)` "
+ "(PR #18 sweet spot); **composition_strict** locks the seed element support "
+ "(`seed_blend=1.0`); **composition_alloy** is the paper-headline path "
+ f"(`seed_blend≈0.95`, {len(ALLOY_PALETTE)}-element ALLOY_PALETTE — allows discovery of QC-prone "
+ "elements outside the seeds); **composition_random** ablates the seed entirely "
+ "(`n_starts=N`) to surface the model's global QC attractor — useful to motivate the "
+ "need for chemistry-constrained palettes when the global attractor falls on "
+ "unsynthesisable elements.\n"
+ )
+ scenarios = inverse.get("scenarios", {}) if isinstance(inverse, dict) else {}
+ for name, data in scenarios.items():
+ reg_targets = data.get("reg_targets", {})
+ paths_meta = data.get("paths", {})
+ paths_details = data.get("paths_details", {})
+ secondary = ", ".join(f"{_display(t)} {_arrow(v)} {v:+.1f}" for t, v in reg_targets.items())
+ lines.append(f"### {name}\n")
+ lines.append(f"- Secondary targets: {secondary}")
+ lines.append(f"- Seed mean QC (before): **{data.get('qc_before_mean', float('nan')):.3f}**")
+ lines.append("")
+ header_cells = ["path", "QC after (mean ± std)"] + [_display(t) for t in reg_targets]
+ lines.append("| " + " | ".join(header_cells) + " |")
+ lines.append("|" + "|".join(["---"] * len(header_cells)) + "|")
+ for path_name in INVERSE_PATHS:
+ meta = paths_meta.get(path_name, {})
+ qc_m = meta.get("qc_after_mean", float("nan"))
+ qc_s = meta.get("qc_after_std", float("nan"))
+ row_cells = [path_name, f"{qc_m:.3f} ± {qc_s:.3f}"]
+ for t in reg_targets:
+ row_cells.append(f"{meta.get('reg_after_decode_mean', {}).get(t, float('nan')):+.2f}")
+ lines.append("| " + " | ".join(row_cells) + " |")
+ lines.append("")
+ lines.append("One decoded example per path:")
+ for path_name in INVERSE_PATHS:
+ decoded = paths_details.get(path_name, {}).get("decoded_composition", [])
+ if decoded:
+ lines.append(f"- **{path_name}**: `{decoded[0]}`")
+ lines.append("")
+ lines.append(
+ f"Element-discovery heatmap: `inverse_design/{name}/element_frequency_heatmap.png`. "
+ f"Side-by-side path comparison: `inverse_design/{name}/comparison.png`. "
+ f"Per-path raw arrays: `inverse_design/{name}//result.json` (keys `optimized_weights` "
+ "(B, 94), `optimized_descriptor` (B, x_dim), `qc_after_decode`, `reg_after_decode`).\n"
+ )
+
+ (self.output_dir / "ANALYSIS.md").write_text("\n".join(lines), encoding="utf-8")
+ logger.info(f"Saved Markdown analysis to {self.output_dir / 'ANALYSIS.md'}")
+
+ def _write_slide_prep_md(self, records: list[dict[str, Any]], inverse: dict[str, Any]) -> None:
+ """9-slide structured handoff for the external slide author.
+
+ Mirrors the polish level of the demo's ``inverse_design_run/SLIDE_PREP.md``:
+ every section names a takeaway, slide content, speaker notes, and the canonical figure
+ (with raw-data paths so the slide author can rebuild the figure if the auto-emitted one
+ doesn't fit their layout). Numbers are computed from this run's data; interpretation
+ threads are templated stubs the slide author fills in after sanity-checking against the
+ plan §5 expected baselines (also reproduced inline).
+
+ When ``sample_per_dataset`` is set (i.e. this is a smoke / partial run rather than the
+ formal full run), a disclaimer is rendered at the top of the document; the numbers are
+ still real but the magnitudes will not match the plan §5 expected baselines.
+ """
+ cfg = self.config
+ counts = self._counts()
+ intro = {r["new_task"]: r["metrics"][r["new_task"]]["primary"] for r in records}
+ final = records[-1]["metrics"] if records else {}
+ scenarios = inverse.get("scenarios", {}) if isinstance(inverse, dict) else {}
+ seeds_meta = inverse.get("seeds", {}) if isinstance(inverse, dict) else {}
+ strategy_seeds = list(seeds_meta.get("strategy_seeds", []))
+ explicit_seeds = list(seeds_meta.get("explicit_seeds", []))
+ all_seeds = strategy_seeds + explicit_seeds
+ seed_pool: set[str] = set()
+ for s in all_seeds:
+ seed_pool |= self._element_system(s)
+
+ is_smoke = cfg.sample_per_dataset is not None or cfg.max_epochs_per_step < 20
+ run_date = _datetime.date.today().isoformat()
+
+ def _discovered(
+ path_data: dict[str, Any], threshold: float = 0.95, eps: float = 1e-3
+ ) -> list[tuple[str, float]]:
+ """Elements present in ≥ ``threshold`` fraction of a path's outputs but **0** in any seed."""
+ w = np.asarray(path_data.get("optimized_weights", []), dtype=float)
+ if w.size == 0:
+ return []
+ occ = (w > eps).mean(axis=0)
+ out: list[tuple[str, float]] = []
+ for i, frac in enumerate(occ):
+ sym = DEFAULT_ELEMENTS[i]
+ if frac >= threshold and sym not in seed_pool:
+ out.append((sym, float(frac)))
+ out.sort(key=lambda kv: -kv[1])
+ return out
+
+ def _headline(task: str) -> str:
+ spec = TASK_SPECS.get(task, {"kind": "reg"})
+ metric_name = "accuracy" if spec["kind"] == "clf" else "R²"
+ val = final.get(task, {}).get("primary", float("nan"))
+ return f"`{task}` ({metric_name} = **{val:+.3f}**)"
+
+ lines: list[str] = []
+ # ── Header ────────────────────────────────────────────────────────────────────────
+ lines.append("# Slide-prep document — handoff for the slide author (claude coworker)\n")
+ lines.append(
+ "> **What this is.** A structured outline a slide author can convert directly into deck\n"
+ "> pages. Each section corresponds to one slide / slide group and lists: (a) the\n"
+ "> takeaway, (b) what to put on the slide, (c) which file in this folder is the visual,\n"
+ "> (d) speaker-note bullets. The slide author has **full creative freedom** for layout,\n"
+ "> colours, and visual style — this document only specifies *what* to communicate, not\n"
+ "> *how*.\n"
+ )
+ lines.append(f"**Folder this document lives in:** `{self.output_dir.name}/`")
+ lines.append(f"**Run date:** {run_date}")
+ lines.append(
+ "**Data sources for every number cited:** "
+ "`training/experiment_records.json` (per-task metrics across the "
+ f"{counts['n_tasks']} training stages), "
+ "`training/metrics_table.csv` (flat per-task at-intro / final), "
+ "`training/stepNN_/_pred.parquet` (per-step raw test predictions for every "
+ "active head), `inverse_design/inverse_design.json` (full nested inverse-design dump), "
+ "and per-path `inverse_design///result.json` (raw per-seed arrays)."
+ )
+ lines.append(
+ "**Companion docs:** [`README.md`](README.md) (folder index), [`ANALYSIS.md`](ANALYSIS.md) (long-form writeup), [`inverse_design/SUMMARY.md`](inverse_design/SUMMARY.md) (compact cross-scenario table).\n"
+ )
+
+ if is_smoke:
+ lines.append(
+ "> **⚠️ Run quality note — this is a SMOKE / partial run.**\n"
+ f"> `sample_per_dataset = {cfg.sample_per_dataset}`, "
+ f"`max_epochs_per_step = {cfg.max_epochs_per_step}` "
+ "(formal full run uses `sample_per_dataset = null` and "
+ "`max_epochs_per_step = 100` + EarlyStopping). The artifact tree is structurally\n"
+ "> complete (every section below has real numbers from THIS run), but the\n"
+ "> *magnitudes* will not match the formal full-run expected baselines documented in\n"
+ "> [`docs/continual_rehearsal_full_PLAN.md`](../../docs/continual_rehearsal_full_PLAN.md) §5.\n"
+ "> The expected-baseline tables below give the slide author the magnitudes to\n"
+ "> sanity-check against before quoting numbers from this smoke run.\n"
+ )
+
+ lines.append("---\n")
+
+ # ── Slide 1 — Experimental goal ───────────────────────────────────────────────────
+ lines.append("## Slide 1 — Experimental goal: multi-property joint optimisation\n")
+ lines.append(
+ "**Takeaway.** Real materials development asks for *several properties at once* (is "
+ "the material a quasi-crystal? does it have low formation energy? does it have high "
+ "Tc / high κ_lat / high magnetic moment?). Single-property inverse-design tools don't "
+ "help. We need a joint-optimisation framework around a model that learned all those "
+ "properties together.\n"
+ )
+ lines.append("**Slide content.**")
+ lines.append('- Opening line: *"The materials-design question is rarely about a single property."*')
+ lines.append(
+ "- 2–3 illustrative property combinations to ground the audience — pulled from this run's scenarios:"
+ )
+ for name, data in scenarios.items():
+ reg_targets = data.get("reg_targets", {})
+ arrowed = ", ".join(f"{_display(t)} {_arrow(v)}" for t, v in reg_targets.items())
+ lines.append(f" - **{name}** — QC ↑ + {arrowed}")
+ lines.append('- A "wishlist → recipe" arrow showing the inverse direction: target properties → composition.\n')
+ lines.append("**Speaker notes.**")
+ lines.append(
+ "- DFT / experiment loops are prohibitively expensive for joint searches over many target dimensions."
+ )
+ lines.append(
+ "- A surrogate model that maps composition → multiple properties + supports gradient-based inverse design lets us search jointly.\n"
+ )
+ lines.append("**Visual asset.** Slide author draws; no pre-rendered figure.\n")
+ lines.append("---\n")
+
+ # ── Slide 2 — Model structure ─────────────────────────────────────────────────────
+ lines.append("## Slide 2 — Model structure + inverse-design strategies\n")
+ lines.append(
+ "**Takeaway.** A shared-encoder foundation model with multiple task heads; **two** "
+ "inverse-design paths (latent vs composition) operate on the **same trained model** "
+ "so the comparison is a fair head-to-head test.\n"
+ )
+ lines.append("**Slide content.**")
+ lines.append(
+ "- Architecture diagram: "
+ "`composition → KMD-1d descriptor x → encoder → latent h → tanh → {head_1, …, head_K}`."
+ )
+ lines.append(
+ "- Highlight the always-on autoencoder head (decoder back to descriptor) — required by the latent path."
+ )
+ lines.append("- Two strategy boxes:")
+ lines.append(
+ " - **Latent path** (`optimize_latent`): gradient-descend on `h`, decode with AE back to descriptor, "
+ "evaluate heads. Failure mode without `ae_align_scale > 0`: AE round-trip drift drops QC."
+ )
+ lines.append(
+ ' - **Composition path** (`optimize_composition`, "differentiable KMD"): gradient-descend directly '
+ "on the 94-d element-weight simplex `w`, descriptor = `w · K`. No AE in the loop."
+ )
+ lines.append("- Two user knobs, both on `[0, 1]` (bigger = more of the named thing):")
+ lines.append(
+ " - `ae_align_scale` — latent path; 0 = no AE-alignment penalty (failure-mode "
+ "baseline), 1 = strongest alignment to AE fixed set. Compared at 0 / 0.25 / 1 in this run."
+ )
+ lines.append(
+ " - `diversity_scale` — composition path; 0 = peaky few-element recipes, 1 = "
+ "multi-element recipes (default). Compared at 1.0 and 0.0 (low-diversity ablation) in this run."
+ )
+ lines.append(
+ "- Optional composition add-ons: `allowed_elements` (whitelist palette), `seed_blend` (5 % uniform mix lets non-seed elements have reachable logits).\n"
+ )
+ lines.append("**Speaker notes.**")
+ lines.append("- KMD-1d is differentiable in PR #17 → composition-space optimisation possible at all.")
+ lines.append(
+ '- Knob naming follows "bigger value = more of the named thing"; user doesn\'t need to read the docstring.'
+ )
+ lines.append(
+ "- Same model handles both paths, so latent vs composition is a fair experiment, not an architecture comparison.\n"
+ )
+ lines.append("**Visual asset.** Slide author draws; no pre-rendered figure.\n")
+ lines.append("---\n")
+
+ # ── Slide 3 — Datasets + task types ──────────────────────────────────────────────
+ lines.append("## Slide 3 — Datasets and task types\n")
+ lines.append(
+ f"**Takeaway.** The framework is trained on a heterogeneous task suite "
+ f"({counts['n_tasks']} tasks across 4 data sources × 3 task types) joined by composition formula.\n"
+ )
+ lines.append("**Slide content (suggested 3-column layout).**\n")
+ lines.append("| Task type | Count | Tasks |")
+ lines.append("|---|---:|---|")
+ for kind, label in (("reg", "Regression"), ("kr", "Kernel regression"), ("clf", "Classification")):
+ tasks = [t for t in cfg.task_sequence if TASK_SPECS[t]["kind"] == kind]
+ if tasks:
+ lines.append(f"| **{label}** | {len(tasks)} | {', '.join(f'`{t}`' for t in tasks)} |")
+ lines.append("")
+ lines.append("Datasets supplying these tasks:\n")
+ for name, ntask, nkeys in self._dataset_summary():
+ lines.append(f"- **{name}** — {ntask} tasks · {nkeys} unique compositions used")
+ lines.append("")
+ lines.append("**Speaker notes.**")
+ lines.append(
+ "- Cross-source joining: every dataset has a `composition` column; the canonical formula is the join key."
+ )
+ lines.append(
+ "- Kernel regression predicts an entire `(t, value)` series per composition — one head learns the shape vs `t` (DOS energy or temperature)."
+ )
+ lines.append(
+ '- Classification uses inverse-frequency `class_weights` so the rare QC / AC classes stay alive against ~48k "others" rows in the qc dataset.\n'
+ )
+ lines.append(
+ "**Visual asset.** Slide author renders the 3-column callout. Optional teaser: [`training/forgetting_trajectory.png`](training/forgetting_trajectory.png).\n"
+ )
+ lines.append(
+ "**Raw-data pointer.** [`training/metrics_table.csv`](training/metrics_table.csv) is the flat task / type / dataset / at-intro / final / metric table.\n"
+ )
+ lines.append("---\n")
+
+ # ── Slide 4 — Continual training ──────────────────────────────────────────────────
+ lines.append("## Slide 4 — Continual training without catastrophic forgetting\n")
+ lines.append(
+ "**Takeaway.** Tasks are introduced one at a time across "
+ f"**{counts['n_tasks']} stages**; tiered rehearsal (5 %/10 %) keeps the older heads "
+ "alive. The forgetting trajectory shows every head holds its R² / accuracy as new "
+ "tasks are added.\n"
+ )
+ lines.append(
+ "**Primary figure:** [`training/forgetting_trajectory.png`](training/forgetting_trajectory.png) "
+ "— per-step metric for every active head across all stages."
+ )
+ # Build the tail-task chain from whatever ``fixed_tail`` actually contains, instead of
+ # hard-indexing ``[0..4]`` — a smaller-scale config might legitimately have fewer tail
+ # tasks, and a future plan revision could change the count.
+ tail_chain = " → ".join(cfg.fixed_tail) if cfg.fixed_tail else "(no fixed tail)"
+ lines.append(
+ f"Annotate the fixed-tail tasks (the last {len(cfg.fixed_tail)} steps, "
+ f"`{tail_chain}`) as the focus for the inverse-design section that follows.\n"
+ )
+ lines.append("**Final-step metrics for the inverse-design heads** (the heads inverse design actually uses):\n")
+ lines.append("| Head | Type | Final-step metric |")
+ lines.append("|---|---|---:|")
+ for t in ["formation_energy", "magnetic_moment", "tc", "klat", "material_type"]:
+ if t in final:
+ spec = TASK_SPECS[t]
+ metric_name = "accuracy" if spec["kind"] == "clf" else "R²"
+ val = final.get(t, {}).get("primary", float("nan"))
+ lines.append(f"| `{t}` | {KIND_LABEL[spec['kind']]} | **{val:+.3f}** ({metric_name}) |")
+ lines.append("")
+ lines.append("**Speaker notes.**")
+ lines.append(
+ f"- Rehearsal: `replay_ratio = {cfg.replay_ratio}` for ordinary old tasks, "
+ f"`replay_ratio_high = {cfg.replay_ratio_high}` for the inverse-design tail (every step). "
+ "No layer is frozen — encoder + every active head train jointly."
+ )
+ lines.append(
+ "- Task ordering minimises rehearsal cost: 12 regression first (any order), then 7 "
+ "kernel-regression tasks **ascending by row count** (cheapest first), then the 5 fixed-"
+ "tail tasks — see plan §2 for the cost argument."
+ )
+ lines.append(
+ "- **Per-step parquets + per-step checkpoints** are available under "
+ "`training/stepNN_/` so any per-task / per-step drill-down can be made later "
+ "**without retraining**."
+ )
+ lines.append("- Raw data:")
+ lines.append(
+ " - [`training/forgetting_trajectory.png`](training/forgetting_trajectory.png) — the headline curve."
+ )
+ lines.append(
+ " - `training/stepNN_/_pred.parquet` — `(composition, true, pred)` for every active head at every step."
+ )
+ lines.append(" - `training/stepNN_/_metrics.json` — per-task metric dict at that step.")
+ lines.append(
+ " - `training/stepNN_/checkpoint.pt` — model state at that step (payload `{model, task_sequence, step, new_task, active_tasks}`)."
+ )
+ lines.append(
+ " - [`training/experiment_records.json`](training/experiment_records.json) — every step × every active head, both at-intro and running metrics."
+ )
+ lines.append(" - [`training/metrics_table.csv`](training/metrics_table.csv) — flat aggregated table.\n")
+ lines.append("---\n")
+
+ # ── Slide 5 — Inverse design scenario setup ──────────────────────────────────────
+ lines.append("## Slide 5 — Inverse design: scenario setup\n")
+ lines.append(
+ "**Takeaway.** Three scenarios share the same model, the same 20 seeds, and the "
+ "same primary objective (**P(QC) ↑**). Secondary objectives differ — picking which "
+ "scenario to feature in the talk is the slide author's narrative choice.\n"
+ )
+ lines.append("**Slide content.** A small table or three pill boxes:\n")
+ lines.append("| Scenario | Primary | Secondary objectives |")
+ lines.append("|---|---|---|")
+ for name, data in scenarios.items():
+ reg_targets = data.get("reg_targets", {})
+ secondary = ", ".join(f"{_display(t)} {_arrow(v)} {v:+.1f}" for t, v in reg_targets.items())
+ lines.append(f"| `{name}` | P(QC) ↑ (target 1.0) | {secondary} |")
+ lines.append("")
+ lines.append("**Methodology (constant across scenarios).**")
+ lines.append("- 20 seeds shared across scenarios (slide 6 details the split).")
+ lines.append(
+ f"- Optimisation budget: **{cfg.inverse_steps} Adam steps**, **`lr = {cfg.inverse_lr}`**, "
+ f"**`class_target_weight = {cfg.inverse_class_weight}`** (so QC dominates the loss)."
+ )
+ lines.append(
+ "- All metrics evaluated **after** decoding the optimised descriptor back to a real composition (round-trip)."
+ )
+ lines.append("- 8 configurations per scenario (3 latent α + 5 composition) — see slide 6.\n")
+ lines.append("**Speaker notes.**")
+ lines.append(
+ '- All three scenarios are first-class — the runner does not pick a "headline" scenario. Slide author chooses which to feature based on the talk\'s narrative.'
+ )
+ lines.append("- Plan §5 lists the rationale for each scenario.\n")
+ lines.append('**Visual asset.** Slide author can draw a small "target dial" visual. No pre-rendered figure.\n')
+ lines.append(
+ "**Raw-data pointer.** [`inverse_design/seeds.json`](inverse_design/seeds.json) (seeds), `inverse_design//targets.json` (objective definitions per scenario).\n"
+ )
+ lines.append("---\n")
+
+ # ── Slide 6 — Seeds + palette + config table ─────────────────────────────────────
+ lines.append("## Slide 6 — Initial seeds, the element palette, and the 8 configurations\n")
+ lines.append(
+ f"**Takeaway.** Three ingredients shape the search: (a) **{len(all_seeds)} seeds** "
+ f"for the optimiser to start from, (b) the **{len(ALLOY_PALETTE)}-element `ALLOY_PALETTE`** the "
+ "constrained composition paths are allowed to use, (c) **8 configurations** isolating "
+ "ae_align_scale / seed_blend / palette / diversity / random-init effects.\n"
+ )
+ lines.append("### Seeds\n")
+ lines.append(
+ f"**N = {len(all_seeds)}** = {len(strategy_seeds)} top-QC dedup + {len(explicit_seeds)} explicit-append. "
+ "Element-system dedup keeps the best representative per element set so the seed list spans "
+ "**different alloy families** rather than ratio variants of a few.\n"
+ )
+ lines.append(
+ f"- **{len(strategy_seeds)} top-QC dedup seeds** (from the training-set material_type frame, picked by predicted QC probability):"
+ )
+ for s in strategy_seeds[:8]:
+ lines.append(f" - `{s}`")
+ if len(strategy_seeds) > 8:
+ lines.append(f" - … ({len(strategy_seeds) - 8} more in `inverse_design/seeds.json`)")
+ lines.append(
+ f"- **{len(explicit_seeds)} explicit-append seeds** (forced regardless of QC score — known Au–Ga–RE i-QC formers):"
+ )
+ for s in explicit_seeds:
+ lines.append(f" - `{s}`")
+ lines.append("")
+
+ lines.append(
+ f"### `ALLOY_PALETTE` ({len(ALLOY_PALETTE)} elements, slide author renders periodic-table highlight)\n"
+ )
+ lines.append(
+ "Range design: covers classic i-QC / d-QC formers + easy 4th/5th-period TMs + accessible lanthanides + Au (so Au–Ga–Ln seeds are reachable). Pm / Tc and Pu-class radioactives are excluded; Tm / Lu excluded as rare and expensive.\n"
+ )
+ lines.append("- **Light alkaline earth:** Mg, Ca")
+ lines.append("- **Group 13:** B, Al, Ga, In, Tl")
+ lines.append("- **Group 14:** Si, Ge")
+ lines.append("- **4th-period TM (10):** Sc Ti V Cr Mn Fe Co Ni Cu Zn")
+ lines.append("- **5th-period TM (9, Tc excluded as radioactive):** Y Zr Nb Mo Ru Rh Pd Ag Cd")
+ lines.append("- **6th-period noble (needed for Au–Ga–RE seeds):** Au")
+ lines.append("- **Accessible lanthanides (12, Pm/Tm/Lu excluded):** La Ce Pr Nd Sm Eu Gd Tb Dy Ho Er Yb\n")
+
+ lines.append("### The 8 configurations — what each isolates\n")
+ lines.append("3 latent points (along `ae_align_scale`) + 5 composition configs:\n")
+ lines.append("| Config (x-axis label in `comparison.png`) | Knobs | What it tests |")
+ lines.append("|---|---|---|")
+ lines.append(
+ "| `latent α=0` | `ae_align_scale = 0` | AE-alignment off → failure mode in PR #18's paper-baseline run (QC collapses). With `dos_density` in the training mix the latent geometry may be more robust — check this run's number. |"
+ )
+ lines.append("| `latent α=0.25` | `ae_align_scale = 0.25` | Low alignment — intermediate point. |")
+ lines.append(
+ "| `latent α=1` | `ae_align_scale = 1.0` | Max alignment — strongest cycle-consistency constraint. |"
+ )
+ lines.append(
+ "| `comp (seed)` | `seed_blend = 1.0`, all elements allowed | Strict-seed baseline. Optimiser can only rebalance the seed's existing elements — no new element can enter the support set. |"
+ )
+ lines.append(
+ "| `comp (seed, 5% all)` | `seed_blend = 0.95`, all allowed | Adds 5 % uniform mass over all 94 elements so non-seed elements have reachable logits. Optimiser *can* introduce new elements but otherwise unconstrained. |"
+ )
+ lines.append(
+ f"| `comp (seed, 5% all, element list)` | (above) + `allowed_elements = ALLOY_PALETTE` | Restricts the support set to the {len(ALLOY_PALETTE)} feasible alloy elements. **Practical materials-design mode.** |"
+ )
+ lines.append(
+ "| `comp (seed, 5% all, element list, low diversity)` | (above) + `diversity_scale = 0` | Adds max entropy penalty → forces peaky few-element recipes. Tests whether peaky recipes still satisfy the targets. |"
+ )
+ lines.append(
+ '| `comp (random)` | `initial_weights = None`, all allowed | No seed, no palette. Pure "let the optimiser explore" — the no-bias control. |'
+ )
+ lines.append("")
+ lines.append("**Speaker notes.**")
+ lines.append("- Each row of `inverse_design//comparison.png` x-axis maps to one of these configs.")
+ lines.append('- Labels read as "config A, then add knob B, then add knob C" — each comma = a knob change.')
+ lines.append(
+ '- "low diversity" = `diversity_scale = 0`, the most penalised end of the diversity knob → fewest elements per output.\n'
+ )
+ lines.append(
+ f"**Visual asset.** Slide author renders the periodic-table highlight from the {len(ALLOY_PALETTE)}-element list above. No pre-rendered palette figure.\n"
+ )
+ lines.append(
+ "**Raw-data pointer.** [`inverse_design/seeds.json`](inverse_design/seeds.json) for the seed list; palette literal in [`samples/continual_rehearsal_full_config.toml`](../../samples/continual_rehearsal_full_config.toml).\n"
+ )
+ lines.append("---\n")
+
+ # ── Slide 7 — Results & discussion (the central section) ─────────────────────────
+ lines.append("## Slide 7 — Results & discussion\n")
+ lines.append(
+ "**Takeaway** (templated stub — fill in based on the per-scenario tables below + "
+ "discovered-elements list). Typical claims the slide author chooses among:\n"
+ )
+ lines.append(
+ "- **Headline claim.** `comp (seed, 5% all, element list)` is the practical winner on the scenario you pick to feature — tight, physically credible alloy recipes; element discovery (specific elements present in 100 % of outputs but 0 % of seeds)."
+ )
+ lines.append(
+ "- **Constraints-matter claim.** `comp (random)` lands the optimiser on the model's unconstrained global QC attractor — often physically implausible elements; demonstrates that the palette + seed are doing real work, not just regularising."
+ )
+ lines.append(
+ "- **Latent-knob claim.** The `ae_align_scale` sweep on `latent α=0 / 0.25 / 1` traces the AE-alignment effect on the three target axes."
+ )
+ lines.append("")
+ lines.append(
+ "Pick the claim(s) the actual numbers support; the per-scenario tables below carry every figure you need.\n"
+ )
+
+ lines.append("**Primary figures (per scenario).**")
+ for name in scenarios:
+ lines.append(
+ f"- [`inverse_design/{name}/comparison.png`](inverse_design/{name}/comparison.png) — 8-config boxplot across P(QC) + each reg target."
+ )
+ lines.append("")
+ lines.append("**Supporting figures (per scenario).**")
+ for name in scenarios:
+ lines.append(
+ f'- [`inverse_design/{name}/element_frequency_heatmap.png`](inverse_design/{name}/element_frequency_heatmap.png) — path × top-25 elements; **bold orange** x-tick labels = elements NOT in any seed → "discovered".'
+ )
+ lines.append("")
+
+ # Per-scenario per-config table + discovered elements + open questions
+ for name, data in scenarios.items():
+ reg_targets = data.get("reg_targets", {})
+ paths_meta = data.get("paths", {})
+ paths_details = data.get("paths_details", {})
+
+ lines.append(f"### Scenario: `{name}`\n")
+ secondary = ", ".join(f"{_display(t)} {_arrow(v)} {v:+.1f}" for t, v in reg_targets.items())
+ lines.append(
+ f"Targets: **P(QC) ↑ (target 1.0)**, {secondary}. "
+ f"Seed mean QC (before): **{data.get('qc_before_mean', float('nan')):.3f}**.\n"
+ )
+
+ # Per-config table (one row per config, columns: QC mean ± std, each reg target mean)
+ header = ["config", "QC after (mean ± std)"] + [REG_TASK_TITLES.get(t, t) for t in reg_targets]
+ lines.append("| " + " | ".join(header) + " |")
+ lines.append("|" + "|".join(["---"] + ["---:"] * (len(header) - 1)) + "|")
+ for path_cfg in INVERSE_PATH_CONFIGS:
+ key = path_cfg["key"]
+ label = path_cfg["label"]
+ meta = paths_meta.get(key, {})
+ qc_m = meta.get("qc_after_mean", float("nan"))
+ qc_s = meta.get("qc_after_std", float("nan"))
+ row = [f"`{label}`", f"{qc_m:.3f} ± {qc_s:.3f}"]
+ for t in reg_targets:
+ row.append(f"{meta.get('reg_after_decode_mean', {}).get(t, float('nan')):+.2f}")
+ lines.append("| " + " | ".join(row) + " |")
+ lines.append("")
+
+ # Discovered elements per config (≥ 95 % occupancy, 0 in seeds)
+ lines.append(
+ "**Element discovery** (occurrence ≥ 95 % in this config's 20 outputs, **and** 0 occurrence in any seed):"
+ )
+ any_discovered = False
+ for path_cfg in INVERSE_PATH_CONFIGS:
+ key = path_cfg["key"]
+ disc = _discovered(paths_details.get(key, {}))
+ if disc:
+ any_discovered = True
+ payload = ", ".join(f"**{sym}** ({int(round(frac * 100))}%)" for sym, frac in disc)
+ lines.append(f"- `{path_cfg['label']}` → {payload}")
+ if not any_discovered:
+ lines.append(
+ "- *(none in this run — no element passes the ≥95 % occurrence + 0-in-seeds bar. "
+ "Either the optimiser is just rebalancing seed elements, or the run is too early "
+ "to surface discoveries. Smoke runs typically have none; the formal full run "
+ "is expected to surface discovered elements in `comp (seed, 5% all, element list)`.)*"
+ )
+ lines.append("")
+
+ # Decoded example per config
+ lines.append("**One decoded example per config** (highest-QC seed of that config):")
+ for path_cfg in INVERSE_PATH_CONFIGS:
+ key = path_cfg["key"]
+ decoded = paths_details.get(key, {}).get("decoded_composition", [])
+ if decoded:
+ lines.append(f"- `{path_cfg['label']}` → `{decoded[0]}`")
+ lines.append("")
+
+ # Three discussion-thread stubs (templated for the slide author)
+ lines.append("### Discussion threads (templated stubs — verify against numbers above)\n")
+ lines.append(
+ "1. **Element discovery is the headline.** *Fill in:* in `comp (seed, 5% all, element list)`, "
+ "which element(s) appear in ≥95 % of outputs and 0 % of seeds? (See the discovery list "
+ 'per scenario above.) If non-empty, this is the central claim — "the model found '
+ "something we didn't tell it about\".\n"
+ )
+ lines.append(
+ "2. **Constraints matter.** *Fill in:* `comp (random)` QC vs `comp (seed, 5% all, element list)` QC. "
+ "If random-init lands far from the constrained QC, the seed + palette are doing real "
+ "work (not regularising). If random-init still finds high QC but with implausible "
+ "elements (Pu / F / Mn-rich), the *physicality* of the recipe is the constraint payoff, "
+ "not raw QC.\n"
+ )
+ lines.append(
+ "3. **Latent path α-knob role.** *Fill in:* compare `latent α=0` vs `latent α=1` QC + reg "
+ "targets. In PR #18's pre-`dos_density` baseline α=0 was a catastrophe (QC ~ 0.39). "
+ "With `dos_density` in this run's training mix, check whether α=0 is still a "
+ "catastrophe (claim the failure-mode story), or whether the latent geometry is now "
+ 'robust to α=0 (claim the α-knob has shifted from "rescue QC" to "trade QC bias '
+ 'against secondary-target reach").\n'
+ )
+
+ lines.append("### Plan §5 expected baselines (for sanity-check; slide author must verify)\n")
+ lines.append(
+ "Plan §5 reports the following PR #18 + 41-elem-smoke baselines for a single "
+ "scenario (QC↑ / FE↓ / klat↑, 16 seeds). The formal full run should land in similar "
+ "magnitudes; smoke / partial runs will not.\n"
+ )
+ lines.append("| Config | QC after | FE after | klat after | pairwise L1 | mean #elems |")
+ lines.append("|---|---:|---:|---:|---:|---:|")
+ lines.append("| latent α=0 (failure) | 0.386 ± 0.315 | +2.46 ± 0.59 | −0.44 ± 0.27 | 1.07 | 5.2 |")
+ lines.append("| latent α=0.5 (sweet) | **0.960 ± 0.027** | +0.92 ± 1.16 | +1.07 ± 0.31 | 0.82 | 3.4 |")
+ lines.append("| latent α=1.0 (max) | 0.951 ± 0.027 | +0.40 ± 1.04 | +1.20 ± 0.35 | 1.06 | 3.6 |")
+ lines.append("| C-strict | 0.887 ± 0.053 | +1.27 ± 0.24 | +0.76 ± 0.67 | 1.42 | 2.6 |")
+ lines.append("| **C-alloy (12 elem)** | 0.870 ± 0.012 | +0.84 ± 0.03 | **+1.81 ± 0.07** | 0.17 | 5.6 |")
+ lines.append("| **C-alloy (41 elem)** | 0.842 ± 0.018 | +0.68 ± 0.07 | **+1.84 ± 0.06** | 1.02 | 6.0 |")
+ lines.append("| C-rand | 0.793 ± 0.005 | −0.78 ± 0.03 | +1.77 ± 0.02 | 0.10 | 6.0 |")
+ lines.append("")
+
+ lines.append("### Open questions to flag\n")
+ lines.append(
+ "- **`comp (seed)` variance.** If `comp (seed)` σ is large (≥0.2 in PR #18 paper run), "
+ "per-seed audit: which seeds fail? Drill down via `inverse_design//comp_seed/result.json` "
+ "(`qc_after_decode` per seed; `seeds` list in same file)."
+ )
+ lines.append(
+ "- **Au–Ga–Ln seeds.** The 3 explicit Au–Ga–Ln seeds are known QC candidates. Their "
+ "*per-seed* QC in `comp (seed)` should be high — if not, that's itself a notable finding."
+ )
+ lines.append(
+ "- **Scenario coverage.** This run has 3 scenarios; the deck may not need all three. "
+ "Pick 1–2 the audience cares about and footnote the others.\n"
+ )
+ lines.append("---\n")
+
+ # ── Slide 8 — Summary ────────────────────────────────────────────────────────────
+ lines.append("## Slide 8 — Summary\n")
+ lines.append("**Takeaway** (three bullets for the slide; numbers fill in from above).\n")
+ lines.append(
+ f"1. A shared-encoder foundation model trained continually across "
+ f"**{counts['n_tasks']} heterogeneous tasks** with tiered rehearsal — no catastrophic "
+ "forgetting on the inverse-design heads (slide 4 numbers)."
+ )
+ lines.append(
+ "2. Two inverse-design paths on the same model, both exposed as user-friendly `[0, 1]` "
+ "knobs (`ae_align_scale`, `diversity_scale`). Eight configurations per scenario "
+ "isolate every effect (slide 6 table)."
+ )
+ lines.append(
+ "3. On the scenario(s) you feature: the constrained composition path delivers "
+ "physically credible recipes; element-discovery signal surfaces "
+ "(see scenario-specific table in slide 7)."
+ )
+ lines.append("")
+ lines.append("**Failure modes (also first-class — claim them honestly).**")
+ lines.append("- AE-roundtrip drift without `ae_align_scale > 0` (latent path).")
+ lines.append("- Seed-init support-set lock without `seed_blend < 1` (composition path with strict seed).")
+ lines.append("- Non-physical attractors without `allowed_elements` (composition random init).\n")
+ lines.append(
+ "**Slide content.** Three takeaway bullets + a thumbnail of one of the "
+ "`inverse_design//comparison.png` files (slide author picks).\n"
+ )
+ lines.append("---\n")
+
+ # ── Slide 9 — Future work ────────────────────────────────────────────────────────
+ lines.append("## Slide 9 — Future work\n")
+ lines.append(
+ "**Takeaway.** The current framework is the foundation; the next step is to wrap it "
+ "in an agent system, then later wire into the broader AI4S agent ecosystem.\n"
+ )
+ lines.append("### Beat 6 — agent-based inverse-design workbench\n")
+ lines.append('- Natural-language goals from the user ("I want a low-density QC formed from common metals").')
+ lines.append(
+ '- An AI agent decomposes the goal + applies domain knowledge ("QC + common metals → use `allowed_elements = ALLOY_PALETTE − lanthanides`").'
+ )
+ lines.append(
+ "- Agent automatically sets optimiser knobs (`ae_align_scale`, `diversity_scale`, seed strategy, palette, target dict)."
+ )
+ lines.append("- Runs `optimize_*`, decodes outputs, generates a visualisation + PDF report.\n")
+ lines.append("### Beat 7 — wider AI4S agent ecosystem\n")
+ lines.append(
+ "- Foundation model becomes the fast predictor + candidate generator in the centre of a larger stack."
+ )
+ lines.append(
+ "- Other agents wrap DFT / MD simulators (slow but accurate validation), automated synthesis platforms (closed-loop experimental feedback)."
+ )
+ lines.append(
+ "- Pipeline: user request → foundation-model candidates → DFT validation → robotic synthesis → results loop back to retrain the foundation model.\n"
+ )
+ lines.append(
+ "**Slide content.** One bullet per beat, plus a concentric-circles sketch (foundation model at the centre, agent wrappers around it, the user / world outside).\n"
+ )
+ lines.append("---\n")
+
+ # ── Quick reference ──────────────────────────────────────────────────────────────
+ lines.append("## Quick reference — files in this run folder\n")
+ lines.append("| File | Used by which slide |")
+ lines.append("|---|---|")
+ lines.append(
+ "| [`training/forgetting_trajectory.png`](training/forgetting_trajectory.png) | Slide 4 (primary) |"
+ )
+ lines.append("| `training/stepNN_/*.png` | Slide 4 appendix (drill-down per task) |")
+ lines.append("| `training/stepNN_/*_pred.parquet` | Replot any per-step figure without retraining |")
+ lines.append("| `training/stepNN_/*_metrics.json` | Per-task metric dict at that step |")
+ lines.append("| `training/stepNN_/checkpoint.pt` | Restore the model at any intermediate stage |")
+ lines.append(
+ "| [`training/experiment_records.json`](training/experiment_records.json) | Full records (step × head, at-intro + running) |"
+ )
+ lines.append(
+ "| [`training/metrics_table.csv`](training/metrics_table.csv) | Flat task / type / dataset / at-intro / final table |"
+ )
+ lines.append(
+ "| [`training/final_model.pt`](training/final_model.pt) | Final model state_dict + task_sequence |"
+ )
+ lines.append(
+ "| `inverse_design//comparison.png` | Slide 7 (primary, per scenario), Slide 8 (thumbnail) |"
+ )
+ lines.append(
+ "| `inverse_design//element_frequency_heatmap.png` | Slide 7 (supporting, per scenario) |"
+ )
+ lines.append(
+ "| `inverse_design///result.json` | Per-config raw arrays — `optimized_weights` (20, 94), `optimized_descriptor` (20, x_dim), per-seed predictions |"
+ )
+ lines.append(
+ "| `inverse_design//summary.json` | Per-scenario aggregated stats (per-config means + stds) |"
+ )
+ lines.append("| `inverse_design//targets.json` | Primary + secondary objective definitions |")
+ lines.append(
+ "| [`inverse_design/seeds.json`](inverse_design/seeds.json) | Slide 6 (seed names + strategy/explicit split) |"
+ )
+ lines.append(
+ "| [`inverse_design/SUMMARY.md`](inverse_design/SUMMARY.md) | Cross-scenario compact summary table |"
+ )
+ lines.append(
+ "| [`inverse_design/inverse_design.json`](inverse_design/inverse_design.json) | Full nested inverse-design dump (every scenario × every path) |"
+ )
+ lines.append("| [`ANALYSIS.md`](ANALYSIS.md) | Speaker-note source (long-form analysis) |")
+ lines.append("| [`README.md`](README.md) | Run-folder reference / directory map |")
+ lines.append("")
+
+ # ── Slide-author freedom ──────────────────────────────────────────────────────────
+ lines.append("## What the slide author has freedom over (and what they don't)\n")
+ lines.append("**Free:**")
+ lines.append("- Visual style (theme, colours, fonts, slide template).")
+ lines.append("- Layout and slide breaks.")
+ lines.append('- Diagrams (slides 1, 2, 3, 5, 6, 9 explicitly say "slide author draws this").')
+ lines.append("- Order: this document is in narrative order, but the slide author may reshuffle.")
+ lines.append("- Which scenario(s) to feature: the runner does not pick a headline scenario.")
+ lines.append(
+ "- Which discussion thread(s) in slide 7 to make the central claim — pick the one(s) the numbers actually support.\n"
+ )
+ lines.append("**Not free (these are the claims):**")
+ lines.append(
+ "- All numbers in the per-scenario tables of slide 7 — quoted from `inverse_design///result.json`."
+ )
+ lines.append(
+ "- The element-discovery list — computed as occurrence ≥ 95 % in a config's outputs AND 0 in any seed (the bar must be cleared to claim discovery)."
+ )
+ lines.append("- The two-knob naming (`ae_align_scale`, `diversity_scale`) — these are the public API.")
+ lines.append("- The 8 configuration names (x-axis labels of every `comparison.png`).")
+ lines.append("- The 3 scenario names + target dicts (slide 5 table is canonical).\n")
+ lines.append("---\n")
+
+ # ── Raw-data cheat sheet ──────────────────────────────────────────────────────────
+ lines.append("## Where the raw data lives — full cheat-sheet\n")
+ lines.append(
+ "Every figure above is fully reproducible from the raw arrays — **no need to "
+ "retrain or rerun the optimisation** to change a plot's style / axis / colour scheme.\n"
+ )
+ lines.append(
+ "- `training/stepNN_/_pred.parquet` — `(composition, true, pred)` (KR has `t` too). Plot any per-task parity / confusion / KR-sequence at any stage."
+ )
+ lines.append("- `training/stepNN_/_metrics.json` — per-task metric dict at that step.")
+ lines.append(
+ "- `training/stepNN_/checkpoint.pt` — model state at that step (payload: `{model, task_sequence, step, new_task, active_tasks}`)."
+ )
+ lines.append(
+ "- `training/experiment_records.json` — every step × every active head metric (at-intro and running)."
+ )
+ lines.append("- `training/metrics_table.csv` — flat task/type/dataset/at-intro/final/metric.")
+ lines.append(
+ "- `training/final_model.pt` — final model state_dict + task_sequence (consumed by `--inverse-only` / `paper_inverse_comparison.py` / `finetune_inverse_heads.py`)."
+ )
+ lines.append("- `training/forgetting_trajectory.png` — per-step × per-task primary-metric curves.")
+ lines.append("- `inverse_design/seeds.json` — seeds in two segments (`strategy_seeds`, `explicit_seeds`).")
+ lines.append("- `inverse_design//targets.json` — primary + secondary target definitions.")
+ lines.append(
+ "- `inverse_design///result.json` — per-config full record: `optimized_weights` `(B, 94)`, `optimized_descriptor` `(B, x_dim)`, `qc_after_decode`, `reg_before` / `reg_achieved_latent` / `reg_after_decode`, `decoded_composition`."
+ )
+ lines.append("- `inverse_design//summary.json` — per-scenario aggregated stats.")
+ lines.append("- `inverse_design//comparison.png` — 8-config boxplot comparison.")
+ lines.append(
+ "- `inverse_design//element_frequency_heatmap.png` — config × element occurrence heatmap; discovered-element x-tick labels are bold + orange."
+ )
+ lines.append("- `inverse_design/SUMMARY.md` — compact cross-scenario table.\n")
+ lines.append(
+ "Element order in `optimized_weights`: "
+ "`foundation_model.utils.kmd_plus.DEFAULT_ELEMENTS` (94 symbols). "
+ "Composition-formula round-trip: `KMD.inverse(descriptor)` (or directly use `optimized_weights` which already lives on the simplex).\n"
+ )
+
+ (self.output_dir / "SLIDE_PREP.md").write_text("\n".join(lines), encoding="utf-8")
+ logger.info(f"Saved SLIDE_PREP.md to {self.output_dir / 'SLIDE_PREP.md'}")
+
+ def _write_readme(self, records: list[dict[str, Any]], inverse: dict[str, Any]) -> None:
+ """Top-level run index — what's in this directory and where to start reading."""
+ c = self._counts()
+ scenarios = inverse.get("scenarios", {}) if isinstance(inverse, dict) else {}
+ lines = [
+ "# Continual rehearsal + inverse-design — run directory",
+ "",
+ f"{c['n_tasks']} supervised tasks ({c['n_reg']} reg · {c['n_kr']} kr · "
+ f"{c['n_clf']} clf) + autoencoder · 3 inverse-design scenarios × 4 paths.",
+ "",
+ "## Start here",
+ "- [`SLIDE_PREP.md`](SLIDE_PREP.md) — 9-section slide outline for the external slide author.",
+ "- [`ANALYSIS.md`](ANALYSIS.md) — long-form narrative analysis (speaker-note material).",
+ "- [`inverse_design/SUMMARY.md`](inverse_design/SUMMARY.md) — compact cross-scenario table.",
+ "- `inverse_design//comparison.png` + `element_frequency_heatmap.png` — per-scenario figures (three scenarios, all first-class — no demo-style single-scenario headline).",
+ "",
+ "## Directory map",
+ "```",
+ "training/",
+ " stepNN_/ # one dir per training step",
+ " _pred.parquet # (composition, true, pred) for every active head",
+ " _metrics.json # per-task metric dict (R²/acc/MAE/…)",
+ " _parity.png | _confusion.png | _sequences.png # newest-head plot only",
+ " checkpoint.pt # model state at that step",
+ " forgetting_trajectory.png # per-step × per-task primary metric",
+ " experiment_records.json # full records (every step × every head)",
+ " metrics_table.csv # flat per-task at-intro / final table",
+ " final_model.pt # final model state_dict + task_sequence",
+ " final_model_taskconfigs.json # task-config metadata for rebuilding the model",
+ "inverse_design/",
+ " seeds.json # 20 seeds (17 top-QC dedup + 3 Au-Ga-Ln)",
+ " inverse_design.json # full nested result dump",
+ " SUMMARY.md # cross-scenario compact table",
+ " /",
+ " targets.json # primary + secondary objectives",
+ " summary.json # per-path mean / std headline stats",
+ " comparison.png # 8-path boxplot (QC + each reg target)",
+ " element_frequency_heatmap.png # path × top-25 elements (discovered = bold orange)",
+ " /result.json # raw per-seed arrays, optimized_weights, …",
+ "SLIDE_PREP.md # slide outline + raw-data pointers",
+ "ANALYSIS.md # long-form analysis",
+ "README.md # this file",
+ "```",
+ "",
+ "## Scenarios",
+ ]
+ for name, data in scenarios.items():
+ reg_targets = data.get("reg_targets", {})
+ secondary = ", ".join(f"{_display(t)} {_arrow(v)} {v:+.1f}" for t, v in reg_targets.items())
+ lines.append(f"- **{name}** — primary: QC ↑; secondary: {secondary}")
+ lines.append("")
+ (self.output_dir / "README.md").write_text("\n".join(lines), encoding="utf-8")
+ logger.info(f"Saved README.md to {self.output_dir / 'README.md'}")
+
+
+# --- CLI ---------------------------------------------------------------------
+
+
+def _load_toml(path: Path) -> dict[str, Any]:
+ try:
+ import tomllib # type: ignore[attr-defined]
+ except ModuleNotFoundError: # pragma: no cover
+ import tomli as tomllib # type: ignore
+ return tomllib.loads(Path(path).read_text(encoding="utf-8"))
+
+
+def _parse_args(argv: list[str] | None = None) -> tuple[ContinualRehearsalFullConfig, argparse.Namespace]:
+ parser = argparse.ArgumentParser(description="Continual rehearsal + inverse-design — full run.")
+ parser.add_argument("--config-file", type=Path, default=None)
+ parser.add_argument("--output-dir", type=Path, default=None)
+ parser.add_argument("--sample-per-dataset", type=int, default=None)
+ parser.add_argument("--max-epochs-per-step", type=int, default=None)
+ parser.add_argument("--accelerator", type=str, default=None)
+ parser.add_argument(
+ "--inverse-only",
+ type=Path,
+ default=None,
+ metavar="CKPT",
+ help="Skip training; load a final_model.pt checkpoint and rerun only the inverse-design stage.",
+ )
+ # Trajectory plotting flags — mirror paper_inverse_comparison's CLI so the user can switch
+ # animation format / opt out of per-step recording without code changes.
+ parser.add_argument(
+ "--record-trajectory",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="Record per-step optimisation trajectories and emit trajectory plots / animations "
+ "per scenario × path. ``--no-record-trajectory`` skips both (saves ~10 %% on the latent "
+ "path and the animation rendering cost).",
+ )
+ parser.add_argument(
+ "--per-seed-trajectories",
+ action="store_true",
+ help="Additionally emit one plot + animation per (path × seed) under "
+ "``trajectories_per_seed/`` (heavy: 20× more figures). Off by default.",
+ )
+ parser.add_argument(
+ "--animation-formats",
+ nargs="+",
+ choices=["gif", "html", "svg", "none"],
+ default=["gif"],
+ help="Trajectory animation formats. ``none`` disables animations (the static plot is "
+ "still written). Default: gif.",
+ )
+ args = parser.parse_args(argv)
+
+ data = _load_toml(args.config_file) if args.config_file else {}
+ for key in ("output_dir", "sample_per_dataset", "max_epochs_per_step", "accelerator"):
+ val = getattr(args, key)
+ if val is not None:
+ data[key] = val
+
+ field_names = set(ContinualRehearsalFullConfig.__dataclass_fields__)
+ path_fields = {
+ "qc_data_path",
+ "qc_preprocessing_path",
+ "superconductor_path",
+ "magnetic_path",
+ "phonix_path",
+ "output_dir",
+ }
+ kwargs: dict[str, Any] = {}
+ for key, value in data.items():
+ if key not in field_names:
+ logger.warning(f"Ignoring unknown config key '{key}'.")
+ continue
+ if key == "inverse_scenarios":
+ kwargs[key] = [InverseScenario(**sc) if isinstance(sc, dict) else sc for sc in value]
+ elif key in path_fields:
+ # Empty string means "unset" (e.g. qc_preprocessing_path with no matching pkl).
+ kwargs[key] = Path(value) if value not in (None, "") else None
+ else:
+ kwargs[key] = value
+ return ContinualRehearsalFullConfig(**kwargs), args
+
+
+def main(argv: list[str] | None = None) -> None:
+ config, args = _parse_args(argv)
+ runner = ContinualRehearsalFullRunner(config)
+ traj_kwargs: dict[str, Any] = {
+ "record_trajectory": args.record_trajectory,
+ "per_seed_trajectories": args.per_seed_trajectories,
+ "animation_formats": tuple(args.animation_formats),
+ }
+ if args.inverse_only is not None:
+ runner.run_inverse_only(args.inverse_only, **traj_kwargs)
+ else:
+ runner.run(**traj_kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/foundation_model/scripts/continual_rehearsal_full_test.py b/src/foundation_model/scripts/continual_rehearsal_full_test.py
new file mode 100644
index 0000000..769cdc5
--- /dev/null
+++ b/src/foundation_model/scripts/continual_rehearsal_full_test.py
@@ -0,0 +1,268 @@
+"""Tests for the full continual-rehearsal + inverse-design runner (config/catalogue/CLI logic).
+
+Training and data loading are exercised by the smoke run, not here; these tests cover the pure
+logic that is cheap and worth guarding: the task catalogue, config validation, and TOML/CLI parsing.
+"""
+
+from __future__ import annotations
+
+import textwrap
+from pathlib import Path
+
+import pytest
+
+from foundation_model.scripts.continual_rehearsal_full import (
+ ALLOY_PALETTE,
+ DEFAULT_FIXED_TAIL,
+ DEFAULT_SEQUENCE,
+ INVERSE_PATH_CONFIGS,
+ INVERSE_PATHS,
+ REG_TASK_TITLES,
+ TASK_SPECS,
+ ContinualRehearsalFullConfig,
+ ContinualRehearsalFullRunner,
+ InverseScenario,
+ _arrow,
+ _display,
+ _parse_args,
+ _title,
+)
+
+
+def test_default_sequence_is_24_tasks_by_type():
+ kinds = [TASK_SPECS[t]["kind"] for t in DEFAULT_SEQUENCE]
+ assert len(DEFAULT_SEQUENCE) == 24
+ assert kinds.count("reg") == 16
+ assert kinds.count("kr") == 7
+ assert kinds.count("clf") == 1
+
+
+def test_catalogue_consistency():
+ # Every sequenced task is known; kernel tasks declare a t_column; clf declares num_classes.
+ for task in DEFAULT_SEQUENCE:
+ spec = TASK_SPECS[task]
+ assert spec["kind"] in {"reg", "kr", "clf"}
+ if spec["kind"] == "kr":
+ assert "t_column" in spec
+ if spec["kind"] == "clf":
+ assert "num_classes" in spec
+ # The fixed tail is the last segment of the default sequence.
+ assert DEFAULT_SEQUENCE[-len(DEFAULT_FIXED_TAIL) :] == DEFAULT_FIXED_TAIL
+ # material_type is last so the QC classifier is freshest for inverse design.
+ assert DEFAULT_SEQUENCE[-1] == "material_type"
+
+
+def test_inverse_path_configs_match_demo():
+ # 8 configurations — 3 latent ae_align_scale points + 5 composition configs — mirroring the
+ # demo's paper_inverse_comparison.py so the figures read the same across runners.
+ assert len(INVERSE_PATH_CONFIGS) == 8
+ methods = [c["method"] for c in INVERSE_PATH_CONFIGS]
+ assert methods.count("latent") == 3
+ assert methods.count("composition") == 5
+ latent_alphas = [c["ae_align_scale"] for c in INVERSE_PATH_CONFIGS if c["method"] == "latent"]
+ assert latent_alphas == [0.0, 0.25, 1.0]
+ # The key list is a flat str list of unique stable identifiers used as result subdir names.
+ assert INVERSE_PATHS == [c["key"] for c in INVERSE_PATH_CONFIGS]
+ assert len(set(INVERSE_PATHS)) == len(INVERSE_PATHS)
+ # One config row must hit each demo configuration knob.
+ keys = set(INVERSE_PATHS)
+ assert {
+ "latent_align0p0",
+ "latent_align0p25",
+ "latent_align1p0",
+ "comp_seed",
+ "comp_seed_blend",
+ "comp_seed_blend_palette",
+ "comp_seed_blend_palette_lowdiv",
+ "comp_random",
+ } == keys
+
+
+def test_reg_task_titles_include_scenario_targets():
+ # Every reg task across the three default scenarios should have a paper-style panel title.
+ for t in ("formation_energy", "klat", "magnetic_moment", "tc"):
+ assert t in REG_TASK_TITLES
+ assert "[" in REG_TASK_TITLES[t] and "]" in REG_TASK_TITLES[t] # units present
+ assert REG_TASK_TITLES[t].endswith(("↑", "↓"))
+
+
+def test_alloy_palette_contents():
+ # Plan §5 originally specified 41 elements; extended 2026-05 with the full Hf–Pt 5d TM row
+ # (7 symbols) → 48. The three Au-Ga-Ln explicit seeds must still fit.
+ assert len(ALLOY_PALETTE) == 48
+ for sym in ("Au", "Ga", "Gd", "Tb", "Dy", "Mg", "Pd", "Al"):
+ assert sym in ALLOY_PALETTE
+ # 5d transition metals (Hf–Pt) — newly added.
+ for sym in ("Hf", "Ta", "W", "Re", "Os", "Ir", "Pt"):
+ assert sym in ALLOY_PALETTE
+ # Radioactive / unwanted symbols deliberately excluded.
+ for sym in ("Pu", "Tc", "Pm"):
+ assert sym not in ALLOY_PALETTE
+
+
+def test_default_config_valid_and_inverse_defaults():
+ cfg = ContinualRehearsalFullConfig()
+ assert len(cfg.inverse_scenarios) == 3
+ assert all(isinstance(sc, InverseScenario) for sc in cfg.inverse_scenarios)
+ # Plan §5 defaults: 20 seeds (17 strategy + 3 Au-Ga-Ln) + the 41-element palette. The single-
+ # value ae_align / seed_blend / diversity knobs are fixed in INVERSE_PATH_CONFIGS, not the
+ # config dataclass — see test_inverse_path_configs_match_demo.
+ assert cfg.inverse_n_seeds == 20
+ assert cfg.inverse_composition_allowed_elements == ALLOY_PALETTE
+ assert cfg.inverse_seed_explicit_append == ["Au65 Ga20 Gd15", "Au65 Ga20 Tb15", "Au65 Ga20 Dy15"]
+
+
+def test_unknown_task_raises():
+ with pytest.raises(ValueError, match="Unknown task"):
+ ContinualRehearsalFullConfig(task_sequence=["density", "not_a_task", "material_type"])
+
+
+def test_duplicate_task_raises():
+ seq = list(DEFAULT_SEQUENCE) + ["density"]
+ with pytest.raises(ValueError, match="duplicates"):
+ ContinualRehearsalFullConfig(task_sequence=seq)
+
+
+def test_fixed_tail_must_be_in_sequence():
+ with pytest.raises(ValueError, match="fixed_tail"):
+ ContinualRehearsalFullConfig(fixed_tail=["formation_energy", "not_present", "material_type"])
+
+
+@pytest.mark.parametrize("ratio_kwargs", [{"replay_ratio": -0.1}, {"replay_ratio_high": 1.5}])
+def test_replay_ratio_bounds(ratio_kwargs):
+ with pytest.raises(ValueError, match="must be in"):
+ ContinualRehearsalFullConfig(**ratio_kwargs)
+
+
+def test_allowed_elements_validation():
+ with pytest.raises(ValueError, match="non-empty"):
+ ContinualRehearsalFullConfig(inverse_composition_allowed_elements=[])
+ with pytest.raises(ValueError, match="not in DEFAULT_ELEMENTS"):
+ ContinualRehearsalFullConfig(inverse_composition_allowed_elements=["Mg", "Xx"])
+
+
+def test_inverse_scenario_length_mismatch():
+ with pytest.raises(ValueError, match="equal length"):
+ InverseScenario("bad", ["formation_energy"], [-2.0, 2.0])
+
+
+def test_scenario_task_must_be_regression():
+ # material_type is a classification task → cannot be a regression objective.
+ bad = InverseScenario("bad", ["material_type"], [1.0])
+ with pytest.raises(ValueError, match="must be a"):
+ ContinualRehearsalFullConfig(inverse_scenarios=[bad])
+
+ # a kernel-regression task is also not a scalar regression objective.
+ bad_kr = InverseScenario("bad_kr", ["dos_density"], [1.0])
+ with pytest.raises(ValueError, match="must be a"):
+ ContinualRehearsalFullConfig(inverse_scenarios=[bad_kr])
+
+
+def test_scenario_task_must_be_in_sequence():
+ short_seq = ["density", "material_type"]
+ bad = InverseScenario("bad", ["formation_energy"], [-2.0])
+ with pytest.raises(ValueError, match="not in task_sequence"):
+ ContinualRehearsalFullConfig(task_sequence=short_seq, fixed_tail=["material_type"], inverse_scenarios=[bad])
+
+
+def test_material_type_required():
+ seq = [t for t in DEFAULT_SEQUENCE if t != "material_type"]
+ with pytest.raises(ValueError, match="material_type"):
+ ContinualRehearsalFullConfig(task_sequence=seq, fixed_tail=["formation_energy"], inverse_scenarios=[])
+
+
+def test_invalid_seed_strategy():
+ with pytest.raises(ValueError, match="inverse_seed_strategy"):
+ ContinualRehearsalFullConfig(inverse_seed_strategy="bogus")
+
+
+def test_display_helpers():
+ assert _display("formation_energy") == "Formation Energy"
+ assert "Density" in _title("density")
+ assert "normalized" in _title("density") # qc scale
+ assert "z-scored" in _title("tc") # raw scale
+ assert _arrow(-2.0) == "↓"
+ assert _arrow(2.0) == "↑"
+
+
+def test_element_system_and_dedup():
+ # Element-system extraction ignores numeric ratios; dedup keeps the first per element set.
+ assert ContinualRehearsalFullRunner._element_system("Au65 Ga20 Gd15") == frozenset({"Au", "Ga", "Gd"})
+ assert ContinualRehearsalFullRunner._element_system("Au0.65Ga0.20Gd0.15") == frozenset({"Au", "Ga", "Gd"})
+ deduped = ContinualRehearsalFullRunner._dedupe_by_element_system(
+ ["Mg2 Zn1 Y1", "Mg1 Zn2 Y1", "Al1 Cu1 Fe1", "Mg3 Zn3 Y2"], n=10
+ )
+ # Mg-Zn-Y duplicates collapsed to the first occurrence; Al-Cu-Fe kept.
+ assert deduped == ["Mg2 Zn1 Y1", "Al1 Cu1 Fe1"]
+
+
+def test_parse_args_tuple_return_and_toml(tmp_path: Path):
+ toml = tmp_path / "cfg.toml"
+ toml.write_text(
+ textwrap.dedent(
+ """
+ qc_preprocessing_path = ""
+ task_sequence = ["density", "formation_energy", "magnetic_moment", "klat", "tc", "material_type"]
+ fixed_tail = ["formation_energy", "magnetic_moment", "tc", "klat", "material_type"]
+ replay_ratio_high = 0.2
+ inverse_composition_allowed_elements = ["Mg", "Al", "Cu", "Pd"]
+
+ [[inverse_scenarios]]
+ name = "s1"
+ reg_tasks = ["formation_energy", "klat"]
+ reg_targets = [-2.0, 2.0]
+
+ [[inverse_scenarios]]
+ name = "s2"
+ reg_tasks = ["formation_energy", "tc", "magnetic_moment"]
+ reg_targets = [-2.0, 2.0, 2.0]
+ """
+ ),
+ encoding="utf-8",
+ )
+ cfg, args = _parse_args(["--config-file", str(toml), "--sample-per-dataset", "500", "--max-epochs-per-step", "2"])
+ # Empty-string path field becomes None (no dropped_idx filtering).
+ assert cfg.qc_preprocessing_path is None
+ # inverse_scenarios dicts are coerced to InverseScenario objects.
+ assert [sc.name for sc in cfg.inverse_scenarios] == ["s1", "s2"]
+ assert all(isinstance(sc, InverseScenario) for sc in cfg.inverse_scenarios)
+ # CLI overrides land on the config; the palette override propagates from TOML.
+ assert cfg.sample_per_dataset == 500
+ assert cfg.max_epochs_per_step == 2
+ assert cfg.replay_ratio_high == 0.2
+ assert cfg.inverse_composition_allowed_elements == ["Mg", "Al", "Cu", "Pd"]
+ # Namespace returned alongside config so main() can read --inverse-only.
+ assert args.inverse_only is None
+
+
+def test_parse_args_inverse_only_flag(tmp_path: Path):
+ ckpt = tmp_path / "model.pt"
+ ckpt.write_bytes(b"placeholder") # presence-only; loading is exercised by smoke
+ _cfg, args = _parse_args(["--inverse-only", str(ckpt)])
+ assert args.inverse_only == ckpt
+
+
+def test_parse_args_unknown_key_ignored(tmp_path: Path):
+ toml = tmp_path / "cfg.toml"
+ toml.write_text("totally_unknown_key = 7\nreplay_ratio = 0.05\n", encoding="utf-8")
+ cfg, _args = _parse_args(["--config-file", str(toml)])
+ assert cfg.replay_ratio == 0.05
+ assert not hasattr(cfg, "totally_unknown_key")
+
+
+def test_demo_inverse_plot_helpers_imported():
+ """The runner relies on two helpers imported from ``paper_inverse_comparison`` to draw the
+ ``qc_vs_secondary_scatter`` and ``seed_to_optimized__*`` figures. If those imports drift
+ the inverse-design loop silently loses both figure groups (no test would catch a missing
+ plot without this guard, because the runner's training loop is only smoke-tested).
+ """
+ from foundation_model.scripts import continual_rehearsal_full as crf
+ from foundation_model.scripts.paper_inverse_comparison import (
+ _plot_qc_vs_reg_scatter as demo_scatter,
+ )
+ from foundation_model.scripts.paper_inverse_comparison import (
+ _plot_seed_to_optimized_mapping as demo_mapping,
+ )
+
+ assert crf._plot_qc_vs_reg_scatter is demo_scatter
+ assert crf._plot_seed_to_optimized_mapping is demo_mapping
diff --git a/src/foundation_model/scripts/eval_inverse_methods.py b/src/foundation_model/scripts/eval_inverse_methods.py
new file mode 100644
index 0000000..7d4e4ad
--- /dev/null
+++ b/src/foundation_model/scripts/eval_inverse_methods.py
@@ -0,0 +1,443 @@
+# Copyright 2025 TsumiNa.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Compare two inverse-design methods on a single trained checkpoint.
+
+Method A — latent-space optimisation with AE-alignment penalty
+ optimize_latent(optimize_space="latent", class_target_weight=…, ae_align_scale=λ).
+ The optimised latent is decoded back to a descriptor through the AE; the heads' values at
+ the **decoded** descriptor are reported (so "round-trip drift" is the key failure mode and
+ cycle-consistency is the proposed mitigation, swept over λ).
+
+Method B — composition-space optimisation via differentiable KMD
+ optimize_composition(kmd_kernel, class_target_weight=…). The optimisation variable IS the
+ element-weight recipe ``w``; descriptor is ``w @ K``; there is no AE in the loop.
+
+Both methods run on the **same model**, **same seed compositions**, and **same targets** so the
+two columns are directly comparable. Output is a JSON summary + a comparison PNG.
+
+This script is independent of the rehearsal demo — its own CLI, own output dir, no rehearsal.
+
+ python -m foundation_model.scripts.eval_inverse_methods \\
+ --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml \\
+ --checkpoint artifacts/inverse_heads_finetuned/final_model.pt \\
+ --output-dir artifacts/inverse_methods_eval \\
+ --align-scales 0,0.25,0.5,0.75,1.0
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import time
+from pathlib import Path
+from typing import Any
+
+import matplotlib
+
+matplotlib.use("Agg")
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from lightning import seed_everything
+from loguru import logger
+
+from foundation_model.scripts.continual_rehearsal_demo import (
+ QC_CLASSES,
+ ContinualRehearsalConfig,
+ ContinualRehearsalRunner,
+)
+from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS, formula_to_composition
+
+
+# --- Helpers ------------------------------------------------------------------
+
+
+def _qc_prob(model, x: torch.Tensor) -> np.ndarray:
+ with torch.no_grad():
+ h = torch.tanh(model.encoder(x))
+ probs = torch.softmax(model.task_heads["material_type"](h), dim=-1)
+ return probs[:, QC_CLASSES].sum(dim=-1).cpu().numpy()
+
+
+def _reg_preds(model, x: torch.Tensor, tasks: list[str]) -> dict[str, np.ndarray]:
+ with torch.no_grad():
+ h = torch.tanh(model.encoder(x))
+ return {t: model.task_heads[t](h).squeeze(-1).cpu().numpy() for t in tasks}
+
+
+def _seed_weights_from_compositions(seeds: list[str], n_components: int) -> torch.Tensor:
+ """Element-weight tensor (B, n_components) for ``optimize_composition`` seeding."""
+ rows = []
+ for c in seeds:
+ w = formula_to_composition(c)
+ if w is None:
+ raise ValueError(f"Cannot parse seed composition '{c}' to element weights.")
+ rows.append(np.asarray(w, dtype=np.float64))
+ return torch.tensor(np.stack(rows), dtype=torch.float64)
+
+
+def _decode_latent_path(kmd, descriptors: np.ndarray) -> list[str]:
+ """Latent path's composition output: AE-decoded descriptor → KMD.inverse → formula string."""
+ try:
+ weights = kmd.inverse(descriptors)
+ except Exception as exc: # pragma: no cover
+ logger.warning(f"KMD.inverse failed ({exc}); skipping composition decoding.")
+ return [""] * descriptors.shape[0]
+ return _format_weights(weights)
+
+
+def _format_weights(weights: np.ndarray, top_k: int = 6, eps: float = 1e-3) -> list[str]:
+ """Render element-weight rows as compact formula strings (top-K elements above ``eps``)."""
+ out: list[str] = []
+ for row in weights:
+ order = np.argsort(row)[::-1]
+ parts = [f"{DEFAULT_ELEMENTS[i]}{row[i]:.3f}" for i in order[:top_k] if row[i] > eps]
+ out.append(" ".join(parts) if parts else "")
+ return out
+
+
+# --- Methods ------------------------------------------------------------------
+
+
+def _run_latent_method(
+ runner: ContinualRehearsalRunner,
+ model,
+ seeds: list[str],
+ x_seed: torch.Tensor,
+ reg_targets: dict[str, float],
+ class_weight: float,
+ align_scale: float,
+ steps: int,
+ lr: float,
+ record_trajectory: bool = False,
+) -> dict[str, Any]:
+ device = next(model.parameters()).device
+ t0 = time.perf_counter()
+ res = model.optimize_latent(
+ initial_input=x_seed,
+ task_targets=reg_targets,
+ class_targets={"material_type": QC_CLASSES},
+ class_target_weight=class_weight,
+ ae_align_scale=align_scale,
+ optimize_space="latent",
+ steps=steps,
+ lr=lr,
+ record_input_trajectory=record_trajectory,
+ )
+ elapsed = time.perf_counter() - t0
+
+ reg_names = list(reg_targets.keys())
+ achieved_latent = res.optimized_target[:, 0, :].cpu().numpy() # (B, T) in reg_targets order
+ optimized_desc = res.optimized_input[:, 0, :] # (B, x_dim) — AE-decoded descriptor
+ after_qc = _qc_prob(model, optimized_desc)
+ after_reg = _reg_preds(model, optimized_desc, reg_names)
+ decoded = _decode_latent_path(runner._kmd, optimized_desc.detach().cpu().numpy())
+ # Recover the per-seed element weights too, so downstream replotting (per-element bar charts,
+ # ratio histograms, similarity matrices) doesn't need to re-run the optimisation.
+ optimized_weights = runner._kmd.inverse(optimized_desc.detach().cpu().numpy())
+
+ out = {
+ "method": "latent",
+ "align_scale": align_scale,
+ "elapsed_s": elapsed,
+ "seeds": list(seeds),
+ "qc_after_decode": after_qc.tolist(),
+ "reg_achieved_latent": {t: achieved_latent[:, j].tolist() for j, t in enumerate(reg_names)},
+ "reg_after_decode": {t: after_reg[t].tolist() for t in reg_names},
+ "decoded_composition": decoded,
+ # Raw arrays for replotting without rerunning: (B, x_dim) descriptor and (B, n_components) weights.
+ "optimized_descriptor": optimized_desc.detach().cpu().numpy().tolist(),
+ "optimized_weights": optimized_weights.tolist(),
+ }
+ if record_trajectory:
+ # Per-step trajectory of the *post-decode* predictions and the per-step decoded weights.
+ # ``res.trajectory`` is (B, R=1, steps, T) — squeeze the restart axis to (steps, B, T).
+ # We additionally re-run the heads on the per-step decoded input so the "trajectory" we
+ # report is on the same surface as the final ``reg_after_decode`` values (the optimiser's
+ # internal latent-space predictions can diverge from the decode-then-predict ones when
+ # ``ae_align_scale`` is small — surfacing the decode-then-predict trajectory is the more
+ # honest signal for the user investigating "how does the recipe evolve").
+ out["trajectory_targets"] = res.trajectory[:, 0, :, :].cpu().numpy().transpose(1, 0, 2).tolist()
+ # (B, R=1, steps, input_dim) → (steps, B, n_components) via KMD.inverse on each step.
+ # Batched per step: KMD.inverse expects (B, input_dim) and returns (B, n_components).
+ per_step_inputs = res.input_trajectory[:, 0, :, :].cpu().numpy() # (B, steps, input_dim)
+ per_step_inputs = per_step_inputs.transpose(1, 0, 2) # (steps, B, input_dim)
+ per_step_weights = [runner._kmd.inverse(per_step_inputs[s]) for s in range(per_step_inputs.shape[0])]
+ # (steps, B, n_components)
+ import numpy as _np
+ out["trajectory_weights"] = _np.stack(per_step_weights, axis=0).tolist()
+ return out
+
+
+def _run_composition_method(
+ runner: ContinualRehearsalRunner,
+ model,
+ seeds: list[str],
+ reg_targets: dict[str, float],
+ class_weight: float,
+ steps: int,
+ lr: float,
+ allowed_elements: "str | list[str]" = "all",
+ element_step_scale: "float | dict[str, float]" = 1.0,
+) -> dict[str, Any]:
+ device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
+ kernel = runner._kmd.kernel_torch(device=device, dtype=dtype)
+ w_seed = _seed_weights_from_compositions(seeds, n_components=len(DEFAULT_ELEMENTS))
+
+ t0 = time.perf_counter()
+ res = model.optimize_composition(
+ kernel,
+ initial_weights=w_seed,
+ task_targets=reg_targets,
+ class_targets={"material_type": QC_CLASSES},
+ class_target_weight=class_weight,
+ allowed_elements=allowed_elements,
+ element_step_scale=element_step_scale,
+ steps=steps,
+ lr=lr,
+ )
+ elapsed = time.perf_counter() - t0
+
+ reg_names = list(reg_targets.keys())
+ achieved = res.optimized_target.cpu().numpy() # (B, T)
+ optimized_desc = res.optimized_descriptor # (B, x_dim) — w @ K, no decode
+ final_qc = _qc_prob(model, optimized_desc)
+ final_reg = _reg_preds(model, optimized_desc, reg_names)
+ w_final = res.optimized_weights.cpu().numpy()
+
+ return {
+ "method": "composition",
+ "align_scale": None,
+ "elapsed_s": elapsed,
+ "seeds": list(seeds),
+ # In composition space there is no "after-decode" drift — the model values AT the optimised
+ # ``w`` are the same as at the descriptor ``w @ K``. We still report both for symmetry.
+ "qc_after_decode": final_qc.tolist(),
+ "reg_achieved_latent": {t: achieved[:, j].tolist() for j, t in enumerate(reg_names)},
+ "reg_after_decode": {t: final_reg[t].tolist() for t in reg_names},
+ "decoded_composition": _format_weights(w_final),
+ # Raw arrays for replotting without rerunning: (B, x_dim) descriptor and (B, n_components) weights.
+ "optimized_descriptor": optimized_desc.detach().cpu().numpy().tolist(),
+ "optimized_weights": w_final.tolist(),
+ }
+
+
+# --- Plot ---------------------------------------------------------------------
+
+
+def _plot_summary(results: list[dict[str, Any]], reg_targets: dict[str, float], out_path: Path) -> None:
+ """Side-by-side: QC prob and each regression target across methods (mean ± seeds)."""
+ fig, axes = plt.subplots(1, 1 + len(reg_targets), figsize=(4.6 * (1 + len(reg_targets)), 4.2), squeeze=False)
+ axes = axes[0]
+ labels = [f"latent (α={r['align_scale']})" if r["method"] == "latent" else "composition" for r in results]
+
+ # QC probability
+ qc_means = [float(np.mean(r["qc_after_decode"])) for r in results]
+ qc_stds = [float(np.std(r["qc_after_decode"])) for r in results]
+ x = np.arange(len(results))
+ axes[0].bar(x, qc_means, yerr=qc_stds, color="#55A868", capsize=3)
+ axes[0].axhline(1.0, color="#C44E52", ls="--", lw=1.4, label="target = 1.0")
+ axes[0].set_xticks(x, labels, rotation=30, ha="right")
+ axes[0].set_ylim(-0.02, 1.05)
+ axes[0].set_ylabel("P(quasicrystal)")
+ axes[0].set_title("Quasicrystal Probability (primary)")
+ axes[0].legend(fontsize=9, loc="lower right")
+
+ for ax, (t, tgt) in zip(axes[1:], reg_targets.items()):
+ means = [float(np.mean(r["reg_after_decode"][t])) for r in results]
+ stds = [float(np.std(r["reg_after_decode"][t])) for r in results]
+ ax.bar(x, means, yerr=stds, color="#4C72B0", capsize=3)
+ ax.axhline(tgt, color="#C44E52", ls="--", lw=1.4, label=f"target = {tgt:+.1f}")
+ ax.set_xticks(x, labels, rotation=30, ha="right")
+ ax.set_ylabel("Predicted value")
+ ax.set_title(f"{t}")
+ ax.legend(fontsize=9, loc="best")
+
+ fig.suptitle("Inverse-design methods compared (same model, same seeds, same targets)", y=1.04)
+ fig.savefig(out_path, dpi=150, bbox_inches="tight")
+ plt.close(fig)
+
+
+# --- Main flow ----------------------------------------------------------------
+
+
+def evaluate(
+ config: ContinualRehearsalConfig,
+ ckpt_path: Path,
+ align_scales: list[float],
+ allowed_elements: "str | list[str]" = "all",
+ element_step_scale: "float | dict[str, float]" = 1.0,
+) -> None:
+ seed_everything(config.random_seed, workers=True)
+ runner = ContinualRehearsalRunner(config)
+ model = runner._build_full_model()
+
+ state = torch.load(ckpt_path, map_location="cpu", weights_only=True)
+ state_dict = state["model"] if isinstance(state, dict) and "model" in state else state
+ model.load_state_dict(state_dict)
+ model.eval()
+
+ # Deterministic seed compositions: same set for both methods. We reuse the demo's "top-QC
+ # training composition" selector so this matches what users see from continual_rehearsal_demo.
+ device = next(model.parameters()).device
+
+ def _qc_prob_fn(x: torch.Tensor) -> np.ndarray:
+ return _qc_prob(model, x)
+
+ seeds = runner._select_seeds(model, device, _qc_prob_fn)
+ if not seeds:
+ raise RuntimeError("No seed compositions selected (check inverse_seed_strategy / data).")
+ x_seed, seeds = runner._descriptor_tensor(seeds, device)
+ logger.info(f"Selected {len(seeds)} seed compositions")
+
+ reg_targets = {t: v for t, v in zip(config.inverse_reg_tasks, config.inverse_reg_targets)}
+
+ results: list[dict[str, Any]] = []
+
+ # Method A: latent-space, sweep ae_align_scale ∈ [0, 1].
+ for lam in align_scales:
+ logger.info(f"--- Latent method, ae_align_scale = {lam} ---")
+ results.append(
+ _run_latent_method(
+ runner,
+ model,
+ seeds,
+ x_seed,
+ reg_targets,
+ class_weight=config.inverse_class_weight,
+ align_scale=float(lam),
+ steps=config.inverse_steps,
+ lr=config.inverse_lr,
+ )
+ )
+
+ # Method B: differentiable KMD, single run (no λ). Element constraints (if any) only apply here.
+ logger.info("--- Composition method (differentiable KMD) ---")
+ if isinstance(allowed_elements, list):
+ logger.info(f" allowed_elements: {len(allowed_elements)} symbol(s) — {allowed_elements}")
+ if isinstance(element_step_scale, dict):
+ logger.info(f" element_step_scale: {element_step_scale}")
+ elif isinstance(element_step_scale, (int, float)) and float(element_step_scale) != 1.0:
+ logger.info(f" element_step_scale (uniform): {element_step_scale}")
+ results.append(
+ _run_composition_method(
+ runner,
+ model,
+ seeds,
+ reg_targets,
+ class_weight=config.inverse_class_weight,
+ steps=config.inverse_steps,
+ lr=config.inverse_lr,
+ allowed_elements=allowed_elements,
+ element_step_scale=element_step_scale,
+ )
+ )
+
+ out_dir = Path(config.output_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ # Compact human-readable summary alongside the full per-seed JSON.
+ summary = []
+ for r in results:
+ row = {
+ "label": f"latent α={r['align_scale']}" if r["method"] == "latent" else "composition",
+ "elapsed_s": round(r["elapsed_s"], 2),
+ "qc_after_mean": round(float(np.mean(r["qc_after_decode"])), 4),
+ }
+ for t in reg_targets:
+ row[f"{t}_after_mean"] = round(float(np.mean(r["reg_after_decode"][t])), 3)
+ summary.append(row)
+ logger.info("=== Summary ===")
+ for row in summary:
+ logger.info(row)
+
+ (out_dir / "eval_inverse_methods.json").write_text(
+ json.dumps({"reg_targets": reg_targets, "results": results, "summary": summary}, indent=2),
+ encoding="utf-8",
+ )
+ _plot_summary(results, reg_targets, out_dir / "eval_inverse_methods.png")
+ logger.info(f"Wrote {out_dir / 'eval_inverse_methods.json'} and the comparison plot.")
+
+
+def _parse_args(argv: list[str] | None = None) -> tuple[ContinualRehearsalConfig, argparse.Namespace]:
+ parser = argparse.ArgumentParser(description="Compare inverse-design methods on a trained checkpoint.")
+ parser.add_argument("--config-file", type=Path, required=True)
+ parser.add_argument("--checkpoint", type=Path, required=True)
+ parser.add_argument("--output-dir", type=Path, required=True)
+ parser.add_argument(
+ "--align-scales",
+ type=str,
+ default="0,0.25,0.5,0.75,1.0",
+ help="Comma-separated values in [0, 1] for ae_align_scale in the latent method.",
+ )
+ parser.add_argument(
+ "--allowed-elements",
+ type=str,
+ default="",
+ help=(
+ "Comma-separated element symbols the composition method is allowed to use (hard "
+ "whitelist; e.g. 'Mg,Al,Cu,Ni,Zn,Ag'). Empty means every element allowed."
+ ),
+ )
+ parser.add_argument(
+ "--locked-elements",
+ type=str,
+ default="",
+ help=(
+ "Comma-separated element symbols whose composition weight is frozen at the seed "
+ "value (sets element_step_scale to --locked-step-scale; default 0 = fully locked)."
+ ),
+ )
+ parser.add_argument(
+ "--locked-step-scale",
+ type=float,
+ default=0.0,
+ help="Gradient multiplier for locked elements (0 = fully locked; 0.1 = slow drift).",
+ )
+ args = parser.parse_args(argv)
+
+ import tomllib
+
+ data = tomllib.loads(args.config_file.read_text(encoding="utf-8"))
+ data["output_dir"] = str(args.output_dir)
+ field_names = set(ContinualRehearsalConfig.__dataclass_fields__)
+ path_fields = {
+ "qc_data_path",
+ "qc_preprocessing_path",
+ "superconductor_path",
+ "magnetic_path",
+ "phonix_path",
+ "output_dir",
+ }
+ kwargs: dict[str, object] = {}
+ for key, value in data.items():
+ if key not in field_names:
+ continue
+ kwargs[key] = Path(value) if key in path_fields and value is not None else value
+ return ContinualRehearsalConfig(**kwargs), args
+
+
+def main(argv: list[str] | None = None) -> None:
+ config, args = _parse_args(argv)
+ align_scales = [float(x) for x in args.align_scales.split(",") if x.strip()]
+ allowed_syms = [s.strip() for s in args.allowed_elements.split(",") if s.strip()]
+ locked_syms = [s.strip() for s in args.locked_elements.split(",") if s.strip()]
+ # Pass symbols straight through to optimize_composition's symbol-based API.
+ allowed_arg: "str | list[str]" = allowed_syms if allowed_syms else "all"
+ step_scale_arg: "float | dict[str, float]" = (
+ {s: args.locked_step_scale for s in locked_syms} if locked_syms else 1.0
+ )
+ evaluate(
+ config,
+ args.checkpoint,
+ align_scales,
+ allowed_elements=allowed_arg,
+ element_step_scale=step_scale_arg,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/foundation_model/scripts/finetune_inverse_heads.py b/src/foundation_model/scripts/finetune_inverse_heads.py
new file mode 100644
index 0000000..ec989f4
--- /dev/null
+++ b/src/foundation_model/scripts/finetune_inverse_heads.py
@@ -0,0 +1,215 @@
+# Copyright 2025 TsumiNa.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Targeted fine-tune of the three heads used by inverse design.
+
+Loads a ``final_model.pt`` checkpoint produced by ``continual_rehearsal_demo``, freezes the
+encoder and every other task head (including the autoencoder), and runs a short fine-tune on
+just the three inverse-design heads — by default ``formation_energy``, ``klat`` and
+``material_type`` — so they are as sharp as possible before we compare inverse-design methods
+(latent-with-cycle-consistency vs differentiable KMD).
+
+The script is **independent of the rehearsal demo** (its own CLI, output dir, and checkpoint).
+It reuses the demo runner only for data loading + model reconstruction; no rehearsal loop is run.
+
+ python -m foundation_model.scripts.finetune_inverse_heads \\
+ --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml \\
+ --checkpoint artifacts/continual_rehearsal_inverse_baseline/final_model.pt \\
+ --output-dir artifacts/inverse_heads_finetuned \\
+ --epochs 30
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+from pathlib import Path
+from typing import Iterable
+
+import torch
+from lightning import Trainer, seed_everything
+from loguru import logger
+
+from foundation_model.data.datamodule import CompoundDataModule
+from foundation_model.scripts.continual_rehearsal_demo import (
+ ContinualRehearsalConfig,
+ ContinualRehearsalRunner,
+ _parse_args as _demo_parse_args, # noqa: F401 (kept for documentation; we parse our own args)
+)
+
+DEFAULT_INVERSE_HEADS = ("formation_energy", "klat", "material_type")
+
+
+def freeze_except(model, keep_heads: Iterable[str]) -> dict[str, bool]:
+ """Freeze encoder + every head NOT in ``keep_heads`` + task_log_sigmas; return prior requires_grad state.
+
+ The model's ``task_log_sigmas`` ParameterDict holds the learnable loss-balancer coefficients
+ (one scalar per task, active when ``enable_learnable_loss_balancer=True``). Without freezing
+ them, ``configure_optimizers`` still picks them up and they move during the "head-only"
+ fine-tune — which would silently change the inverse-design objectives' relative weights and
+ make the comparison apples-to-oranges. We freeze every per-task balancer scalar here too,
+ so this script really is head-only.
+ """
+ keep = set(keep_heads)
+ saved: dict[str, bool] = {}
+ for name, p in model.named_parameters():
+ saved[name] = p.requires_grad
+ for p in model.encoder.parameters():
+ p.requires_grad_(False)
+ for head_name, head in model.task_heads.items():
+ train = head_name in keep
+ for p in head.parameters():
+ p.requires_grad_(train)
+ # Freeze every learnable-loss-balancer scalar (no-op when the balancer is disabled).
+ for p in model.task_log_sigmas.parameters():
+ p.requires_grad_(False)
+ return saved
+
+
+def _restore_requires_grad(model, saved: dict[str, bool]) -> None:
+ for name, p in model.named_parameters():
+ if name in saved:
+ p.requires_grad_(saved[name])
+
+
+def finetune(config: ContinualRehearsalConfig, ckpt_path: Path, inverse_heads: tuple[str, ...], epochs: int) -> Path:
+ seed_everything(config.random_seed, workers=True)
+ runner = ContinualRehearsalRunner(config) # loads data + builds KMD cache (same as demo)
+
+ logger.info(f"Loading model checkpoint {ckpt_path}")
+ model = runner._build_full_model()
+ state = torch.load(ckpt_path, map_location="cpu", weights_only=True)
+ state_dict = state["model"] if isinstance(state, dict) and "model" in state else state
+ model.load_state_dict(state_dict)
+
+ missing = [t for t in inverse_heads if t not in model.task_heads]
+ if missing:
+ raise ValueError(
+ f"Heads {missing} not found in the loaded model (have {list(model.task_heads.keys())}). "
+ "Check that the checkpoint was produced with the same task_sequence."
+ )
+
+ logger.info(f"Freezing everything except heads: {sorted(inverse_heads)}")
+ freeze_except(model, inverse_heads)
+
+ # Deactivate every non-inverse head so the Trainer's validation_step doesn't try to forward
+ # them on a batch that only carries the three inverse-head columns. ``disable_task`` keeps the
+ # weights in ``model.disabled_task_heads`` (so the saved state_dict still contains them) but
+ # removes them from ``model.task_heads`` so the forward loop iterates only over the inverse
+ # ones. Important for KR heads (e.g. ``dos_density``) whose forward expects a ``t_sequences``
+ # entry that the inverse-only DataModule does not provide.
+ other_active = [name for name in list(model.task_heads.keys()) if name not in inverse_heads]
+ if other_active:
+ logger.info(f"Disabling {len(other_active)} non-inverse head(s) for the duration of fine-tune: {other_active}")
+ model.disable_task(*other_active)
+
+ # Use the same task configs as training (built by the runner), but restrict the DataModule to
+ # the inverse-head tasks and disable masking (we want all available labels for these heads).
+ task_configs = {name: runner._build_task_config(name) for name in inverse_heads}
+ for cfg in task_configs.values():
+ cfg.task_masking_ratio = 1.0 # no rehearsal-style dropout — we want every label
+
+ datamodule = CompoundDataModule(
+ task_configs=list(task_configs.values()),
+ descriptor_fn=runner.descriptor_fn,
+ task_frames={name: runner.task_frames[name] for name in inverse_heads},
+ composition_column="composition",
+ random_seed=config.datamodule_random_seed,
+ batch_size=config.batch_size,
+ num_workers=config.num_workers,
+ )
+
+ trainer = Trainer(
+ max_epochs=epochs,
+ accelerator=config.accelerator,
+ devices=config.devices,
+ logger=False,
+ enable_checkpointing=False,
+ enable_progress_bar=False,
+ )
+ trainer.fit(model, datamodule=datamodule)
+
+ # Re-activate the heads we hid so the saved state_dict's key layout matches what
+ # paper_inverse_comparison / eval_inverse_methods rebuild (all heads under ``task_heads``).
+ if other_active:
+ logger.info(f"Re-enabling {len(other_active)} previously-disabled head(s) before save.")
+ model.enable_task(*other_active)
+
+ out_path = Path(config.output_dir) / "final_model.pt"
+ Path(config.output_dir).mkdir(parents=True, exist_ok=True)
+ torch.save(
+ {
+ "model": model.state_dict(),
+ "task_sequence": list(config.task_sequence),
+ "finetuned_heads": list(inverse_heads),
+ "finetune_epochs": int(epochs),
+ "from_checkpoint": str(ckpt_path),
+ },
+ out_path,
+ )
+ (Path(config.output_dir) / "finetune_summary.json").write_text(
+ json.dumps(
+ {
+ "from_checkpoint": str(ckpt_path),
+ "finetuned_heads": list(inverse_heads),
+ "epochs": int(epochs),
+ "task_sequence": list(config.task_sequence),
+ },
+ indent=2,
+ ),
+ encoding="utf-8",
+ )
+ logger.info(f"Saved fine-tuned checkpoint to {out_path}")
+ return out_path
+
+
+def _parse_args(argv: list[str] | None = None) -> tuple[ContinualRehearsalConfig, argparse.Namespace]:
+ parser = argparse.ArgumentParser(description="Targeted fine-tune of inverse-design heads.")
+ parser.add_argument("--config-file", type=Path, required=True, help="Demo config (paths + task_sequence).")
+ parser.add_argument(
+ "--checkpoint", type=Path, required=True, help="final_model.pt produced by continual_rehearsal_demo."
+ )
+ parser.add_argument(
+ "--output-dir", type=Path, required=True, help="Where to write the fine-tuned checkpoint + summary."
+ )
+ parser.add_argument("--epochs", type=int, default=20, help="Fine-tune epochs (default 20).")
+ parser.add_argument(
+ "--inverse-heads",
+ type=str,
+ default=",".join(DEFAULT_INVERSE_HEADS),
+ help=f"Comma-separated head names to fine-tune. Default: {','.join(DEFAULT_INVERSE_HEADS)}.",
+ )
+ args = parser.parse_args(argv)
+
+ # Build the demo config (reuses the same TOML schema), overriding output_dir.
+ import tomllib
+
+ data = tomllib.loads(args.config_file.read_text(encoding="utf-8"))
+ data["output_dir"] = str(args.output_dir)
+ field_names = set(ContinualRehearsalConfig.__dataclass_fields__)
+ path_fields = {
+ "qc_data_path",
+ "qc_preprocessing_path",
+ "superconductor_path",
+ "magnetic_path",
+ "phonix_path",
+ "output_dir",
+ }
+ kwargs: dict[str, object] = {}
+ for key, value in data.items():
+ if key not in field_names:
+ continue
+ kwargs[key] = Path(value) if key in path_fields and value is not None else value
+ config = ContinualRehearsalConfig(**kwargs)
+ return config, args
+
+
+def main(argv: list[str] | None = None) -> None:
+ config, args = _parse_args(argv)
+ heads = tuple(h.strip() for h in args.inverse_heads.split(",") if h.strip())
+ finetune(config, args.checkpoint, heads, args.epochs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/foundation_model/scripts/finetune_inverse_heads_test.py b/src/foundation_model/scripts/finetune_inverse_heads_test.py
new file mode 100644
index 0000000..3ca2725
--- /dev/null
+++ b/src/foundation_model/scripts/finetune_inverse_heads_test.py
@@ -0,0 +1,103 @@
+# Copyright 2026 TsumiNa.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Tests for ``finetune_inverse_heads.freeze_except`` — the per-parameter freeze contract.
+
+The full ``finetune`` entry point needs a real checkpoint + data parquets, so it's exercised by
+the smoke runs under ``artifacts/inverse_design_run/finetune/``. The unit-testable piece is the
+freeze logic, which is the most refactor-fragile part: a future change that accidentally
+un-freezes the encoder (or forgets the per-task loss-balancer scalars) would silently break
+the "apples-to-apples" comparison the script exists to enable.
+"""
+
+from __future__ import annotations
+
+import pytest
+import torch
+
+from foundation_model.models.flexible_multi_task_model import FlexibleMultiTaskModel
+from foundation_model.models.model_config import (
+ ClassificationTaskConfig,
+ MLPEncoderConfig,
+ RegressionTaskConfig,
+)
+from foundation_model.scripts.finetune_inverse_heads import freeze_except
+
+
+INPUT_DIM = 16
+LATENT_DIM = 8
+
+
+def _make_model(enable_balancer: bool = False) -> FlexibleMultiTaskModel:
+ """Three-head model mirroring the inverse-design tail (formation_energy / klat / material_type).
+
+ ``enable_autoencoder=False`` keeps the test fast — the freeze contract doesn't depend on the
+ AE head; the smoke run covers that path.
+ """
+ enc = MLPEncoderConfig(hidden_dims=[INPUT_DIM, LATENT_DIM])
+ tasks = [
+ RegressionTaskConfig(name="formation_energy", data_column="formation_energy", dims=[LATENT_DIM, 4, 1]),
+ RegressionTaskConfig(name="klat", data_column="klat", dims=[LATENT_DIM, 4, 1]),
+ ClassificationTaskConfig(name="material_type", data_column="material_type", num_classes=3, dims=[LATENT_DIM, 4, 3]),
+ # An extra head that should be frozen (simulates ``density`` / ``tc`` / etc. in the real tail).
+ RegressionTaskConfig(name="density", data_column="density", dims=[LATENT_DIM, 4, 1]),
+ ]
+ return FlexibleMultiTaskModel(
+ task_configs=tasks,
+ encoder_config=enc,
+ enable_learnable_loss_balancer=enable_balancer,
+ )
+
+
+def _grad_state(model) -> dict[str, bool]:
+ return {name: p.requires_grad for name, p in model.named_parameters()}
+
+
+def test_freeze_except_freezes_encoder_and_unkept_heads():
+ """Encoder + every head NOT in ``keep`` is frozen; kept heads remain trainable."""
+ model = _make_model()
+ inverse_heads = ("formation_energy", "klat", "material_type")
+ freeze_except(model, inverse_heads)
+
+ # Encoder: every param frozen.
+ assert all(not p.requires_grad for p in model.encoder.parameters())
+ # Kept heads: every param trainable.
+ for head in inverse_heads:
+ assert all(p.requires_grad for p in model.task_heads[head].parameters()), f"{head!r} should be trainable"
+ # Non-kept head (``density``): every param frozen.
+ assert all(not p.requires_grad for p in model.task_heads["density"].parameters())
+
+
+def test_freeze_except_freezes_task_log_sigmas_when_balancer_enabled():
+ """The learnable per-task loss-balancer scalars MUST be frozen, otherwise the optimiser
+ silently shifts the inverse heads' relative weights during the head-only fine-tune and
+ the downstream comparison stops being apples-to-apples."""
+ model = _make_model(enable_balancer=True)
+ # Sanity check: balancer is on so task_log_sigmas has at least one parameter. ``any()``
+ # would unwrap to the scalar's bool (0.0 is falsy) — we want a count check instead.
+ assert len(list(model.task_log_sigmas.parameters())) > 0, "fixture must register balancer scalars"
+ freeze_except(model, ("formation_energy", "klat", "material_type"))
+ assert all(not p.requires_grad for p in model.task_log_sigmas.parameters())
+
+
+def test_freeze_except_returns_pre_freeze_requires_grad_state():
+ """The ``saved`` dict captures the pre-call ``requires_grad`` for every named parameter —
+ used by ``_restore_requires_grad`` if a caller wants to roll back. The contract is that the
+ returned dict has one entry per ``named_parameters()`` key."""
+ model = _make_model()
+ pre = _grad_state(model)
+ saved = freeze_except(model, ("formation_energy",))
+ assert set(saved.keys()) == set(pre.keys())
+ # All params were trainable before freezing → saved should reflect that.
+ assert all(v is True for v in saved.values())
+
+
+def test_freeze_except_handles_unknown_keep_head_silently():
+ """An unknown ``keep_heads`` entry is *not* an error in this helper — it simply means
+ no head matches, and every head ends up frozen. This is the right contract for a low-level
+ freeze; the caller (``finetune``) is responsible for validating head names against the
+ loaded checkpoint upstream (see ``finetune`` raising on ``missing`` heads)."""
+ model = _make_model()
+ freeze_except(model, ("not_a_head",))
+ for head in model.task_heads.values():
+ assert all(not p.requires_grad for p in head.parameters())
diff --git a/src/foundation_model/scripts/paper_inverse_3scenarios.py b/src/foundation_model/scripts/paper_inverse_3scenarios.py
new file mode 100644
index 0000000..306b07b
--- /dev/null
+++ b/src/foundation_model/scripts/paper_inverse_3scenarios.py
@@ -0,0 +1,166 @@
+# Copyright 2025 TsumiNa.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Run the paper-grade inverse-design comparison across multiple scenarios on a single checkpoint.
+
+This is a thin orchestrator around :mod:`paper_inverse_comparison`. The TOML config is expected to
+contain a ``[[inverse_scenarios]]`` array of tables (see plan §5), each entry overriding
+``reg_tasks`` / ``reg_targets`` for one scenario. The script loops over the scenarios and writes
+each one's outputs into ``//`` so the per-scenario files (figures, raw
+arrays, summary) stay isolated.
+
+Layout::
+
+ /
+ scenario1_fe_down_magnetic_up/
+ final_model.pt # copy of the input checkpoint (self-contained)
+ seeds.json
+ results.json # per-seed raw arrays for all 11 paths (latent α-sweep + 5 comp)
+ comparison.png # headline 3-panel bar chart
+ SUMMARY.md
+ scenario.json # this scenario's reg_tasks/reg_targets
+ scenario2_fe_down_tc_up_magnetic_up/
+ ...
+ scenario3_fe_down_klat_up/
+ ...
+ README.md # cross-scenario summary index (hand-written downstream)
+
+The trained model has to expose every regression head listed in any scenario's ``reg_tasks``;
+otherwise the per-scenario run will fail loudly at the model side. ``material_type`` (the
+classification head) is implicit and always required for the QC primary objective.
+
+Run:
+ python -m foundation_model.scripts.paper_inverse_3scenarios \\
+ --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml \\
+ --checkpoint artifacts/inverse_design_run/finetune/final_model.pt \\
+ --output-dir artifacts/inverse_design_run/inverse_design
+"""
+
+from __future__ import annotations
+
+import argparse
+import dataclasses
+import json
+import tomllib
+from pathlib import Path
+from typing import Any
+
+from loguru import logger
+
+from foundation_model.scripts.continual_rehearsal_demo import ContinualRehearsalConfig
+from foundation_model.scripts.paper_inverse_comparison import _parse_args as _paper_parse_args
+from foundation_model.scripts.paper_inverse_comparison import run as paper_run
+
+
+def _load_scenarios(config_file: Path) -> list[dict[str, Any]]:
+ """Pull the ``[[inverse_scenarios]]`` array out of the TOML and validate it."""
+ raw = tomllib.loads(config_file.read_text(encoding="utf-8"))
+ scenarios = raw.get("inverse_scenarios", [])
+ if not scenarios:
+ raise ValueError(
+ f"No [[inverse_scenarios]] array found in {config_file}. "
+ "Add the array (with name/reg_tasks/reg_targets) per plan §5 first."
+ )
+ for sc in scenarios:
+ missing = {"name", "reg_tasks", "reg_targets"} - set(sc)
+ if missing:
+ raise ValueError(f"Scenario missing required fields {sorted(missing)}: {sc!r}.")
+ if len(sc["reg_tasks"]) != len(sc["reg_targets"]):
+ raise ValueError(
+ f"reg_tasks and reg_targets length mismatch in scenario {sc['name']!r}: "
+ f"{len(sc['reg_tasks'])} vs {len(sc['reg_targets'])}."
+ )
+ return scenarios
+
+
+def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="Paper-grade inverse-design comparison across multiple scenarios.")
+ parser.add_argument("--config-file", type=Path, required=True)
+ parser.add_argument("--checkpoint", type=Path, required=True)
+ parser.add_argument(
+ "--output-dir",
+ type=Path,
+ required=True,
+ help="Parent folder; each scenario writes into //.",
+ )
+ # Trajectory flags — forwarded verbatim to each scenario's ``paper_inverse_comparison.run()``.
+ parser.add_argument(
+ "--record-trajectory",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="Record per-step trajectory (default on; --no-record-trajectory to skip).",
+ )
+ parser.add_argument(
+ "--per-seed-trajectories",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="Per-(path × seed) trajectory plots/animations (default on; --no-per-seed-trajectories to skip).",
+ )
+ parser.add_argument(
+ "--animation-formats",
+ nargs="+",
+ choices=["gif", "html", "svg", "none"],
+ default=["gif"],
+ help="One or more trajectory-animation formats (default: gif).",
+ )
+ return parser.parse_args(argv)
+
+
+def main(argv: list[str] | None = None) -> None:
+ args = _parse_args(argv)
+ scenarios = _load_scenarios(args.config_file)
+ logger.info(f"Loaded {len(scenarios)} inverse-design scenarios from {args.config_file}.")
+ args.output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Build a baseline config once by re-using the single-scenario parser. We then ``replace`` it
+ # per-scenario to override ``inverse_reg_tasks`` / ``inverse_reg_targets`` / ``output_dir``.
+ paper_argv = [
+ "--config-file",
+ str(args.config_file),
+ "--checkpoint",
+ str(args.checkpoint),
+ "--output-dir",
+ str(args.output_dir / scenarios[0]["name"]), # placeholder; overridden below
+ ]
+ base_config, _ = _paper_parse_args(paper_argv)
+
+ for sc in scenarios:
+ sc_dir = args.output_dir / sc["name"]
+ sc_config: ContinualRehearsalConfig = dataclasses.replace(
+ base_config,
+ inverse_reg_tasks=list(sc["reg_tasks"]),
+ inverse_reg_targets=list(sc["reg_targets"]),
+ output_dir=sc_dir,
+ )
+ logger.info(f"=== Scenario {sc['name']} ===")
+ logger.info(f" reg_tasks : {sc['reg_tasks']}")
+ logger.info(f" reg_targets : {sc['reg_targets']}")
+ logger.info(f" output : {sc_dir}")
+ paper_run(
+ sc_config,
+ args.checkpoint,
+ record_trajectory=args.record_trajectory,
+ per_seed_trajectories=args.per_seed_trajectories,
+ animation_formats=tuple(args.animation_formats),
+ )
+ # Drop a per-scenario meta file so future readers don't need to chase results.json's
+ # `config` block to learn what this folder represents.
+ (sc_dir / "scenario.json").write_text(
+ json.dumps(
+ {
+ "name": sc["name"],
+ "reg_tasks": list(sc["reg_tasks"]),
+ "reg_targets": list(sc["reg_targets"]),
+ "primary_objective": "P(material_type = QC) ↑",
+ "checkpoint": str(args.checkpoint),
+ },
+ indent=2,
+ ),
+ encoding="utf-8",
+ )
+ logger.info(f"=== {sc['name']} done ===")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/foundation_model/scripts/paper_inverse_comparison.py b/src/foundation_model/scripts/paper_inverse_comparison.py
new file mode 100644
index 0000000..97d0226
--- /dev/null
+++ b/src/foundation_model/scripts/paper_inverse_comparison.py
@@ -0,0 +1,1226 @@
+# Copyright 2025 TsumiNa.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Paper-grade comparison of inverse-design methods on a single trained checkpoint.
+
+Orchestrates a full sweep that ``eval_inverse_methods`` can do piecewise, and writes everything
+(the model checkpoint, the seed list, the raw per-seed JSON, and the figures) into one folder
+ready to drop into a paper draft. Reuses the per-method helpers from
+``eval_inverse_methods`` so the methodology is identical.
+
+The study covers:
+
+* **Latent method** with AE-alignment scale α ∈ {0, 0.25, 1.0} — failure-mode baseline, a useful
+ intermediate, and the [0, 1] upper bound. (Earlier runs swept finer; the three points are enough
+ to show the qualitative plateau.)
+* **Composition method** (differentiable KMD) under five configurations chosen to expose how
+ ``seed_blend``, the element whitelist, and seeding strategy affect novelty / diversity. Labels
+ follow a "describe the config in the label" convention:
+ 1. ``comp (seed)`` — ``seed_blend = 1.0`` (strict seed, support set frozen);
+ 2. ``comp (seed, 5% all)`` — ``seed_blend = 0.95`` (5 % uniform mixed in, all 94 elements
+ reachable but no whitelist);
+ 3. ``comp (seed, 5% all, element list)`` — (2) + ``allowed_elements = ALLOY_PALETTE``;
+ 4. ``comp (seed, 5% all, element list, low diversity)`` — (3) + ``diversity_scale = 0`` so
+ per-output entropy is penalised → peaky few-element recipes (ablation);
+ 5. ``comp (random)`` — ``initial_weights=None``, no seed bias.
+
+ python -m foundation_model.scripts.paper_inverse_comparison \\
+ --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml \\
+ --checkpoint artifacts/inverse_heads_finetuned/final_model.pt \\
+ --output-dir artifacts/paper_inverse_design
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import re
+import shutil
+from collections import Counter
+from pathlib import Path
+from typing import Any
+
+import matplotlib
+
+matplotlib.use("Agg")
+
+import matplotlib.colors as mcolors
+import matplotlib.pyplot as plt
+from matplotlib.offsetbox import AnnotationBbox, HPacker, TextArea
+import numpy as np
+import torch
+from lightning import seed_everything
+from loguru import logger
+
+from foundation_model.scripts.continual_rehearsal_common import (
+ DISCOVERED_ELEMENT_COLOR,
+ plot_element_frequency_heatmap,
+)
+from foundation_model.scripts.continual_rehearsal_demo import (
+ QC_CLASSES,
+ ContinualRehearsalConfig,
+ ContinualRehearsalRunner,
+)
+from foundation_model.scripts.eval_inverse_methods import (
+ _format_weights,
+ _qc_prob,
+ _reg_preds,
+ _run_latent_method,
+ _seed_weights_from_compositions,
+)
+
+# Feasible alloy palette for the constrained-composition runs. Designed per the plan in
+# docs/continual_rehearsal_full_PLAN.md §5: light alkaline-earth + group 13/14 + the full 4th/5th
+# period transition metals (Tc excluded for radioactivity) + the full Hf–Pt 5d TM row (added
+# 2026-05 to broaden heavy-TM coverage — reaches refractory / noble-metal i-QC families) + Au
+# (needed for Au-Ga-RE seeds) + accessible lanthanides (Pm radioactive, Tm/Lu scarce). 48 symbols
+# total — wide enough to expose multiple QC-prone basins (incl. heavy-TM families), narrow enough
+# to suppress Pu/F/Cs/Tm-style non-physical model bias.
+DEFAULT_ALLOY_PALETTE = [
+ "Mg",
+ "Ca",
+ "B",
+ "Al",
+ "Ga",
+ "In",
+ "Tl",
+ "Si",
+ "Ge",
+ "Sc",
+ "Ti",
+ "V",
+ "Cr",
+ "Mn",
+ "Fe",
+ "Co",
+ "Ni",
+ "Cu",
+ "Zn",
+ "Y",
+ "Zr",
+ "Nb",
+ "Mo",
+ "Ru",
+ "Rh",
+ "Pd",
+ "Ag",
+ "Cd",
+ # 5d transition metals (Hf–Pt). Added 2026-05; placed between Cd and Au so the 6th-period TM
+ # block is contiguous. Keeps the palette ordered by period within each group.
+ "Hf",
+ "Ta",
+ "W",
+ "Re",
+ "Os",
+ "Ir",
+ "Pt",
+ "Au",
+ "La",
+ "Ce",
+ "Pr",
+ "Nd",
+ "Sm",
+ "Eu",
+ "Gd",
+ "Tb",
+ "Dy",
+ "Ho",
+ "Er",
+ "Yb",
+]
+assert len(DEFAULT_ALLOY_PALETTE) == 48
+
+# Composition-method configurations. Each row produces one bar in the comparison plot. The first
+# two isolate the seed_blend effect; the next two layer on element constraints; the last drops the
+# seed entirely (random init) as the no-seed-bias control (Scheme D).
+COMPOSITION_CONFIGS: list[dict[str, Any]] = [
+ # diversity = 1.0 = no entropy penalty (default user-facing behaviour).
+ # Labels follow the "describe the config" convention: each comma-separated phrase names a
+ # knob that's been turned on relative to the previous row.
+ {"label": "comp\n(seed)", "init": "seed", "blend": 1.0, "allowed": "all", "scale": 1.0, "diversity": 1.0},
+ {"label": "comp\n(seed, 5% all)", "init": "seed", "blend": 0.95, "allowed": "all", "scale": 1.0, "diversity": 1.0},
+ {
+ "label": "comp\n(seed, 5% all, element list)",
+ "init": "seed",
+ "blend": 0.95,
+ "allowed": DEFAULT_ALLOY_PALETTE,
+ "scale": 1.0,
+ "diversity": 1.0,
+ },
+ {
+ # Ablation: clamp diversity to 0 → max entropy penalty → forced peaky few-element recipes.
+ "label": "comp\n(seed, 5% all,\nelement list, low diversity)",
+ "init": "seed",
+ "blend": 0.95,
+ "allowed": DEFAULT_ALLOY_PALETTE,
+ "scale": 1.0,
+ "diversity": 0.0,
+ },
+ {"label": "comp\n(random)", "init": "random", "blend": 0.95, "allowed": "all", "scale": 1.0, "diversity": 1.0},
+]
+LATENT_ALIGN_SCALES = [0.0, 0.25, 1.0] # ae_align_scale ∈ [0, 1] — three points: failure / mid / max
+
+
+#: Per-task display title with units and a directional arrow that points the way the optimiser
+#: should drive the value. Defaults applied for the two tasks the plan §5 scenarios use. The
+#: lookup falls back to the raw task name if a task isn't in the map (so the plot still works
+#: when scenarios 1 / 2 add ``magnetic_moment`` / ``tc``).
+REG_TASK_TITLES: dict[str, str] = {
+ "formation_energy": "Formation energy [eV/atom] ↓",
+ "klat": "klat [W/mK] ↑",
+ "magnetic_moment": "Magnetic moment [μB/f.u.] ↑",
+ "tc": "Critical temperature [K] ↑",
+}
+
+
+def _plot_comparison(results: list[dict[str, Any]], reg_targets: dict[str, float], out_path: Path) -> None:
+ """Three-panel comparison: QC probability + each regression target across all methods."""
+ n_panels = 1 + len(reg_targets)
+ fig, axes = plt.subplots(1, n_panels, figsize=(5.6 * n_panels, 5.6), squeeze=False)
+ axes = axes[0]
+ # Single-line labels so rotated x-ticks don't collide.
+ labels = [r["label"].replace("\n", " ") for r in results]
+ colors = ["#55A868" if r["method"] == "latent" else "#2563EB" for r in results]
+ x = np.arange(len(results))
+
+ def _set_xticks(ax):
+ ax.set_xticks(x)
+ ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=9)
+
+ # Panel 1: QC probability. The arrow makes the optimisation direction explicit at a glance.
+ qc_means = [float(np.mean(r["qc_after_decode"])) for r in results]
+ qc_stds = [float(np.std(r["qc_after_decode"])) for r in results]
+ axes[0].bar(x, qc_means, yerr=qc_stds, color=colors, capsize=3)
+ axes[0].axhline(1.0, color="#C44E52", ls="--", lw=1.4, label="target = 1.0")
+ _set_xticks(axes[0])
+ axes[0].set_ylim(-0.02, 1.05)
+ axes[0].set_ylabel("P(quasicrystal)")
+ axes[0].set_title("P(quasicrystal) ↑")
+ axes[0].legend(fontsize=9, loc="lower right")
+
+ # Remaining panels: regression targets. Title pulled from REG_TASK_TITLES with the unit and
+ # an arrow indicating whether the target is below (↓) or above (↑) the model's baseline.
+ for ax, (t, tgt) in zip(axes[1:], reg_targets.items()):
+ means = [float(np.mean(r["reg_after_decode"][t])) for r in results]
+ stds = [float(np.std(r["reg_after_decode"][t])) for r in results]
+ ax.bar(x, means, yerr=stds, color=colors, capsize=3)
+ ax.axhline(tgt, color="#C44E52", ls="--", lw=1.4, label=f"target = {tgt:+.1f}")
+ _set_xticks(ax)
+ ax.set_ylabel("Predicted value")
+ ax.set_title(REG_TASK_TITLES.get(t, t))
+ ax.legend(fontsize=9, loc="best")
+
+ fig.suptitle("Inverse-design comparison: latent (ae_align_scale sweep) vs differentiable KMD (configs)", y=1.00)
+ fig.savefig(out_path, dpi=150, bbox_inches="tight")
+ plt.close(fig)
+ logger.info(f"Wrote comparison plot to {out_path}")
+
+
+# --- seed → optimised composition mapping plot -------------------------------------------------
+
+#: Element-symbol + optional stoichiometry regex used by ``_parse_formula_to_fractions`` below.
+#: ``continual_rehearsal_common.element_set`` carries the same pattern for the *set* of element
+#: symbols; here we additionally need the amount (the second capture group) to recover fractions.
+_COMP_RE = re.compile(r"([A-Z][a-z]?)([\d.]*)")
+
+
+def _parse_formula_to_fractions(formula: str) -> dict[str, float]:
+ """Parse a composition string into ``{element: fraction}`` summing to 1.
+
+ Handles both raw-amount formulas (``"Au65 Ga20 Gd15"`` → sum=100 → normalised to 1) and
+ pre-fractional formulas (``"Mg0.691 Cd0.309"`` → already sums to ~1).
+ """
+ out: dict[str, float] = {}
+ for el, amt in _COMP_RE.findall(formula):
+ if not el:
+ continue
+ a = float(amt) if amt else 1.0
+ out[el] = out.get(el, 0.0) + a
+ tot = sum(out.values())
+ return {k: v / tot for k, v in out.items()} if tot > 0 else out
+
+
+#: Font size for composition formula text in the seed-to-optimized plot. Tuned with the
+#: ``_ROW_HEIGHT`` below to keep rows compact without text overlap.
+_MAP_FONT = 13
+_MAP_ROW_HEIGHT = 0.34 # data-unit row height; figure height scales with n_rows × this
+
+#: Short labels used inside the parenthetical block, so a row like
+#: ``Δformation_energy=-1.36`` doesn't push the right edge off the figure. Tasks not in the
+#: map fall back to their raw name (covered by the lookup default in the call site).
+_REG_DISPLAY_SHORT: dict[str, str] = {
+ "formation_energy": "FE",
+ "klat": "klat",
+ "tc": "tc",
+ "magnetization": "mag",
+ "magnetic_moment": "mm",
+}
+
+
+def _target_arrow(target_value: float, baseline: float = 0.0) -> str:
+ """Up-arrow if the target is above ``baseline`` (default 0 in z-scored regression space).
+
+ Both project reg targets are z-scored; positive target ⇒ "drive up" (↑), negative ⇒ "drive
+ down" (↓). The arrow is rendered next to each property name in the column header and in
+ every row's parenthetical block, so the reader can match the delta sign against the desired
+ direction at a glance.
+ """
+ return "↑" if target_value > baseline else "↓"
+
+
+def _render_seed_row(
+ ax,
+ x_axes_frac: float,
+ y_data: float,
+ comp: dict[str, float],
+ qc: float,
+) -> None:
+ """Draw one *seed* row: all-black text, no element colouring, with a ``(QC=XX.X%)`` suffix.
+
+ The seed side is informational — the comparison signal lives on the optimised side. Keeping
+ the seed monochrome lets the colour gradient on the right read as a pure 'what the optimiser
+ did to this seed' story.
+ """
+ if not comp:
+ return
+ items = sorted(comp.items(), key=lambda kv: -kv[1])
+ parts: list = []
+ for el, frac in items:
+ parts.append(
+ TextArea(
+ el,
+ textprops=dict(color="#111", fontweight="bold", fontsize=_MAP_FONT, fontfamily="monospace"),
+ )
+ )
+ parts.append(
+ TextArea(
+ f"{frac * 100:.1f} ",
+ textprops=dict(color="#111", fontsize=_MAP_FONT, fontfamily="monospace"),
+ )
+ )
+ parts.append(
+ TextArea(
+ f" (QC={qc * 100:.1f}%)",
+ textprops=dict(color="#555", fontsize=_MAP_FONT - 1, fontfamily="monospace"),
+ )
+ )
+ box = HPacker(children=parts, align="baseline", pad=0, sep=2)
+ ax.add_artist(
+ AnnotationBbox(
+ box,
+ (x_axes_frac, y_data),
+ xycoords=("axes fraction", "data"),
+ frameon=False,
+ box_alignment=(0, 0.5),
+ pad=0,
+ )
+ )
+
+
+def _render_optimized_row(
+ ax,
+ x_axes_frac: float,
+ y_data: float,
+ comp: dict[str, float],
+ qc: float,
+ deltas: dict[str, float],
+ arrows: dict[str, str],
+ element_counts: Counter,
+ n_outputs: int,
+ cmap,
+) -> None:
+ """Draw one *optimised* row: element symbols coloured by frequency in the optimised pool.
+
+ The parenthetical block is ``(QC=XX.X%, Δ=±N.N , ...)`` — the signed
+ delta tells the reader how much each property moved from its seed value, and the arrow
+ pins down whether the target wants it to go up or down.
+ """
+ if not comp:
+ return
+ items = sorted(comp.items(), key=lambda kv: -kv[1])
+ parts: list = []
+ for el, frac in items:
+ count = element_counts.get(el, 0)
+ # vmin=0 / vmax=n_outputs maps the lowest appearance count to the cmap's darkest end
+ # (per user request: "the lower, the closer to black"). Elements absent from the
+ # optimised pool can't actually appear in ``comp`` (we'd never iterate them here), so
+ # the ``count == 0`` branch is a defensive fallback only.
+ color = cmap(count / max(n_outputs, 1)) if count > 0 else "#aaaaaa"
+ parts.append(
+ TextArea(
+ el,
+ textprops=dict(color=color, fontweight="bold", fontsize=_MAP_FONT, fontfamily="monospace"),
+ )
+ )
+ parts.append(
+ TextArea(
+ f"{frac * 100:.1f} ",
+ textprops=dict(color="#111", fontsize=_MAP_FONT, fontfamily="monospace"),
+ )
+ )
+ # Parenthetical: QC + per-target signed delta + target-direction arrow. Use the short
+ # display labels so long names like ``formation_energy`` don't push the right edge of the
+ # axes into the colourbar.
+ delta_text = ", ".join(f"Δ{_REG_DISPLAY_SHORT.get(t, t)}={deltas[t]:+.2f} {arrows[t]}" for t in deltas)
+ parts.append(
+ TextArea(
+ f" (QC={qc * 100:.1f}%, {delta_text})",
+ textprops=dict(color="#555", fontsize=_MAP_FONT - 2, fontfamily="monospace"),
+ )
+ )
+ box = HPacker(children=parts, align="baseline", pad=0, sep=2)
+ ax.add_artist(
+ AnnotationBbox(
+ box,
+ (x_axes_frac, y_data),
+ xycoords=("axes fraction", "data"),
+ frameon=False,
+ box_alignment=(0, 0.5),
+ pad=0,
+ )
+ )
+
+
+def _plot_seed_to_optimized_mapping(
+ seeds: list[str],
+ decoded: list[str],
+ out_path: Path,
+ *,
+ title: str,
+ seed_qc: np.ndarray,
+ seed_reg: dict[str, np.ndarray],
+ optimized_qc: np.ndarray,
+ optimized_reg: dict[str, np.ndarray],
+ reg_targets: dict[str, float],
+) -> None:
+ """Per-seed 1:1 view — left column shows the seed, right column shows the optimiser's output.
+
+ Both compositions are normalised to fractions and rendered as percent (so the user-facing
+ numbers match the seed-side ``"Au65 Ga20 Gd15"`` convention).
+
+ * **Seed side** — all-black monochrome formula + ``(QC=XX.X%)``.
+ * **Optimised side** — element symbols coloured by their appearance count in the optimised
+ pool (cmap goes near-black for rare → bright yellow for ubiquitous, per the user's
+ "low end close to black" request). Parenthetical block carries QC% and per-target
+ signed deltas ``Δ=+/-N.N `` so the reader can match each delta's sign
+ against the optimisation direction at a glance.
+ * **Color bar** on the right shows the appearance-count scale used on the optimised side.
+
+ The intent is to complement the aggregated ``element_frequency_heatmap.png`` with per-seed
+ detail — which seed gave rise to which composition under each path, and whether each
+ target moved correctly.
+ """
+ n = len(seeds)
+ if n == 0 or len(decoded) != n:
+ logger.warning(
+ f"_plot_seed_to_optimized_mapping: seeds ({n}) / decoded ({len(decoded)}) mismatch — skipping plot."
+ )
+ return
+
+ seed_dicts = [_parse_formula_to_fractions(s) for s in seeds]
+ decoded_dicts = [_parse_formula_to_fractions(d) for d in decoded]
+
+ # Element-presence count over the optimised pool — drives the colour scale + colour bar.
+ element_counts: Counter = Counter()
+ for d in decoded_dicts:
+ for el in d:
+ element_counts[el] += 1
+
+ # ``inferno`` gives high contrast across the range with the low end close to black, as
+ # requested. ``vmin=0`` keeps the "rare" colour distinguishable from the "common" end.
+ cmap = plt.cm.inferno
+ norm = mcolors.Normalize(vmin=0, vmax=n)
+ arrows = {t: _target_arrow(v) for t, v in reg_targets.items()}
+
+ fig_height = max(6.5, _MAP_ROW_HEIGHT * n + 1.4)
+ # ``bbox_inches="tight"`` at savefig crops to actual artist extents, so the 20" width is a
+ # *minimum* — long parenthetical blocks (many reg targets, long element formulas) will
+ # stretch it further without colliding with the colour bar.
+ fig, (ax_main, ax_cbar) = plt.subplots(1, 2, figsize=(20, fig_height), gridspec_kw={"width_ratios": [70, 1]})
+ ax_main.set_xlim(0, 1)
+ ax_main.set_ylim(-0.7, n - 0.3)
+ ax_main.invert_yaxis()
+ ax_main.set_axis_off()
+
+ # Column headers above row 0 — also document what's in the parenthetical block, using the
+ # same short property names so the header matches each row's delta block exactly.
+ header_arrows = ", ".join(f"Δ{_REG_DISPLAY_SHORT.get(t, t)} {arrows[t]}" for t in reg_targets)
+ ax_main.text(
+ 0.005,
+ -0.6,
+ "Seed (fraction × 100, QC%)",
+ fontsize=_MAP_FONT,
+ fontweight="bold",
+ ha="left",
+ va="bottom",
+ )
+ ax_main.text(
+ 0.38,
+ -0.6,
+ f"Optimised composition (fraction × 100, QC%, {header_arrows})",
+ fontsize=_MAP_FONT,
+ fontweight="bold",
+ ha="left",
+ va="bottom",
+ )
+
+ for i, (s_dict, d_dict) in enumerate(zip(seed_dicts, decoded_dicts)):
+ _render_seed_row(ax_main, x_axes_frac=0.005, y_data=i, comp=s_dict, qc=float(seed_qc[i]))
+ ax_main.text(0.355, i, "→", fontsize=15, color="#888", ha="center", va="center")
+ deltas_i = {t: float(optimized_reg[t][i] - seed_reg[t][i]) for t in reg_targets}
+ _render_optimized_row(
+ ax_main,
+ x_axes_frac=0.38,
+ y_data=i,
+ comp=d_dict,
+ qc=float(optimized_qc[i]),
+ deltas=deltas_i,
+ arrows=arrows,
+ element_counts=element_counts,
+ n_outputs=n,
+ cmap=cmap,
+ )
+
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+ sm.set_array([])
+ cb = fig.colorbar(sm, cax=ax_cbar)
+ cb.set_label(f"Element appearance count\nin optimised pool (out of {n})", fontsize=_MAP_FONT - 2)
+ cb.ax.tick_params(labelsize=_MAP_FONT - 3)
+
+ fig.suptitle(title, fontsize=_MAP_FONT + 1, y=0.998)
+ fig.savefig(out_path, dpi=150, bbox_inches="tight")
+ plt.close(fig)
+ logger.info(f"Wrote seed→optimised mapping plot to {out_path}")
+
+
+# --- QC vs secondary-property scatter plot ----------------------------------------------------
+
+
+#: Marker shapes by method-group, per the user's "use shape to separate the two groups" request.
+#: Circle for latent (continuous α sweep ↦ a continuous family) vs triangle for composition
+#: (discrete-config family). Kept here as a single source of truth so the legend renderer and
+#: the scatter loop can't drift.
+_SCATTER_MARKERS = {"latent": "o", "composition": "^"}
+
+#: Per-group base colormaps. Greens vs Blues keep the two groups easily distinguishable at a
+#: glance (the user's "two groups' base colors must be easy to tell apart"). Within each group
+#: we step the colormap to encode the parameter-config ordering — see ``_group_color_ramp``.
+_SCATTER_CMAPS = {"latent": plt.cm.Greens, "composition": plt.cm.Blues}
+
+#: Seed-layer style: star marker + the project's discovered-element orange. Distinct shape and
+#: a third colour family (not Blues / Greens / red-target-lines) so the seed cloud reads as a
+#: separate "starting point" anchor without competing with the optimised clouds.
+_SEED_MARKER = "*"
+_SEED_COLOR = DISCOVERED_ELEMENT_COLOR # ``#E67E22`` — same orange used for new elements in the heatmap
+
+
+def _group_color_ramp(cmap, n: int) -> list:
+ """Evenly stepped colors across the upper portion of ``cmap``.
+
+ Skip the very pale low end (would be invisible on white) and the near-black high end
+ (would look the same across both groups). The 0.35 / 0.90 window matches the band used in
+ the seed-to-optimised plot's element shading.
+ """
+ if n <= 0:
+ return []
+ if n == 1:
+ return [cmap(0.65)]
+ return [cmap(0.35 + 0.55 * i / (n - 1)) for i in range(n)]
+
+
+def _plot_qc_vs_reg_scatter(
+ results: list[dict[str, Any]],
+ reg_targets: dict[str, float],
+ out_path: Path,
+ *,
+ title: str | None = None,
+ seed_qc: np.ndarray | None = None,
+ seed_reg: dict[str, np.ndarray] | None = None,
+) -> None:
+ """One panel per secondary regression target, plotting QC prob vs that target across all paths.
+
+ Each method's per-seed outputs become one scatter cluster: shape encodes the *group* (circle
+ for latent, triangle for composition — per the "use shape to separate the two groups" spec),
+ and color steps through that group's colormap (Greens / Blues) in label-order so the reader
+ can read the parameter sweep off the legend without remembering which α / config is which.
+ Red dashed lines mark the joint target (vertical at ``QC=1.0``, horizontal at the per-task
+ regression target). A figure-level legend at the bottom lists every method label once across
+ all panels.
+
+ When ``seed_qc`` and ``seed_reg`` are provided, the per-seed *baseline* predictions are also
+ drawn — as orange ★ stars — so the reader can see how far each method moved each seed in
+ QC-vs-secondary space. ``seed_reg`` must carry one array per key in ``reg_targets``; missing
+ keys silently skip the seed layer in that panel.
+ """
+ if not reg_targets:
+ logger.warning("_plot_qc_vs_reg_scatter: no reg_targets — skipping plot.")
+ return
+ if not results:
+ logger.warning("_plot_qc_vs_reg_scatter: no results — skipping plot.")
+ return
+
+ # Split results by group, preserving the order in which ``run()`` appended them — that's
+ # the same order the comparison bar chart uses, so the legend matches across figures.
+ latent_results = [r for r in results if r["method"] == "latent"]
+ comp_results = [r for r in results if r["method"] == "composition"]
+
+ # Per-group color ramps. Latent: Greens, low α → pale green, high α → deep green. Comp:
+ # Blues, simple-config → pale blue, full-knob config → deep blue.
+ latent_colors = _group_color_ramp(_SCATTER_CMAPS["latent"], len(latent_results))
+ comp_colors = _group_color_ramp(_SCATTER_CMAPS["composition"], len(comp_results))
+ color_by_result: dict[int, Any] = {}
+ for r, c in zip(latent_results, latent_colors):
+ color_by_result[id(r)] = c
+ for r, c in zip(comp_results, comp_colors):
+ color_by_result[id(r)] = c
+
+ # Seeds layer: drawn first so the optimised clouds overplot it (the seed cloud is the
+ # "context"; the optimised clouds are the headline data).
+ has_seeds = seed_qc is not None and seed_reg is not None
+ seed_qc_arr = np.asarray(seed_qc, dtype=float) if has_seeds else None
+
+ n_panels = len(reg_targets)
+ fig, axes = plt.subplots(1, n_panels, figsize=(5.6 * n_panels, 6.4), squeeze=False)
+ axes = axes[0]
+
+ for ax, (task, tgt) in zip(axes, reg_targets.items()):
+ arrow = _target_arrow(tgt)
+ # Seeds first (under) — only if seed_reg has this panel's task.
+ if has_seeds and task in seed_reg:
+ seed_reg_arr = np.asarray(seed_reg[task], dtype=float)
+ ax.scatter(
+ seed_qc_arr,
+ seed_reg_arr,
+ marker=_SEED_MARKER,
+ color=_SEED_COLOR,
+ s=110,
+ alpha=0.85,
+ edgecolor="#222",
+ linewidths=0.7,
+ zorder=2,
+ )
+ for r in results:
+ qc = np.asarray(r["qc_after_decode"], dtype=float)
+ reg = np.asarray(r["reg_after_decode"][task], dtype=float)
+ ax.scatter(
+ qc,
+ reg,
+ marker=_SCATTER_MARKERS[r["method"]],
+ color=color_by_result[id(r)],
+ s=64,
+ alpha=0.78,
+ edgecolor="#222",
+ linewidths=0.6,
+ label=r["label"].replace("\n", " "),
+ zorder=3,
+ )
+ ax.axvline(1.0, color="#C44E52", ls="--", lw=1.3, alpha=0.8)
+ ax.axhline(tgt, color="#C44E52", ls="--", lw=1.3, alpha=0.8)
+ ax.set_xlim(-0.05, 1.05)
+ ax.set_xlabel("P(quasicrystal) ↑")
+ ax.set_ylabel(REG_TASK_TITLES.get(task, task))
+ ax.set_title(f"QC vs {_REG_DISPLAY_SHORT.get(task, task)} {arrow} (target = {tgt:+.1f})", fontsize=11)
+
+ # Figure-level legend across all panels. Use proxy handles so the legend orders by group
+ # (seeds → latent → composition → target) rather than by whichever panel happened to draw
+ # which marker first.
+ from matplotlib.lines import Line2D
+
+ handles: list[Line2D] = []
+ if has_seeds:
+ handles.append(
+ Line2D(
+ [0],
+ [0],
+ marker=_SEED_MARKER,
+ color="none",
+ markerfacecolor=_SEED_COLOR,
+ markeredgecolor="#222",
+ markersize=11,
+ label="seed (baseline)",
+ )
+ )
+ for r in latent_results:
+ handles.append(
+ Line2D(
+ [0],
+ [0],
+ marker=_SCATTER_MARKERS["latent"],
+ color="none",
+ markerfacecolor=color_by_result[id(r)],
+ markeredgecolor="#222",
+ markersize=9,
+ label=r["label"].replace("\n", " "),
+ )
+ )
+ for r in comp_results:
+ handles.append(
+ Line2D(
+ [0],
+ [0],
+ marker=_SCATTER_MARKERS["composition"],
+ color="none",
+ markerfacecolor=color_by_result[id(r)],
+ markeredgecolor="#222",
+ markersize=9,
+ label=r["label"].replace("\n", " "),
+ )
+ )
+ handles.append(Line2D([0], [0], color="#C44E52", ls="--", lw=1.3, label="target (QC=1.0 / reg-target)"))
+ # ncol picked so the legend fits across the figure width without wrapping past 3 rows for
+ # the 8-method + 1-target sweep we use in practice.
+ fig.legend(
+ handles=handles,
+ loc="lower center",
+ ncol=min(len(handles), 4),
+ fontsize=9,
+ frameon=False,
+ bbox_to_anchor=(0.5, -0.02),
+ )
+
+ if title:
+ fig.suptitle(title, y=1.00)
+ # Leave generous bottom padding so the legend (rendered below the axes via bbox_to_anchor)
+ # ends up inside the saved bbox after ``bbox_inches="tight"`` crops.
+ fig.tight_layout(rect=(0, 0.10, 1, 0.98))
+ fig.savefig(out_path, dpi=150, bbox_inches="tight")
+ plt.close(fig)
+ logger.info(f"Wrote QC-vs-secondary scatter plot to {out_path}")
+
+
+def _path_slug(r: dict[str, Any]) -> str:
+ """Stable filename slug for one path. Latent: ``latent_align0p25``; comp: cleaned label."""
+ if r["method"] == "latent":
+ return f"latent_align{r['align_scale']:g}".replace(".", "p")
+ return re.sub(r"[^a-z0-9]+", "_", r["label"].lower()).strip("_")
+
+
+def _summarise(results: list[dict[str, Any]], reg_targets: dict[str, float]) -> list[dict[str, Any]]:
+ summary = []
+ for r in results:
+ row = {
+ "label": r["label"].replace("\n", " "),
+ "method": r["method"],
+ "align_scale": r.get("align_scale"),
+ "config": r.get("config"),
+ "elapsed_s": round(r["elapsed_s"], 2),
+ "qc_after_mean": round(float(np.mean(r["qc_after_decode"])), 4),
+ "qc_after_std": round(float(np.std(r["qc_after_decode"])), 4),
+ }
+ for t in reg_targets:
+ row[f"{t}_after_mean"] = round(float(np.mean(r["reg_after_decode"][t])), 3)
+ row[f"{t}_after_std"] = round(float(np.std(r["reg_after_decode"][t])), 3)
+ summary.append(row)
+ return summary
+
+
+def run(
+ config: ContinualRehearsalConfig,
+ ckpt_path: Path,
+ *,
+ record_trajectory: bool = True,
+ per_seed_trajectories: bool = False,
+ animation_formats: tuple[str, ...] = ("gif",),
+) -> None:
+ seed_everything(config.random_seed, workers=True)
+ runner = ContinualRehearsalRunner(config)
+
+ # Load the trained model exactly as we built it during training (same task_sequence).
+ model = runner._build_full_model()
+ state = torch.load(ckpt_path, map_location="cpu", weights_only=True)
+ state_dict = state["model"] if isinstance(state, dict) and "model" in state else state
+ model.load_state_dict(state_dict)
+ model.eval()
+
+ out_dir = Path(config.output_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+ # Copy the checkpoint so this folder is a self-contained paper artefact (skip when
+ # the source and destination resolve to the same file — happens on idempotent reruns).
+ dst = out_dir / "final_model.pt"
+ if ckpt_path.resolve() != dst.resolve():
+ shutil.copy2(ckpt_path, dst)
+
+ device = next(model.parameters()).device
+
+ def _qc_prob_fn(x: torch.Tensor) -> np.ndarray:
+ return _qc_prob(model, x)
+
+ seeds = runner._select_seeds(model, device, _qc_prob_fn)
+ if not seeds:
+ raise RuntimeError("No seed compositions selected.")
+ x_seed, seeds = runner._descriptor_tensor(seeds, device)
+ (out_dir / "seeds.json").write_text(json.dumps({"seeds": list(seeds)}, indent=2), encoding="utf-8")
+ logger.info(f"Selected {len(seeds)} seed compositions (saved to seeds.json)")
+
+ reg_targets = {t: v for t, v in zip(config.inverse_reg_tasks, config.inverse_reg_targets)}
+ # Per-seed *baseline* predictions (before any inverse-design optimisation). These power the
+ # seed-side ``(QC=X.X%)`` parenthetical and the ``Δ`` deltas on the optimised side of
+ # the per-seed mapping plot. Computed once here against ``x_seed`` (the seed descriptors)
+ # and persisted in ``results.json`` under ``seed_predictions`` so future re-plots don't need
+ # the model loaded again.
+ seed_qc = _qc_prob(model, x_seed)
+ seed_reg = _reg_preds(model, x_seed, list(reg_targets.keys()))
+ results: list[dict[str, Any]] = []
+
+ # Latent method: ae_align_scale sweep over [0, 1].
+ for lam in LATENT_ALIGN_SCALES:
+ logger.info(f"--- Latent method, ae_align_scale = {lam} ---")
+ r = _run_latent_method(
+ runner,
+ model,
+ seeds,
+ x_seed,
+ reg_targets,
+ class_weight=config.inverse_class_weight,
+ align_scale=lam,
+ steps=config.inverse_steps,
+ lr=config.inverse_lr,
+ record_trajectory=record_trajectory,
+ )
+ r["label"] = f"latent\nα={lam:g}"
+ r["config"] = {"ae_align_scale": lam}
+ results.append(r)
+
+ # Composition method: walk through the configuration matrix.
+ for cfg in COMPOSITION_CONFIGS:
+ logger.info(f"--- {cfg['label'].replace(chr(10), ' ')} ---")
+ r = _run_composition_config(
+ runner,
+ model,
+ seeds,
+ reg_targets,
+ class_weight=config.inverse_class_weight,
+ steps=config.inverse_steps,
+ lr=config.inverse_lr,
+ cfg=cfg,
+ record_trajectory=record_trajectory,
+ )
+ r["label"] = cfg["label"]
+ r["config"] = {k: cfg[k] for k in ("init", "blend", "allowed", "scale", "diversity")}
+ results.append(r)
+
+ summary = _summarise(results, reg_targets)
+ logger.info("=== Summary ===")
+ for row in summary:
+ logger.info(row)
+
+ # Trajectory arrays would blow up the inlined results.json (≈ 36MB / scenario for a 300-step
+ # 20-seed run); persist them as compressed .npz next to results.json and replace the inline
+ # lists with a relative-path reference. The JSON stays browsable; replots read the .npz.
+ traj_dir: Path | None = None
+ if record_trajectory:
+ traj_dir = out_dir / "trajectories"
+ traj_dir.mkdir(exist_ok=True)
+ for r in results:
+ if "trajectory_targets" not in r:
+ continue
+ slug = _path_slug(r)
+ npz_path = traj_dir / f"{slug}.npz"
+ np.savez_compressed(
+ npz_path,
+ targets=np.asarray(r["trajectory_targets"], dtype=np.float32),
+ weights=np.asarray(r["trajectory_weights"], dtype=np.float32),
+ )
+ r["trajectory_file"] = str(npz_path.relative_to(out_dir))
+ del r["trajectory_targets"]
+ del r["trajectory_weights"]
+ logger.info(f"Wrote per-path trajectory arrays under {traj_dir}/")
+
+ (out_dir / "results.json").write_text(
+ json.dumps(
+ {
+ "reg_targets": reg_targets,
+ # ``seed_predictions`` carries the baseline predictions the inverse-design
+ # optimisation moved away from — needed to render the per-seed mapping plot's
+ # ``Δ`` deltas (and the seed-side ``QC%`` parenthetical). Save here so a
+ # future re-plot from results.json alone never has to re-run the model.
+ "seed_predictions": {
+ "qc": seed_qc.tolist(),
+ "reg": {t: vals.tolist() for t, vals in seed_reg.items()},
+ },
+ "results": results,
+ "summary": summary,
+ },
+ indent=2,
+ ),
+ encoding="utf-8",
+ )
+ _plot_comparison(results, reg_targets, out_dir / "comparison.png")
+ # Per-method × top-25-element occurrence heatmap. Always written so the discovered-element
+ # signal (bold orange on the x-axis) is part of every paper-comparison output — the slide
+ # author / downstream reader doesn't need to find or rerun a separate post-hoc script.
+ plot_element_frequency_heatmap(results, list(seeds), out_dir / "element_frequency_heatmap.png")
+ # Seed → optimised 1:1 mapping plot. One figure per path that has per-seed correspondence
+ # (every method except ``comp (random)``, whose ``seeds`` field is a ``random_start_N``
+ # placeholder rather than a real composition). Each plot's right side carries the QC% and
+ # per-target signed deltas so the reader can see *which seed gave rise to which output*
+ # and whether each target moved in the right direction.
+ for r in results:
+ if r["method"] == "composition" and r.get("config", {}).get("init") != "seed":
+ # ``comp (random)`` — no per-row seed correspondence.
+ continue
+ slug = _path_slug(r)
+ _plot_seed_to_optimized_mapping(
+ seeds=list(seeds),
+ decoded=list(r["decoded_composition"]),
+ out_path=out_dir / f"seed_to_optimized__{slug}.png",
+ title=f"Seed → optimised composition · {r['label'].replace(chr(10), ' ')}",
+ seed_qc=seed_qc,
+ seed_reg=seed_reg,
+ optimized_qc=np.asarray(r["qc_after_decode"]),
+ optimized_reg={t: np.asarray(r["reg_after_decode"][t]) for t in reg_targets},
+ reg_targets=reg_targets,
+ )
+ # Scatter view of QC prob vs each secondary reg target, grouped by method (latent = circle /
+ # green ramp, composition = triangle / blue ramp), with the per-seed baseline drawn as orange
+ # ★ stars so the reader sees how far each method moved each seed. Complements the bar chart:
+ # the bar chart collapses each method to a mean ± std, the scatter shows the per-seed cloud.
+ _plot_qc_vs_reg_scatter(
+ results,
+ reg_targets,
+ out_dir / "qc_vs_secondary_scatter.png",
+ title="QC probability vs secondary properties (per-seed outputs)",
+ seed_qc=seed_qc,
+ seed_reg=seed_reg,
+ )
+ # Per-step optimisation trajectory plots + animations. One figure (and one animation) per
+ # path; ``--per-seed-trajectories`` additionally emits per-seed variants. Skipped when
+ # ``--no-record-trajectory`` was passed (results.json carries no trajectory_file refs then).
+ if record_trajectory and traj_dir is not None:
+ _emit_trajectory_outputs(
+ results=results,
+ reg_targets=reg_targets,
+ seeds=list(seeds),
+ seed_qc=seed_qc,
+ seed_reg=seed_reg,
+ out_dir=out_dir,
+ traj_dir=traj_dir,
+ per_seed=per_seed_trajectories,
+ animation_formats=animation_formats,
+ )
+ # The auto-generated README is a compact summary table only. It writes to ``SUMMARY.md``
+ # (not ``README.md``) so a user-written index — pointing to every figure, file, and the
+ # full ANALYSIS.md — can live at ``README.md`` without being overwritten on rerun.
+ _write_readme(out_dir, summary, reg_targets, ckpt_path)
+ logger.info(f"Paper materials written to {out_dir}")
+
+
+def _emit_trajectory_outputs(
+ *,
+ results: list[dict[str, Any]],
+ reg_targets: dict[str, float],
+ seeds: list[str],
+ seed_qc: np.ndarray,
+ seed_reg: dict[str, np.ndarray],
+ out_dir: Path,
+ traj_dir: Path,
+ per_seed: bool,
+ animation_formats: tuple[str, ...],
+) -> None:
+ """Render the static "normalised-progress vs step" plot + animation per path.
+
+ Always-on: a mean across-seeds line plot per path under ``trajectories/`` with the comp panel
+ animated using the seed whose final state best matches all targets (joint normalised distance).
+ The chosen seed's composition formula is shown under the title.
+
+ ``per_seed=True`` (the new default) also emits one plot+animation per ``(path × seed)`` under
+ ``trajectories_per_seed/seed{NN}/.{png,gif,html}`` — **seed-major** layout chosen so the
+ user can compare the same seed across all 8 paths by opening one folder. The seed's composition
+ string is rendered under each title so the reader doesn't have to cross-reference seed indices
+ against ``seeds.json``.
+
+ ``animation_formats`` defaults to ``("gif",)``; pass extras (``html``, ``svg``) to also emit
+ them. ``"none"`` in the format list disables animations entirely (static plot still emitted).
+ """
+ from foundation_model.scripts.paper_inverse_trajectory import (
+ best_seed_by_target_distance,
+ normalize_target_trajectories,
+ plot_trajectory_animation,
+ plot_trajectory_static,
+ )
+ from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS
+
+ formats: list[str] = [f for f in animation_formats if f != "none"]
+ static_dir = out_dir / "trajectories"
+ static_dir.mkdir(exist_ok=True)
+ per_seed_dir = out_dir / "trajectories_per_seed" if per_seed else None
+ if per_seed_dir is not None:
+ per_seed_dir.mkdir(exist_ok=True)
+
+ for r in results:
+ if "trajectory_file" not in r:
+ continue
+ slug = _path_slug(r)
+ npz_path = out_dir / r["trajectory_file"]
+ with np.load(npz_path) as data:
+ traj_targets = np.asarray(data["targets"]) # (steps, B, T_reg)
+ traj_weights = np.asarray(data["weights"]) # (steps, B, n_components)
+ # The composition method's ``trajectory_targets`` is the in-loop reg-only predictions
+ # (T_reg = len(reg_targets)); for the QC trajectory we replay the qc_after_decode per
+ # step. That requires running ``_qc_prob`` on each step's descriptor — but the npz only
+ # stores weights, not descriptors. Fast path: each step's QC ≈ qc_after_decode is well
+ # approximated by reusing the qc_after_decode of the final step as a fixed line + (initial
+ # − final) as a linear ramp would be wrong. So we just reconstruct the per-step QC
+ # trajectory by *not* including QC when it isn't in the npz; the static plot still works
+ # with reg-only progress. For the inverse-design study this is the right signal anyway —
+ # the user asked "do the reg targets converge together?" and the QC line is best read off
+ # the separate ``comparison.png``.
+ reg_names = list(reg_targets)
+ # Mean reg trajectory across seeds (per step → per task).
+ reg_traj_dict: dict[str, np.ndarray] = {
+ t: traj_targets[:, :, j] for j, t in enumerate(reg_names)
+ }
+ # Mean variant: use QC after-decode (final value) as a flat baseline-vs-target progress
+ # line only if it's available; otherwise drop QC. For the inverse-design case QC is in
+ # results dict but not per-step; we synthesise a "flat" QC progress line from the final
+ # value so it shows up on the chart for context.
+ qc_after = np.asarray(r["qc_after_decode"], dtype=float)
+ qc_traj = np.tile(qc_after[None, :], (traj_targets.shape[0], 1)) # (steps, B); only end-state QC
+ progress_mean = normalize_target_trajectories(
+ qc_trajectory=qc_traj,
+ reg_trajectory=reg_traj_dict,
+ reg_targets=reg_targets,
+ seed_qc=seed_qc,
+ seed_reg=seed_reg,
+ )
+ # The QC entry is degenerate (flat ≈ end-state); drop it from the static plot to avoid
+ # misleading the reader. The animation also keeps reg-only.
+ progress_mean.pop("QC", None)
+
+ # Pick the best representative seed for the animation's comp panel.
+ reg_final_per_task = {t: np.asarray(r["reg_after_decode"][t], dtype=float) for t in reg_names}
+ best_idx = best_seed_by_target_distance(qc_after, reg_final_per_task, reg_targets)
+ per_step_weights_best = traj_weights[:, best_idx, :] # (steps, n_components)
+ # Map the path's per-row "seeds" entry to a comp string. For comp_random the entry is
+ # ``random_start_N`` placeholder text; surface it verbatim so the title still says where
+ # the row came from. The ``r["seeds"]`` carried by every path is exactly the per-row
+ # label sequence; fall back to the shared ``seeds`` arg if a path forgot to set it.
+ per_row_seeds = list(r.get("seeds", seeds))
+
+ # --- Static plot (mean across seeds) ---
+ static_out = static_dir / f"trajectory__{slug}.png"
+ plot_trajectory_static(
+ progress_mean,
+ static_out,
+ title=f"Optimisation trajectory · {r['label'].replace(chr(10), ' ')} (mean over {qc_after.shape[0]} seeds)",
+ )
+
+ # --- Animation (mean curves + best-seed comp panel) ---
+ if formats:
+ out_paths = {fmt: static_dir / f"trajectory__{slug}.{fmt}" for fmt in formats}
+ plot_trajectory_animation(
+ progress_mean,
+ per_step_weights_best,
+ element_symbols=list(DEFAULT_ELEMENTS),
+ out_paths_by_format=out_paths,
+ title=f"Trajectory · {r['label'].replace(chr(10), ' ')} (best seed: {best_idx})",
+ seed_composition=per_row_seeds[best_idx],
+ )
+
+ # --- Per-seed variants (seed-major layout: trajectories_per_seed/seed{NN}/.{ext}) ---
+ if per_seed_dir is not None:
+ for seed_i in range(qc_after.shape[0]):
+ seed_dir = per_seed_dir / f"seed{seed_i:02d}"
+ seed_dir.mkdir(exist_ok=True)
+ seed_comp = per_row_seeds[seed_i]
+ reg_traj_one_seed = {t: traj_targets[:, seed_i : seed_i + 1, j] for j, t in enumerate(reg_names)}
+ qc_traj_one_seed = qc_traj[:, seed_i : seed_i + 1]
+ progress_seed = normalize_target_trajectories(
+ qc_trajectory=qc_traj_one_seed,
+ reg_trajectory=reg_traj_one_seed,
+ reg_targets=reg_targets,
+ seed_qc=seed_qc[seed_i : seed_i + 1],
+ seed_reg={t: vals[seed_i : seed_i + 1] for t, vals in seed_reg.items()},
+ )
+ progress_seed.pop("QC", None)
+ seed_static = seed_dir / f"{slug}.png"
+ plot_trajectory_static(
+ progress_seed,
+ seed_static,
+ title=f"{r['label'].replace(chr(10), ' ')} · seed {seed_i}",
+ seed_composition=seed_comp,
+ )
+ if formats:
+ seed_out_paths = {fmt: seed_dir / f"{slug}.{fmt}" for fmt in formats}
+ plot_trajectory_animation(
+ progress_seed,
+ traj_weights[:, seed_i, :],
+ element_symbols=list(DEFAULT_ELEMENTS),
+ out_paths_by_format=seed_out_paths,
+ title=f"{r['label'].replace(chr(10), ' ')} · seed {seed_i}",
+ seed_composition=seed_comp,
+ )
+
+
+def _run_composition_config(
+ runner: ContinualRehearsalRunner,
+ model,
+ seeds: list[str],
+ reg_targets: dict[str, float],
+ *,
+ class_weight: float,
+ steps: int,
+ lr: float,
+ cfg: dict[str, Any],
+ record_trajectory: bool = False,
+) -> dict[str, Any]:
+ """Run :meth:`optimize_composition` under one config row (handles seed/random init both)."""
+ import time
+
+ from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS
+
+ device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
+ kernel = runner._kmd.kernel_torch(device=device, dtype=dtype)
+
+ if cfg["init"] == "seed":
+ w_seed = _seed_weights_from_compositions(seeds, n_components=len(DEFAULT_ELEMENTS))
+ init_kwargs = {"initial_weights": w_seed, "seed_blend": cfg["blend"]}
+ elif cfg["init"] == "random":
+ # n_starts matches the seed count so per-row aggregation lines up with the latent runs.
+ init_kwargs = {"initial_weights": None, "n_starts": len(seeds)}
+ else:
+ raise ValueError(f"Unknown init mode in config: {cfg['init']!r}")
+
+ t0 = time.perf_counter()
+ res = model.optimize_composition(
+ kernel,
+ task_targets=reg_targets,
+ class_targets={"material_type": QC_CLASSES},
+ class_target_weight=class_weight,
+ diversity_scale=cfg["diversity"],
+ allowed_elements=cfg["allowed"],
+ element_step_scale=cfg["scale"],
+ steps=steps,
+ lr=lr,
+ record_weights_trajectory=record_trajectory,
+ **init_kwargs,
+ )
+ elapsed = time.perf_counter() - t0
+
+ reg_names = list(reg_targets)
+ optimized_desc = res.optimized_descriptor
+ w_final = res.optimized_weights.cpu().numpy()
+ out = {
+ "method": "composition",
+ "align_scale": None,
+ "elapsed_s": elapsed,
+ # For random init the "seeds" entry is informational only — there's no per-row correspondence.
+ "seeds": list(seeds) if cfg["init"] == "seed" else [f"random_start_{i}" for i in range(len(seeds))],
+ "qc_after_decode": _qc_prob(model, optimized_desc).tolist(),
+ "reg_achieved_latent": {t: res.optimized_target.cpu().numpy()[:, j].tolist() for j, t in enumerate(reg_names)},
+ "reg_after_decode": {t: _reg_preds(model, optimized_desc, [t])[t].tolist() for t in reg_names},
+ "decoded_composition": _format_weights(w_final),
+ # Raw arrays — keep so future replots (per-element bar charts, similarity matrices, etc.)
+ # don't have to re-run the optimisation. ``optimized_weights`` is (B, n_components),
+ # ``optimized_descriptor`` is (B, x_dim); element order matches DEFAULT_ELEMENTS.
+ "optimized_descriptor": optimized_desc.detach().cpu().numpy().tolist(),
+ "optimized_weights": w_final.tolist(),
+ }
+ if record_trajectory:
+ # ``res.trajectory`` is (steps, B, T) in reg-task order — already on the right surface.
+ # ``res.weights_trajectory`` is (steps, B, n_components) and is the per-step recipe
+ # exactly (no decode needed — composition method's optim variable already lives there).
+ out["trajectory_targets"] = res.trajectory.cpu().numpy().tolist()
+ out["trajectory_weights"] = res.weights_trajectory.cpu().numpy().tolist()
+ return out
+
+
+def _write_readme(out_dir: Path, summary: list[dict[str, Any]], reg_targets: dict[str, float], ckpt_path: Path) -> None:
+ lines = [
+ "# Inverse-design method comparison — paper materials",
+ "",
+ f"Trained model: `final_model.pt` (copied from `{ckpt_path}`).",
+ "Seed compositions: top-QC training compositions, listed in `seeds.json`.",
+ f"Targets: QC probability → 1.0; {', '.join(f'{t} → {v:+.1f}' for t, v in reg_targets.items())}.",
+ "",
+ "Raw per-seed JSON: `results.json` (one entry per method+config).",
+ "Comparison figure: `comparison.png`.",
+ "",
+ "## Summary (mean ± std across seeds)",
+ "",
+ "| label | QC after | " + " | ".join(f"{t} after" for t in reg_targets) + " | elapsed (s) |",
+ "| --- | --- | " + " | ".join("---" for _ in reg_targets) + " | --- |",
+ ]
+ for row in summary:
+ qc_cell = f"{row['qc_after_mean']:.3f} ± {row['qc_after_std']:.3f}"
+ reg_cells = [f"{row[f'{t}_after_mean']:+.2f} ± {row[f'{t}_after_std']:.2f}" for t in reg_targets]
+ lines.append(f"| {row['label']} | {qc_cell} | " + " | ".join(reg_cells) + f" | {row['elapsed_s']} |")
+ (out_dir / "SUMMARY.md").write_text("\n".join(lines) + "\n", encoding="utf-8")
+
+
+def _parse_args(argv: list[str] | None = None) -> tuple[ContinualRehearsalConfig, argparse.Namespace]:
+ parser = argparse.ArgumentParser(description="Paper-grade inverse-design comparison.")
+ parser.add_argument("--config-file", type=Path, required=True)
+ parser.add_argument("--checkpoint", type=Path, required=True)
+ parser.add_argument("--output-dir", type=Path, required=True)
+ parser.add_argument(
+ "--record-trajectory",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help=(
+ "Record per-step optimisation trajectory (target predictions + per-step composition) "
+ "per path. Adds ~10–30 % runtime + a few MB of disk per scenario but enables the "
+ "trajectory_* plots and animations. Default: on. Use --no-record-trajectory to skip."
+ ),
+ )
+ parser.add_argument(
+ "--per-seed-trajectories",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help=(
+ "Also emit one trajectory plot + animation per (path × seed) under "
+ "trajectories_per_seed/seed{NN}/.{png,gif,html} (seed-major layout — easier "
+ "to compare paths for one seed). Default: on. Adds ~480 PNGs / scenario plus 480 GIFs "
+ "(~1GB) + 480 HTMLs (~5GB) if both anim formats are on; use --no-per-seed-trajectories "
+ "to skip when you only need the across-seed-mean view."
+ ),
+ )
+ parser.add_argument(
+ "--animation-formats",
+ nargs="+",
+ choices=["gif", "html", "svg", "none"],
+ default=["gif"],
+ help=(
+ "Animation output formats. ``gif`` (default) uses matplotlib's Pillow writer; "
+ "``html`` emits an interactive JS-controlled HTML file (matplotlib HTMLWriter); "
+ "``svg`` emits a SMIL-animated single-file SVG; ``none`` disables animations "
+ "(static plot still emitted). Multi-select supported, e.g. --animation-formats gif html."
+ ),
+ )
+ args = parser.parse_args(argv)
+
+ import tomllib
+
+ data = tomllib.loads(args.config_file.read_text(encoding="utf-8"))
+ data["output_dir"] = str(args.output_dir)
+ field_names = set(ContinualRehearsalConfig.__dataclass_fields__)
+ path_fields = {
+ "qc_data_path",
+ "qc_preprocessing_path",
+ "superconductor_path",
+ "magnetic_path",
+ "phonix_path",
+ "output_dir",
+ }
+ kwargs: dict[str, object] = {}
+ for key, value in data.items():
+ if key not in field_names:
+ continue
+ kwargs[key] = Path(value) if key in path_fields and value is not None else value
+ return ContinualRehearsalConfig(**kwargs), args
+
+
+def main(argv: list[str] | None = None) -> None:
+ config, args = _parse_args(argv)
+ run(
+ config,
+ args.checkpoint,
+ record_trajectory=args.record_trajectory,
+ per_seed_trajectories=args.per_seed_trajectories,
+ animation_formats=tuple(args.animation_formats),
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/foundation_model/scripts/paper_inverse_comparison_test.py b/src/foundation_model/scripts/paper_inverse_comparison_test.py
new file mode 100644
index 0000000..9c00393
--- /dev/null
+++ b/src/foundation_model/scripts/paper_inverse_comparison_test.py
@@ -0,0 +1,211 @@
+# Copyright 2025 TsumiNa.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Tests for the pure helpers in :mod:`paper_inverse_comparison`.
+
+The main ``run()`` function needs a trained checkpoint + KMD kernel to exercise end-to-end (see
+the smoke runs under ``artifacts/inverse_design_run/``); this file targets the *units that don't*
+need either — the formula parser, and the two output plot helpers we added in this PR.
+"""
+
+from __future__ import annotations
+
+import numpy as np
+
+from foundation_model.scripts.paper_inverse_comparison import (
+ _parse_formula_to_fractions,
+ _plot_qc_vs_reg_scatter,
+ _plot_seed_to_optimized_mapping,
+ _target_arrow,
+)
+
+
+# --- _parse_formula_to_fractions ----------------------------------------------------------
+
+
+def test_parse_raw_amount_formula_normalises_to_fractions():
+ # Seeds typically come in raw-amount form like "Au65 Ga20 Gd15"; the parser must normalise
+ # so the same downstream code can read it as fractions.
+ out = _parse_formula_to_fractions("Au65 Ga20 Gd15")
+ assert sorted(out.keys()) == ["Au", "Ga", "Gd"]
+ assert abs(sum(out.values()) - 1.0) < 1e-12
+ assert abs(out["Au"] - 0.65) < 1e-12
+ assert abs(out["Ga"] - 0.20) < 1e-12
+ assert abs(out["Gd"] - 0.15) < 1e-12
+
+
+def test_parse_pre_fractional_formula_kept_as_fractions():
+ # Decoded compositions land here in fractional form ("Mg0.691 Cd0.309 …"); they must round-trip.
+ out = _parse_formula_to_fractions("Mg0.691 Cd0.309")
+ assert abs(sum(out.values()) - 1.0) < 1e-12
+ assert abs(out["Mg"] - 0.691) < 1e-12
+ assert abs(out["Cd"] - 0.309) < 1e-12
+
+
+def test_parse_handles_missing_amount_as_unit():
+ # A bare element symbol ("Mg") gets unit amount, then normalised.
+ out = _parse_formula_to_fractions("Mg Cu Ni")
+ # 3 elements, equal amounts, fractions = 1/3 each.
+ assert sorted(out.keys()) == ["Cu", "Mg", "Ni"]
+ for v in out.values():
+ assert abs(v - 1.0 / 3.0) < 1e-12
+
+
+def test_parse_empty_formula_returns_empty_dict():
+ assert _parse_formula_to_fractions("") == {}
+
+
+# --- _target_arrow --------------------------------------------------------------------------
+
+
+def test_target_arrow_up_for_positive_target():
+ """Target above baseline ⇒ ↑ (optimisation drives the value up)."""
+ assert _target_arrow(2.0) == "↑"
+ assert _target_arrow(0.1) == "↑"
+
+
+def test_target_arrow_down_for_negative_or_zero_target():
+ """Target at or below baseline ⇒ ↓. The convention treats 0 as "no clear up direction"."""
+ assert _target_arrow(-2.0) == "↓"
+ assert _target_arrow(0.0) == "↓"
+
+
+# --- _plot_seed_to_optimized_mapping ------------------------------------------------------
+
+
+def _mapping_kwargs(seeds: list[str], decoded: list[str]) -> dict:
+ """Reasonable defaults for the helper's per-seed QC / reg arguments.
+
+ Tests don't care about specific numbers — they just need arrays the same length as the
+ seed list. Reg-target names map to the project's plan §5 targets.
+ """
+ n = len(seeds)
+ return dict(
+ seed_qc=np.full(n, 0.5),
+ seed_reg={"formation_energy": np.full(n, 0.3), "klat": np.full(n, 0.1)},
+ optimized_qc=np.full(n, 0.9),
+ optimized_reg={"formation_energy": np.full(n, -0.5), "klat": np.full(n, 1.6)},
+ reg_targets={"formation_energy": -2.0, "klat": 2.0},
+ )
+
+
+def test_plot_seed_to_optimized_mapping_writes_png(tmp_path):
+ seeds = [
+ "Mg12 Cu3 Ni3",
+ "Au65 Ga20 Gd15",
+ "Al6 Co1 Cu3",
+ ]
+ decoded = [
+ "Mg0.50 Cu0.30 Ni0.20",
+ "Au0.55 Ga0.30 Gd0.15",
+ "Al0.60 Pd0.20 Ti0.20", # introduces Pd / Ti not in seeds
+ ]
+ out = tmp_path / "seed_to_optimized.png"
+ _plot_seed_to_optimized_mapping(seeds, decoded, out, title="test scenario", **_mapping_kwargs(seeds, decoded))
+ assert out.exists()
+
+
+def test_plot_seed_to_optimized_mapping_skips_on_length_mismatch(tmp_path):
+ """Mismatched seeds / decoded lengths must not crash — log a warning and skip the write."""
+ out = tmp_path / "should_not_exist.png"
+ _plot_seed_to_optimized_mapping(
+ ["Mg1 Cu1"], ["Mg0.5 Cu0.5", "Al1.0"], out, title="bad", **_mapping_kwargs(["Mg1 Cu1"], ["Mg0.5 Cu0.5"])
+ )
+ assert not out.exists()
+
+
+def test_plot_seed_to_optimized_mapping_skips_on_empty(tmp_path):
+ out = tmp_path / "should_not_exist.png"
+ _plot_seed_to_optimized_mapping([], [], out, title="empty", **_mapping_kwargs([], []))
+ assert not out.exists()
+
+
+# --- _plot_qc_vs_reg_scatter ----------------------------------------------------------------
+
+
+def _scatter_result(method: str, label: str, n: int = 6, **extra) -> dict:
+ """Minimal ``results`` row shape consumed by ``_plot_qc_vs_reg_scatter``.
+
+ Only the fields the scatter helper reads are populated — ``method``, ``label``,
+ ``qc_after_decode``, and ``reg_after_decode``. Numbers are arbitrary; the test asserts
+ the helper writes a PNG without raising.
+ """
+ rng = np.random.default_rng(abs(hash(label)) % (2**31))
+ return {
+ "method": method,
+ "label": label,
+ "qc_after_decode": rng.uniform(0.2, 0.95, size=n).tolist(),
+ "reg_after_decode": {
+ "formation_energy": rng.uniform(-1.5, -0.2, size=n).tolist(),
+ "klat": rng.uniform(0.5, 2.2, size=n).tolist(),
+ },
+ **extra,
+ }
+
+
+def test_plot_qc_vs_reg_scatter_writes_png(tmp_path):
+ """End-to-end smoke: latent + composition results, two reg targets, expect a PNG out."""
+ results = [
+ _scatter_result("latent", "latent\nα=0"),
+ _scatter_result("latent", "latent\nα=0.25"),
+ _scatter_result("latent", "latent\nα=1"),
+ _scatter_result("composition", "comp\n(seed)"),
+ _scatter_result("composition", "comp\n(seed, 5% all)"),
+ ]
+ reg_targets = {"formation_energy": -2.0, "klat": 2.0}
+ out = tmp_path / "qc_vs_secondary_scatter.png"
+ _plot_qc_vs_reg_scatter(results, reg_targets, out, title="test")
+ assert out.exists()
+
+
+def test_plot_qc_vs_reg_scatter_handles_single_target(tmp_path):
+ """One reg-target = one panel; still must render without grid-shape errors."""
+ results = [
+ _scatter_result("latent", "latent\nα=1"),
+ _scatter_result("composition", "comp\n(seed)"),
+ ]
+ out = tmp_path / "qc_single.png"
+ _plot_qc_vs_reg_scatter(results, {"klat": 2.0}, out, title="single target")
+ assert out.exists()
+
+
+def test_plot_qc_vs_reg_scatter_skips_on_empty_results(tmp_path):
+ out = tmp_path / "should_not_exist.png"
+ _plot_qc_vs_reg_scatter([], {"klat": 2.0}, out, title="empty")
+ assert not out.exists()
+
+
+def test_plot_qc_vs_reg_scatter_skips_on_empty_reg_targets(tmp_path):
+ """No reg-targets ⇒ nothing to plot; the helper must not write a degenerate figure."""
+ results = [_scatter_result("latent", "latent\nα=1")]
+ out = tmp_path / "should_not_exist.png"
+ _plot_qc_vs_reg_scatter(results, {}, out, title="no targets")
+ assert not out.exists()
+
+
+def test_plot_qc_vs_reg_scatter_with_seed_layer(tmp_path):
+ """Optional ``seed_qc`` / ``seed_reg`` draw the per-seed baseline as orange ★ stars.
+
+ Verifies the figure still renders (the layer is added before the optimised clouds and
+ drops cleanly when the kwarg is omitted — see the no-arg test above).
+ """
+ results = [
+ _scatter_result("latent", "latent\nα=1"),
+ _scatter_result("composition", "comp\n(seed)"),
+ ]
+ reg_targets = {"formation_energy": -2.0, "klat": 2.0}
+ n_seeds = 5
+ rng = np.random.default_rng(123)
+ out = tmp_path / "qc_with_seeds.png"
+ _plot_qc_vs_reg_scatter(
+ results,
+ reg_targets,
+ out,
+ title="with seeds",
+ seed_qc=rng.uniform(0.1, 0.6, size=n_seeds),
+ seed_reg={
+ "formation_energy": rng.uniform(0.5, 2.5, size=n_seeds),
+ "klat": rng.uniform(-0.5, 1.0, size=n_seeds),
+ },
+ )
+ assert out.exists()
diff --git a/src/foundation_model/scripts/paper_inverse_trajectory.py b/src/foundation_model/scripts/paper_inverse_trajectory.py
new file mode 100644
index 0000000..e02139e
--- /dev/null
+++ b/src/foundation_model/scripts/paper_inverse_trajectory.py
@@ -0,0 +1,472 @@
+# Copyright 2026 TsumiNa.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Per-step trajectory analytics + plots + animations for inverse-design runs.
+
+Each call to :meth:`FlexibleMultiTaskModel.optimize_latent` /
+:meth:`FlexibleMultiTaskModel.optimize_composition` can now optionally record:
+
+* ``trajectory_targets`` — shape ``(steps, B, T)``: per-step predicted target values
+ (one column per regression task in ``reg_targets`` order; QC is separate).
+* ``trajectory_weights`` — shape ``(steps, B, n_components)``: per-step element weights
+ (the optimisation variable for ``optimize_composition``; decoded via ``KMD.inverse`` from the
+ per-step AE-decoded ``x`` for ``optimize_latent``).
+
+Together with the per-step QC trajectory (also collected from the raw target predictions for
+the QC head), these are enough to visualise:
+
+1. How fast each target converges relative to the others (static line plot, normalised so all
+ targets are on the same y-axis).
+2. How the recipe evolves across the optimisation (animated bar chart of the per-step composition
+ on the side, frame per step).
+
+This module hosts the pure helpers; ``paper_inverse_comparison.run()`` is the only caller.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Mapping
+from pathlib import Path
+from typing import Any, Iterable
+
+import matplotlib
+
+matplotlib.use("Agg")
+
+import matplotlib.animation as manimation
+import matplotlib.pyplot as plt
+import numpy as np
+from loguru import logger
+
+
+# --- representative-seed picker -----------------------------------------------------------------
+
+
+def best_seed_by_target_distance(
+ qc_final: np.ndarray,
+ reg_final: dict[str, np.ndarray],
+ reg_targets: Mapping[str, float],
+) -> int:
+ """Pick the seed whose final state minimises the joint normalised distance to all targets.
+
+ "Joint distance" = $\\sqrt{(1 - \\text{QC})^2 + \\sum_t ((y_t - \\text{target}_t) / s_t)^2}$
+ where $s_t$ is the per-task scale (we use the absolute target value as a stand-in so each
+ task contributes on a comparable scale; a target of ±2 σ in z-scored space gives a scale of 2).
+
+ The QC term uses ``1 - QC`` so closer-to-1 wins; the regression terms use signed deviation
+ so an under-shoot and an over-shoot are penalised equally.
+ """
+ qc_final = np.asarray(qc_final, dtype=float)
+ n = qc_final.shape[0]
+ if n == 0:
+ raise ValueError("best_seed_by_target_distance: empty qc_final array.")
+ dist_sq = (1.0 - qc_final) ** 2
+ for task, target in reg_targets.items():
+ scale = max(abs(float(target)), 1.0) # avoid divide-by-zero if target == 0
+ vals = np.asarray(reg_final[task], dtype=float)
+ dist_sq = dist_sq + ((vals - float(target)) / scale) ** 2
+ return int(np.argmin(dist_sq))
+
+
+# --- trajectory normalisation -----------------------------------------------------------------
+
+
+def normalize_target_trajectories(
+ qc_trajectory: np.ndarray,
+ reg_trajectory: dict[str, np.ndarray],
+ reg_targets: Mapping[str, float],
+ seed_qc: np.ndarray,
+ seed_reg: Mapping[str, np.ndarray],
+) -> dict[str, np.ndarray]:
+ """Map per-step target predictions to a [0, 1] "progress" fraction.
+
+ For each target, 0 = "at seed baseline", 1 = "exactly at target". Values can exceed [0, 1]
+ if the optimiser overshoots. The transform is per-(task, seed): for seed *i* we compute
+ ``(y[step, i] - baseline[i]) / (target - baseline[i])`` so a noisy seed-to-seed baseline
+ doesn't dilute the average. After per-seed normalisation we mean over seeds so the static
+ plot shows the average progress across the seed cohort.
+
+ Returns: dict ``{"QC": (steps,), task_name: (steps,)}`` of mean progress values.
+ """
+ out: dict[str, np.ndarray] = {}
+
+ # QC always targets 1.0.
+ qc_baseline = np.asarray(seed_qc, dtype=float) # (B,)
+ qc_target = 1.0
+ qc_denom = qc_target - qc_baseline
+ qc_denom = np.where(np.abs(qc_denom) < 1e-9, 1.0, qc_denom) # protect against /0
+ qc_progress = (np.asarray(qc_trajectory, dtype=float) - qc_baseline[None, :]) / qc_denom[None, :]
+ out["QC"] = qc_progress.mean(axis=1)
+
+ for task, target in reg_targets.items():
+ baseline = np.asarray(seed_reg[task], dtype=float) # (B,)
+ denom = float(target) - baseline
+ denom = np.where(np.abs(denom) < 1e-9, 1.0, denom)
+ traj = np.asarray(reg_trajectory[task], dtype=float) # (steps, B)
+ progress = (traj - baseline[None, :]) / denom[None, :]
+ out[task] = progress.mean(axis=1)
+
+ return out
+
+
+# --- static plot -------------------------------------------------------------------------------
+
+
+_TARGET_COLOR_QC = "#C44E52" # red — matches the target lines used elsewhere
+_TARGET_COLORS_REG = ["#2563EB", "#55A868", "#E67E22", "#9467bd"] # blue / green / orange / purple
+
+
+def plot_trajectory_static(
+ progress: Mapping[str, np.ndarray],
+ out_path: Path,
+ *,
+ title: str,
+ seed_composition: str | None = None,
+) -> None:
+ """Line plot of normalised progress vs step.
+
+ QC is drawn in red; the regression tasks cycle through the project's blue / green / orange
+ palette. The y-axis is "progress fraction" (0 = at seed, 1 = at target); a horizontal dashed
+ line at 1.0 marks the joint target. The reader gets a one-glance answer to the question the
+ user asked: "do the targets converge together, or does the recipe stabilise early and the
+ targets keep moving?" — divergence between the QC line and the reg lines, or between the reg
+ lines themselves, surfaces immediately.
+
+ When ``seed_composition`` is provided (the per-seed composition string, e.g.
+ ``"Au65 Ga20 Gd15"``), it's appended to the figure title under the main title in a monospace
+ font — the reader can identify the seed by chemistry rather than by index.
+ """
+ fig, ax = plt.subplots(figsize=(8.0, 5.0), dpi=150)
+ steps = np.arange(len(next(iter(progress.values()))))
+
+ # QC first so it's visually behind the reg lines (the user usually cares about reg
+ # convergence; QC's behavior is rarely surprising).
+ if "QC" in progress:
+ ax.plot(steps, progress["QC"], color=_TARGET_COLOR_QC, lw=2.0, label="QC (P(quasicrystal))")
+ reg_keys = [k for k in progress if k != "QC"]
+ for i, key in enumerate(reg_keys):
+ ax.plot(
+ steps,
+ progress[key],
+ color=_TARGET_COLORS_REG[i % len(_TARGET_COLORS_REG)],
+ lw=1.8,
+ label=key,
+ )
+
+ ax.axhline(1.0, color="#666", ls="--", lw=1.0, alpha=0.7, label="target (progress = 1.0)")
+ ax.axhline(0.0, color="#bbb", ls=":", lw=0.8, alpha=0.5)
+ ax.set_xlabel("Optimisation step")
+ ax.set_ylabel("Progress (0 = seed, 1 = target)")
+ if seed_composition:
+ # Two-line layout: bold main title on top + seed composition underneath, with extra
+ # ``pad`` so the title doesn't sit flush against the upper axes line. Putting the
+ # seed-comp as a text annotation at y=1.02 collided with the title when matplotlib's
+ # default title-pad was applied — fix is to render both lines via set_title and a
+ # second matching text() at a clearly-distinct y position.
+ ax.set_title(title, fontsize=12, fontweight="bold", pad=22)
+ ax.text(
+ 0.5, 1.005, f"seed: {seed_composition}",
+ transform=ax.transAxes, ha="center", va="bottom",
+ fontsize=10, family="monospace", color="#444",
+ )
+ else:
+ ax.set_title(title, fontsize=12, fontweight="bold")
+ ax.legend(loc="best", fontsize=9, frameon=False)
+ ax.grid(True, alpha=0.2)
+ fig.tight_layout()
+ fig.savefig(out_path, bbox_inches="tight", facecolor="white")
+ plt.close(fig)
+ logger.info(f"Wrote trajectory static plot to {out_path}")
+
+
+# --- animation ---------------------------------------------------------------------------------
+
+
+def _topk_composition_frame(weights: np.ndarray, element_symbols: list[str], top_k: int = 10) -> list[tuple[str, float]]:
+ """Top-K elements by weight, sorted descending. Used as one frame of the animation's comp panel."""
+ idx = np.argsort(weights)[::-1][:top_k]
+ return [(element_symbols[int(i)], float(weights[int(i)])) for i in idx if weights[int(i)] > 1e-4]
+
+
+def plot_trajectory_animation(
+ progress: Mapping[str, np.ndarray],
+ per_step_weights: np.ndarray,
+ element_symbols: list[str],
+ out_paths_by_format: Mapping[str, Path],
+ *,
+ title: str,
+ seed_composition: str | None = None,
+ top_k_elements: int = 10,
+ fps: int = 15,
+ max_frames: int = 120,
+) -> None:
+ """Targets-vs-step line plot (top panel) + per-step top-K element bar chart (right panel).
+
+ The line plot draws the full curve from step 0; a vertical "current step" marker advances
+ one tick per frame. The bar chart on the right re-draws each frame to show the current
+ composition's top-K elements (so the viewer can see "what is the recipe right now?" as the
+ targets evolve). For long runs (steps > ``max_frames``) we subsample uniformly so the GIF
+ stays under a few seconds at fps=15.
+
+ Writers:
+ - ``gif`` → ``PillowWriter`` (no external deps; embeddable anywhere).
+ - ``html`` → ``HTMLWriter`` (JS-controlled play/pause/scrub; great for inspection).
+ - ``svg`` → custom SMIL-animated single-file SVG (browsers play it; PPT cannot embed).
+ """
+ n_steps = len(next(iter(progress.values())))
+ if n_steps == 0:
+ logger.warning("plot_trajectory_animation: empty progress arrays — skipping.")
+ return
+ if per_step_weights.shape[0] != n_steps:
+ logger.warning(
+ f"plot_trajectory_animation: per_step_weights step count ({per_step_weights.shape[0]}) "
+ f"does not match progress step count ({n_steps}); skipping animation."
+ )
+ return
+
+ # Uniform subsample down to ``max_frames`` so GIFs stay manageable. The line plot still uses
+ # the full curve; only the marker / weights frames are subsampled.
+ frame_steps = np.linspace(0, n_steps - 1, num=min(n_steps, max_frames)).astype(int)
+ frame_steps = np.unique(frame_steps) # in case of duplicate indices for very small n_steps
+
+ fig = plt.figure(figsize=(12.0, 5.5), dpi=120)
+ gs = fig.add_gridspec(1, 2, width_ratios=[2.0, 1.0], wspace=0.30)
+ ax_line = fig.add_subplot(gs[0, 0])
+ ax_bar = fig.add_subplot(gs[0, 1])
+
+ # --- Static line plot in left panel ---
+ steps = np.arange(n_steps)
+ if "QC" in progress:
+ ax_line.plot(steps, progress["QC"], color=_TARGET_COLOR_QC, lw=2.0, label="QC (P(quasicrystal))")
+ for i, key in enumerate([k for k in progress if k != "QC"]):
+ ax_line.plot(
+ steps,
+ progress[key],
+ color=_TARGET_COLORS_REG[i % len(_TARGET_COLORS_REG)],
+ lw=1.8,
+ label=key,
+ )
+ ax_line.axhline(1.0, color="#666", ls="--", lw=1.0, alpha=0.6)
+ ax_line.axhline(0.0, color="#bbb", ls=":", lw=0.8, alpha=0.5)
+ ax_line.set_xlabel("Optimisation step")
+ ax_line.set_ylabel("Progress (0 = seed, 1 = target)")
+ if seed_composition:
+ # Two-line title: bold panel title on top + monospace seed-composition underneath. The
+ # ``pad=22`` lifts the title clear of the second line; without the pad they overlap
+ # because matplotlib's default title baseline sits where the text annotation lands.
+ ax_line.set_title(title, fontsize=11, fontweight="bold", pad=22)
+ ax_line.text(
+ 0.5, 1.005, f"seed: {seed_composition}",
+ transform=ax_line.transAxes, ha="center", va="bottom",
+ fontsize=10, family="monospace", color="#444",
+ )
+ else:
+ ax_line.set_title(title, fontsize=11, fontweight="bold")
+ ax_line.legend(loc="best", fontsize=8, frameon=False)
+ ax_line.grid(True, alpha=0.2)
+ marker = ax_line.axvline(0, color="#444", lw=1.2, alpha=0.85)
+
+ # --- Bar chart in right panel (redrawn per frame) ---
+ ax_bar.set_title("Composition (top-K by weight)", fontsize=10)
+ ax_bar.set_xlim(0, 1.0)
+ ax_bar.set_xlabel("weight")
+
+ def _draw_bar(step_idx: int) -> None:
+ ax_bar.clear()
+ frame = _topk_composition_frame(per_step_weights[step_idx], element_symbols, top_k=top_k_elements)
+ if not frame:
+ ax_bar.text(0.5, 0.5, "(no elements above threshold)", ha="center", va="center", transform=ax_bar.transAxes)
+ else:
+ symbols, weights = zip(*frame)
+ y_pos = np.arange(len(symbols))
+ ax_bar.barh(y_pos, weights, color="#2563EB", alpha=0.75, edgecolor="#222", linewidth=0.5)
+ ax_bar.set_yticks(y_pos)
+ ax_bar.set_yticklabels(symbols, fontsize=9)
+ ax_bar.invert_yaxis() # largest on top
+ ax_bar.set_xlim(0, max(0.5, float(per_step_weights[step_idx].max()) * 1.1))
+ ax_bar.set_xlabel("weight")
+ ax_bar.set_title(f"Composition (step {step_idx + 1}/{n_steps})", fontsize=10)
+ ax_bar.grid(True, axis="x", alpha=0.2)
+
+ def _init() -> Iterable[Any]:
+ _draw_bar(int(frame_steps[0]))
+ marker.set_xdata([int(frame_steps[0])])
+ return (marker,)
+
+ def _update(frame_idx: int) -> Iterable[Any]:
+ step_idx = int(frame_steps[frame_idx])
+ _draw_bar(step_idx)
+ marker.set_xdata([step_idx])
+ return (marker,)
+
+ # Only build the matplotlib FuncAnimation when at least one matplotlib-native format
+ # (gif / html) is requested. For svg-only output we render a handwritten SMIL SVG without
+ # touching the animation object — building it anyway would emit a "Animation was deleted
+ # without rendering anything" UserWarning on test runs.
+ needs_mpl_anim = any(fmt in ("gif", "html") for fmt in out_paths_by_format)
+ anim = (
+ manimation.FuncAnimation(
+ fig,
+ _update,
+ frames=len(frame_steps),
+ init_func=_init,
+ interval=1000 // fps,
+ blit=False, # the bar chart redraw isn't blittable cleanly
+ )
+ if needs_mpl_anim
+ else None
+ )
+
+ for fmt, out_path in out_paths_by_format.items():
+ try:
+ if fmt == "gif":
+ anim.save(str(out_path), writer=manimation.PillowWriter(fps=fps))
+ elif fmt == "html":
+ # ``to_jshtml`` returns a single self-contained HTML string with frames embedded
+ # as base64 PNGs. The ``HTMLWriter`` alternative drops a separate ``*_frames/``
+ # folder of 120+ PNGs alongside, which clutters the output dir and makes the
+ # artefact non-portable. The base64 version is bigger per-file (~3 MB vs the
+ # multi-file's ~10 MB total) but is one self-contained file.
+ out_path.write_text(anim.to_jshtml(fps=fps), encoding="utf-8")
+ elif fmt == "svg":
+ _save_smil_svg(progress, per_step_weights, element_symbols, frame_steps, out_path, title=title, fps=fps)
+ else:
+ logger.warning(f"plot_trajectory_animation: unknown format {fmt!r} — skipping.")
+ continue
+ logger.info(f"Wrote trajectory animation ({fmt}) to {out_path}")
+ except Exception as exc: # pragma: no cover (writer-specific failure modes)
+ logger.warning(f"plot_trajectory_animation: failed to write {fmt} → {out_path}: {exc}")
+
+ plt.close(fig)
+
+
+# --- SMIL SVG writer ---------------------------------------------------------------------------
+
+
+def _save_smil_svg(
+ progress: Mapping[str, np.ndarray],
+ per_step_weights: np.ndarray,
+ element_symbols: list[str],
+ frame_steps: np.ndarray,
+ out_path: Path,
+ *,
+ title: str,
+ fps: int,
+ top_k_elements: int = 10,
+) -> None:
+ """Single-file SMIL-animated SVG.
+
+ matplotlib doesn't have a native SVG-animation writer; rather than render N PNGs and ship a
+ multi-frame SVG (would defeat the "one file" goal), we emit a compact handwritten SVG with
+ the static line plot as a vector overlay + ```` tags for the per-step marker and
+ per-element bar widths. Plays in any modern browser (Firefox / Chrome / Safari); PowerPoint
+ and Keynote cannot embed it directly — for those use the GIF.
+ """
+ n_steps = len(next(iter(progress.values())))
+ duration_s = max(1.0, len(frame_steps) / fps)
+ # Coordinate system: 800 × 400 viewBox, line plot in [40, 480] × [40, 360], bar plot in
+ # [520, 780] × [40, 360]. Bars are horizontal, top-K elements, redrawn via .
+
+ # ---- header ----
+ parts: list[str] = []
+ parts.append(
+ '")
+ out_path.write_text("\n".join(parts), encoding="utf-8")
diff --git a/src/foundation_model/scripts/paper_inverse_trajectory_test.py b/src/foundation_model/scripts/paper_inverse_trajectory_test.py
new file mode 100644
index 0000000..e5f152c
--- /dev/null
+++ b/src/foundation_model/scripts/paper_inverse_trajectory_test.py
@@ -0,0 +1,192 @@
+# Copyright 2026 TsumiNa.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Tests for the pure helpers in :mod:`paper_inverse_trajectory`.
+
+The full ``_emit_trajectory_outputs`` orchestrator needs a real trained checkpoint to exercise;
+this file covers the pure functions — seed picker, progress normalisation, and the writer
+smoke-tests (static plot + gif + html + svg). Animations are checked only for "file got written";
+visual correctness is verified by inspecting the rerun artefacts.
+"""
+
+from __future__ import annotations
+
+import numpy as np
+import pytest
+
+from foundation_model.scripts.paper_inverse_trajectory import (
+ best_seed_by_target_distance,
+ normalize_target_trajectories,
+ plot_trajectory_animation,
+ plot_trajectory_static,
+)
+
+
+# --- best_seed_by_target_distance --------------------------------------------------------------
+
+
+def test_best_seed_picks_closest_joint_distance_to_targets():
+ """Seed 1 is closest to the joint target (QC=1, fe=-2, klat=2); the picker should return 1."""
+ qc = np.array([0.20, 0.95, 0.50]) # seed 1 has highest QC
+ reg = {
+ "formation_energy": np.array([+0.5, -1.9, -1.0]), # seed 1 hits target -2 best
+ "klat": np.array([0.0, 1.8, 1.2]), # seed 1 hits target 2 best
+ }
+ reg_targets = {"formation_energy": -2.0, "klat": 2.0}
+ assert best_seed_by_target_distance(qc, reg, reg_targets) == 1
+
+
+def test_best_seed_handles_zero_target_without_div_by_zero():
+ """``target == 0`` would naively divide by zero; the picker uses a min-scale guard."""
+ qc = np.array([0.9, 0.8])
+ reg = {"some_task": np.array([0.1, 0.5])}
+ # Should pick seed 0 (closer to target 0).
+ assert best_seed_by_target_distance(qc, reg, {"some_task": 0.0}) == 0
+
+
+def test_best_seed_empty_qc_raises():
+ with pytest.raises(ValueError, match="empty qc_final"):
+ best_seed_by_target_distance(np.array([]), {}, {})
+
+
+# --- normalize_target_trajectories -------------------------------------------------------------
+
+
+def test_normalize_trajectory_maps_baseline_to_zero_and_target_to_one():
+ """Per (task, seed): a step's value of (target - baseline) + baseline ⇒ progress = 1."""
+ n_steps = 4
+ n_seeds = 2
+ # One reg target only. Baseline = [0.0, 0.5], target = 2.0.
+ reg_targets = {"k": 2.0}
+ seed_reg = {"k": np.array([0.0, 0.5])}
+ # Per-seed trajectory: linear interpolation from baseline → target across 4 steps.
+ traj_k = np.stack(
+ [
+ np.linspace(0.0, 2.0, n_steps), # seed 0
+ np.linspace(0.5, 2.0, n_steps), # seed 1
+ ],
+ axis=1,
+ ) # (steps, B)
+ # QC trajectory: flat at the seed baseline so it normalises to 0 progress throughout.
+ seed_qc = np.array([0.1, 0.2])
+ qc_traj = np.tile(seed_qc[None, :], (n_steps, 1))
+
+ progress = normalize_target_trajectories(
+ qc_trajectory=qc_traj,
+ reg_trajectory={"k": traj_k},
+ reg_targets=reg_targets,
+ seed_qc=seed_qc,
+ seed_reg=seed_reg,
+ )
+ # k progress: starts at 0, ends at 1 (per-seed normalised then mean over B).
+ assert progress["k"].shape == (n_steps,)
+ assert progress["k"][0] == pytest.approx(0.0, abs=1e-9)
+ assert progress["k"][-1] == pytest.approx(1.0, abs=1e-9)
+ # QC stays at baseline ⇒ progress = 0 throughout.
+ assert progress["QC"].shape == (n_steps,)
+ assert np.allclose(progress["QC"], 0.0)
+
+
+# --- plot writers ------------------------------------------------------------------------------
+
+
+def _toy_progress() -> dict[str, np.ndarray]:
+ """4-target × 30-step normalised progress, monotone so the picture is interpretable."""
+ n = 30
+ return {
+ "QC": np.clip(np.linspace(0.0, 0.95, n) + 0.02 * np.sin(np.linspace(0, 4 * np.pi, n)), 0, 1.5),
+ "formation_energy": np.linspace(0.0, 1.2, n),
+ "klat": np.linspace(0.0, 0.8, n),
+ }
+
+
+def _toy_weights(n_steps: int = 30, n_components: int = 12) -> np.ndarray:
+ """(steps, n_components) toy weights — start sparse, drift toward a different sparse set."""
+ rng = np.random.default_rng(7)
+ w = np.zeros((n_steps, n_components), dtype=float)
+ # Initial: mass on elements 0..2
+ w[0, :3] = [0.5, 0.3, 0.2]
+ # Final: mass on elements 4, 6, 7
+ end = np.zeros(n_components)
+ end[4], end[6], end[7] = 0.5, 0.3, 0.2
+ for s in range(n_steps):
+ t = s / (n_steps - 1)
+ w[s] = (1 - t) * w[0] + t * end + 0.001 * rng.standard_normal(n_components)
+ w[s] = np.clip(w[s], 0, None)
+ w[s] /= w[s].sum()
+ return w
+
+
+def test_plot_trajectory_static_writes_png(tmp_path):
+ out = tmp_path / "static.png"
+ plot_trajectory_static(_toy_progress(), out, title="toy trajectory")
+ assert out.exists()
+
+
+def test_plot_trajectory_static_with_seed_composition(tmp_path):
+ """``seed_composition`` is rendered as a monospace annotation under the title — verify the
+ plot still writes with the kwarg present (visual correctness is by inspection)."""
+ out = tmp_path / "static_with_seed.png"
+ plot_trajectory_static(
+ _toy_progress(), out, title="toy trajectory", seed_composition="Au65 Ga20 Gd15"
+ )
+ assert out.exists()
+
+
+def test_plot_trajectory_animation_writes_gif(tmp_path):
+ out = tmp_path / "anim.gif"
+ plot_trajectory_animation(
+ _toy_progress(),
+ per_step_weights=_toy_weights(),
+ element_symbols=[f"E{i}" for i in range(12)],
+ out_paths_by_format={"gif": out},
+ title="toy animation",
+ max_frames=10, # keep test fast
+ )
+ assert out.exists()
+
+
+def test_plot_trajectory_animation_writes_html(tmp_path):
+ out = tmp_path / "anim.html"
+ plot_trajectory_animation(
+ _toy_progress(),
+ per_step_weights=_toy_weights(),
+ element_symbols=[f"E{i}" for i in range(12)],
+ out_paths_by_format={"html": out},
+ title="toy animation",
+ max_frames=10,
+ )
+ assert out.exists()
+
+
+def test_plot_trajectory_animation_writes_smil_svg(tmp_path):
+ out = tmp_path / "anim.svg"
+ plot_trajectory_animation(
+ _toy_progress(),
+ per_step_weights=_toy_weights(),
+ element_symbols=[f"E{i}" for i in range(12)],
+ out_paths_by_format={"svg": out},
+ title="toy animation",
+ max_frames=8,
+ )
+ assert out.exists()
+ body = out.read_text(encoding="utf-8")
+ # The SMIL animation should contain tags driving the marker x1/x2 + bar widths.
+ assert "