diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index d8e05d1..3d37bbd 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,304 +2,301 @@ ``` foundation_model/ -├── src/ -│ └── foundation_model/ # Main Python package -│ ├── models/ # Neural network models and components -│ │ ├── components/ # Reusable model parts (encoders, fusion, SSL) -│ │ └── task_head/ # Task-specific prediction heads (regression, classification, sequence) -│ ├── data/ # Data handling (Dataset, DataModule, splitter) -│ ├── utils/ # Utility functions (plotting, training helpers) -│ ├── configs/ # Configuration models -│ └── scripts/ # Execution scripts (e.g., train.py) - -├── data/ # Placeholder for larger, persistent datasets (e.g., raw data) +├── src/foundation_model/ # Main Python package +│ ├── models/ # Neural network models +│ │ ├── components/ # Reusable encoder + utility blocks +│ │ │ ├── fc_layers.py # LinearBlock / LinearLayer +│ │ │ └── foundation_encoder.py # MLP / Transformer backbones +│ │ ├── task_head/ # Task-specific prediction heads +│ │ │ ├── regression.py +│ │ │ ├── classification.py +│ │ │ ├── kernel_regression.py +│ │ │ └── autoencoder.py # Reconstructs x from h_task; powers optimize_latent +│ │ ├── flexible_multi_task_model.py +│ │ └── model_config.py # EncoderConfig + per-task config dataclasses +│ ├── data/ # CompoundDataModule + per-task data sources + splitter +│ ├── utils/ # KMD + plotting / training helpers +│ └── scripts/ # Entry points (see below) +│ ├── train.py # fm-trainer (LightningCLI) +│ ├── continual_rehearsal_demo.py # demo runner (training + inverse design) +│ ├── continual_rehearsal_full.py # formal runner (11- or 24-task + 3 scenarios) +│ ├── continual_rehearsal_common.py # shared dump / plot helpers +│ ├── finetune_inverse_heads.py # head-only fine-tune of inverse heads +│ ├── eval_inverse_methods.py # piecewise latent-vs-composition eval +│ ├── paper_inverse_comparison.py # single-scenario paper-grade sweep +│ └── paper_inverse_3scenarios.py # 3-scenario orchestrator │ -├── results/ # Default output directory for models, logs, figures +├── data/ # Persistent datasets +├── artifacts/ # Run outputs (gitignored) +├── samples/ # TOML / YAML config templates +├── docs/ # Plan + algorithm reference + summary +├── notebooks/ # Experiments / analysis │ -├── notebooks/ # Jupyter notebooks for experiments, analysis, and visualization -│ └── experiments/ # Older experimental notebooks -│ -├── samples/ # Example configurations, data, and helper scripts -│ ├── cli_examples/ # Shell script examples for CLI usage -│ ├── fake_data/ # Small fake datasets for testing -│ ├── generated_configs/ # Example generated YAML configurations -│ └── helper_tools/ # Utility scripts for data/config generation -│ -├── .gitignore -├── .python-version -├── ARCHITECTURE.md # Detailed model architecture documentation -├── CHANGES.md # Changelog -├── pyproject.toml # Project metadata and dependencies -├── README.md # This file -└── uv.lock # uv lock file +├── ARCHITECTURE.md # This file +├── CHANGES.md # Changelog +├── CLAUDE.md / AGENTS.md # Repo-level coding guidelines +├── README.md # Top-level overview + quickstart +├── pyproject.toml # Dependencies + fm-trainer entry point +└── uv.lock ``` -# Model Architecture Documentation +# Model architecture -This document provides a detailed overview of the `FlexibleMultiTaskModel` architecture, its components, and data flow. +`FlexibleMultiTaskModel` ([src/foundation_model/models/flexible_multi_task_model.py](src/foundation_model/models/flexible_multi_task_model.py)) +is a single-encoder, multi-head supervised model. Composition descriptors enter the encoder, +get `tanh`'d at the model level, and feed every active task head. -## Detailed Architecture Diagram - -The following diagram illustrates the comprehensive structure of the `FlexibleMultiTaskModel`, including support for multi-modal inputs (formula and structure), various task heads (regression, classification, sequence), and internal data pathways. +## Diagram ```mermaid graph TD - %% ---------------- Legend ---------------- - subgraph Legend["Tensor Shape Legend"] + subgraph Legend["Tensor-shape legend"] direction LR - Legend_B["B: Batch size"] - Legend_L["L: Sequence length"] - Legend_D["D: Feature dimension"] + B["B: batch size"] + L["L: sequence length"] + D["D: feature dim"] end - %% ---------------- Inputs ---------------- - subgraph InputLayer["Input Layer"] - X_formula["x_formula (B, D_in_formula)"] - X_structure["x_structure (B, D_in_structure)"] - Task_Sequence_Data_Batch["task_sequence_data_batch (Dict[task_name, Tensor(B,L,1)])"] + %% ---------- Inputs ---------- + subgraph InputLayer["Input layer"] + X_formula["x_formula (B, input_dim)"] + Task_Seq_Data["task_sequence_data_batch
Dict[task_name, Tensor(B, L, 1)]
(KernelRegression heads only)"] end - %% -------- Foundation Encoder -------- + %% ---------- Foundation encoder ---------- subgraph FoundationEncoderModule["FoundationEncoder (self.encoder)"] direction TB - - FormulaEncoder["Configurable Shared Encoder
(MLP or Transformer)
(self.encoder.shared)"] - Aggregation["Token Aggregation
([CLS] or Mean Pool)"] - DepositBlock["Deposit Block (Linear + Tanh)
(self.encoder.deposit)
D_latent → D_deposit"] - - X_formula --> FormulaEncoder - FormulaEncoder -- "Token embeddings (B, L, D_model)" --> Aggregation - Aggregation -- "h_latent (B, D_latent)" --> DepositBlock - Aggregation -.-> H_Latent_Output_Point["h_latent"] - H_Latent_Output_Point --> DepositBlock + SharedEncoder["Configurable Shared Encoder
(MLPEncoderConfig or TransformerEncoderConfig)
self.encoder.shared"] + Aggregation["Token aggregation
([CLS] or mean pool — Transformer only)"] + X_formula --> SharedEncoder + SharedEncoder -- "Token embeddings (B, L, D_model)
or h_latent (B, latent_dim)" --> Aggregation + Aggregation -- "h_latent (B, latent_dim)" --> H_Latent["h_latent"] end - %% Junctions to heads - DepositBlock -- "h_task (B, D_deposit)" --> AttrTaskHeadsJunction{"To Attribute / Classification Heads"} - DepositBlock -- "h_task (B, D_deposit)" --> SeqTaskHeadsJunction{"To Sequence Heads"} + %% ---------- Model-level tanh ---------- + H_Latent --> TANH["torch.tanh
(model-level, applied in FlexibleMultiTaskModel.forward)"] + TANH -- "h_task (B, latent_dim)" --> HeadsJunction{"To every active task head"} - %% ---------------- Task Heads ---------------- - subgraph TaskHeadsModule["Task Heads"] + %% ---------- Task heads ---------- + subgraph TaskHeadsModule["Task heads (self.task_heads)"] direction TB - - %% Attribute / Classification - subgraph AttrClassHeads["Attribute / Classification Heads"] - direction LR - RegHead["RegressionHead: task_A
(MLP from D_deposit)"] - ClassHead["ClassificationHead: task_B
(MLP from D_deposit)"] - end - - %% Sequence heads - subgraph SeqHeads["Sequence Heads"] - direction LR - SeqHeadRNN["SequenceRNNHead: task_C
(Uses h_task + task_sequence_data_C)"] - SeqHeadTransformer["SequenceTransformerHead: task_D
(Uses h_task + task_sequence_data_D)"] - end + RegHead["RegressionHead
MLP from latent_dim"] + ClassHead["ClassificationHead
MLP + softmax, optional per-class weights"] + KRHead["KernelRegressionHead
(takes h_task + t-sequence)"] + AEHead["AutoEncoderHead
(reconstructs x_formula from h_task;
required for optimize_latent's latent space)"] end - - AttrTaskHeadsJunction --> RegHead - AttrTaskHeadsJunction --> ClassHead - - SeqTaskHeadsJunction --> SeqHeadRNN - Task_Sequence_Data_Batch -- "task_sequence_data_C" --> SeqHeadRNN - SeqTaskHeadsJunction --> SeqHeadTransformer - Task_Sequence_Data_Batch -- "task_sequence_data_D" --> SeqHeadTransformer - - %% ---------------- Outputs ---------------- - RegHead -- "pred_A (B, D_out_A)" --> OutputLayer["Model Outputs (Dictionary)"] - ClassHead -- "pred_B (B, D_out_B)" --> OutputLayer - SeqHeadRNN -- "pred_C (B, L, D_out_C)" --> OutputLayer - SeqHeadTransformer -- "pred_D (B, L, D_out_D)" --> OutputLayer - - %% ----------- Style definitions ----------- - classDef input fill:#E0EFFF,stroke:#5C9DFF,stroke-width:2px,color:#000; - classDef foundation fill:#DFF0D8,stroke:#77B55A,stroke-width:2px,color:#000; - classDef fusion fill:#D9EDF7,stroke:#6BADCF,stroke-width:2px,color:#000; - classDef taskhead fill:#FCF8E3,stroke:#F0AD4E,stroke-width:2px,color:#000; - classDef seqtaskhead fill:#F2DEDE,stroke:#D9534F,stroke-width:2px,color:#000; - classDef output fill:#EAEAEA,stroke:#888888,stroke-width:2px,color:#000; - classDef junction fill:#FFFFFF,stroke:#AAAAAA,stroke-width:1px,color:#000,shape:circle; - classDef point fill:#FFFFFF,stroke:#AAAAAA,stroke-width:1px,color:#000; + + HeadsJunction --> RegHead + HeadsJunction --> ClassHead + HeadsJunction --> KRHead + Task_Seq_Data -- "t-sequence for KR task" --> KRHead + HeadsJunction --> AEHead + + %% ---------- Outputs ---------- + RegHead -- "pred (B, D_out)" --> Outputs["Outputs (Dict[str, Tensor])"] + ClassHead -- "logits (B, num_classes)" --> Outputs + KRHead -- "pred (B, L, 1)" --> Outputs + AEHead -- "x̂ (B, input_dim)" --> Outputs + + %% ---------- Styles ---------- + classDef input fill:#E0EFFF,stroke:#5C9DFF,stroke-width:2px,color:#000; + classDef foundation fill:#DFF0D8,stroke:#77B55A,stroke-width:2px,color:#000; + classDef tanh fill:#D9EDF7,stroke:#6BADCF,stroke-width:2px,color:#000; + classDef taskhead fill:#FCF8E3,stroke:#F0AD4E,stroke-width:2px,color:#000; + classDef kr fill:#F2DEDE,stroke:#D9534F,stroke-width:2px,color:#000; + classDef ae fill:#EFE0F7,stroke:#9067C6,stroke-width:2px,color:#000; + classDef output fill:#EAEAEA,stroke:#888888,stroke-width:2px,color:#000; + classDef junction fill:#FFFFFF,stroke:#AAAAAA,stroke-width:1px,color:#000,shape:circle; classDef legend_style fill:#f9f9f9,stroke:#ccc,stroke-width:1px,color:#333; - %% ---------- Class assignments ---------- - class Legend_B,Legend_L,Legend_D legend_style - class X_formula,X_structure,Task_Sequence_Data_Batch input - class FormulaEncoder,Aggregation,DepositBlock foundation - class Fusion fusion - class H_Latent_Output_Point point - class AttrTaskHeadsJunction,SeqTaskHeadsJunction junction + class B,L,D legend_style + class X_formula,Task_Seq_Data input + class SharedEncoder,Aggregation,H_Latent foundation + class TANH tanh + class HeadsJunction junction class RegHead,ClassHead taskhead - class SeqHeadRNN,SeqHeadTransformer seqtaskhead - class OutputLayer output + class KRHead kr + class AEHead ae + class Outputs output ``` -## Component Explanations +## Component explanations -### 1. Input Layer -The model can accept several types of inputs: -- **`x_formula`**: Tensor representing formula-based features (e.g., chemical composition, elemental descriptors). Shape: `(BatchSize, D_in_formula)`. This is the primary input. -- **`task_sequence_data_batch`** (Optional): A dictionary where keys are sequence task names and values are tensors representing sequence input data (e.g., temperatures, time steps) for those tasks. Shape of each tensor: `(BatchSize, SequenceLength, NumFeaturesPerPoint)` (typically `(B,L,1)`). +### 1. Input layer +- **`x_formula`** — composition descriptors, shape `(B, input_dim)`. Typically the output of a + `descriptor_fn` (see `data/composition_sources.py`) cached per unique composition. +- **`task_sequence_data_batch`** *(KernelRegression heads only)* — `Dict[task_name, Tensor(B,L,1)]` + carrying the sequence x-axis (e.g. energies for DOS, temperatures for ZT) the KR head consumes. ### 2. Foundation Encoder (`self.encoder`) -This is the core shared part of the model. It processes formula descriptors with a configurable backbone and produces task-ready representations. The behavior is driven by `encoder_config`, which declares its mode with the `EncoderType` enum (`encoder_config.type`) defined in `model_config.py`. +A `FoundationEncoder` wrapping either an MLP or a Transformer backbone (mode chosen by +`encoder_config.type`): -- **`shared` (Configurable Backbone)**: Projects `x_formula` into a latent space. - - **MLP Mode** (`MLPEncoderConfig`): Applies the feed-forward stack defined by `hidden_dims`, optional normalization, and residual settings. The final hidden size becomes `latent_dim`. - - **Transformer Mode** (`TransformerEncoderConfig`): Treats each scalar feature as a token, learns per-token embeddings, and runs a stack of Transformer encoder blocks. Token outputs are aggregated through either a learnable `[CLS]` token or mean pooling depending on `use_cls_token`. The aggregated representation becomes `h_latent`. -- **`deposit` (Linear + Tanh)**: Processes `h_latent`. - - Input: `h_latent` (dimension defined by the chosen encoder’s `latent_dim`). - - Output: `h_task` (task-specific input representation, dimension `D_deposit`). `D_deposit` is typically the input dimension expected by the first non-sequence task head. +- **MLP mode** — `MLPEncoderConfig(hidden_dims=[input_dim, …, latent_dim])` runs a + `LinearBlock` (Linear + optional BatchNorm1d + LeakyReLU, optional residuals). `hidden_dims[0]` + is the input dim; `hidden_dims[-1]` is the latent dim. +- **Transformer mode** — `TransformerEncoderConfig(d_model=…, num_layers=…, nhead=…)` treats + each scalar feature as a token, learns per-token embeddings, runs Transformer encoder blocks, + and aggregates via either a learnable `[CLS]` token or mean pooling. `latent_dim = d_model`. -The output `h_task` (from the `deposit` layer) serves as the primary contextual input for ALL task heads (Attribute, Classification, and Sequence). The `h_latent` representation is the intermediate output within the `FoundationEncoder` before the `deposit` layer, whether it originates from the final MLP layer or the Transformer aggregation. +The encoder's output is a raw `h_latent` of shape `(B, latent_dim)` — there is **no** deposit +layer. The Tanh activation is applied *at the model level*, see below. -### 3. Task Heads (`self.task_heads`) -This is an `nn.ModuleDict` containing individual prediction heads for each configured task. +### 3. Model-level Tanh +`FlexibleMultiTaskModel.forward` applies `torch.tanh(self.encoder(x))` once and reuses the +resulting `h_task` for every task head and for `optimize_latent` / `optimize_composition`. This +keeps the head-input distribution bounded and lets the AutoEncoder head learn a stable +reconstruction target for inverse design. -- **General Input**: - - All task heads (Attribute Regression, Classification, and Sequence Prediction) receive `h_task` (output of the `deposit` block) as their primary input. - - Sequence Prediction heads additionally receive their specific sequence data (e.g., temperature points, time steps) from `task_sequence_data_batch['task_name']`. +### 4. Task heads (`self.task_heads`) +An `nn.ModuleDict`. All heads consume `h_task` of shape `(B, latent_dim)`. -- **`RegressionHead`**: - - Typically an MLP defined by `config.dims` (e.g., `[D_deposit, hidden_dim, 1]`). - - Outputs a continuous value (or vector) for each sample. Shape: `(BatchSize, D_out_regression)`. +| Head | Config | Output | +|---|---|---| +| `RegressionHead` | `RegressionTaskConfig(dims=[latent_dim, …, 1])` | `(B, D_out)` | +| `ClassificationHead` | `ClassificationTaskConfig(num_classes=K, class_weights=[…]?)` | logits `(B, K)`; optional per-class loss weights for imbalanced labels (PR #18) | +| `KernelRegressionHead` | `KernelRegressionTaskConfig(x_dim=…, t_dim=…)` | `(B, L, 1)` (one value per t-point) | +| `AutoEncoderHead` | enabled by `FlexibleMultiTaskModel(enable_autoencoder=True)` | `x̂ (B, input_dim)` — reconstruction of the original descriptor; **required for `optimize_latent(optimize_space="latent")`** | -- **`ClassificationHead`**: - - Typically an MLP defined by `config.dims` (e.g., `[D_deposit, hidden_dim, num_classes]`). - - Outputs logits for each class. Shape: `(BatchSize, NumClasses)`. +`disabled_task_heads` holds heads taken offline mid-run (e.g. by `model.disable_task(...)` during +the head-only fine-tune in `finetune_inverse_heads`), preserving their weights in the state-dict. -- **Sequence Heads (e.g., `SequenceRNNHead`, `SequenceTransformerHead`, `SequenceTCNFiLMHead`)**: - - These heads have more complex internal architectures (RNNs, Transformers, TCNs). - - They combine the contextual vector (`h_task`) with the input sequence points (`task_sequence_data_batch['task_name']`). - - Output a sequence of predictions. Shape: `(BatchSize, SequenceLength, D_out_sequence_point)`. +### 5. Model outputs +`forward` returns a `Dict[str, Tensor]` keyed by task name. `predict_step` further unwraps each +head's output via the head's own `predict` method (so e.g. classification gives both `*_logits` +and `*_probabilities`). -### 4. Model Outputs -The `forward` method of `FlexibleMultiTaskModel` returns a dictionary. -- Keys: Task names as defined in `task_configs`. -- Values: The corresponding prediction tensors from each enabled task head. +## Data flow + dimensionality summary -During `predict_step`, the output dictionary keys are further processed by each head's `predict` method, often resulting in keys like `task_name_value` or `task_name_probabilities`. +| Stage | Shape | +|---|---| +| `x_formula` | `(B, input_dim)` | +| After encoder (`h_latent`) | `(B, latent_dim)` | +| After model-level tanh (`h_task`) | `(B, latent_dim)` — feeds every head | +| Regression / Classification / AutoEncoder output | `(B, D_out)` | +| KernelRegression output | `(B, L, 1)` | -## Data Flow and Dimensionality Summary +## Loss calculation and weighting -- **Input (`x_formula`)**: `(B, shared_block_dims[0])` -- **After Formula/Shared Encoder**: `h_latent` or `h_formula` is `(B, shared_block_dims[-1])` -- **After Structure Encoder (if applicable)**: `h_structure` is `(B, struct_block_dims[-1])` (must be same as `shared_block_dims[-1]`) -- **After Fusion (if applicable)**: `h_fused` is `(B, shared_block_dims[-1])` -- **After Deposit Layer**: `h_task` is `(B, D_deposit)` -- **Regression/Classification Head Output**: `(B, task_specific_output_dim)` -- **Sequence Head Output**: `(B, SequenceLength, task_specific_output_dim_per_point)` +### 1. Raw task losses +Each head computes its own loss $\mathcal{L}_t$: -This structure allows for flexible combination of shared representations with task-specific processing. +- **Regression** — MSE (often on z-scored targets). +- **Classification** — cross-entropy with optional per-class weights (`class_weights` on the + task config). When `class_weights=None`, the head registers a buffer of ones so the + `state_dict` shape is stable across with/without configurations. +- **Kernel regression** — sequence-wise MSE. +- **AutoEncoder** — reconstruction MSE between `h_task` and the round-trip + `tanh(encoder(decoder(h_task)))`. -## Loss Calculation and Weighting +### 2. Optional learnable uncertainty (Kendall et al. CVPR 2018) +When `enable_learnable_loss_balancer=True`, the model registers $\log\sigma_t$ per task in +`model.task_log_sigmas` (a `ParameterDict`) and scales each contribution as: -The `FlexibleMultiTaskModel` employs a sophisticated strategy for calculating and weighting losses from multiple supervised tasks to enable stable and effective multi-task learning. This section details the approach, including the use of learnable uncertainty weighting. +$$ \mathcal{L}'_{t} = \tfrac{1}{2}\,w_t\,\exp(-2\log\sigma_t)\,\mathcal{L}_t + \log\sigma_t $$ -### 1. Raw Task Losses -Each individual task head (e.g., `RegressionHead`, `ClassificationHead`, `SequenceRNNHead`) computes its own "raw" loss ($\mathcal{L}_t$). This is typically a standard loss function appropriate for the task type: -- **Regression Tasks**: Mean Squared Error (MSE) is common, often calculated on target values that may have been pre-scaled by a `target_scaler` (e.g., `StandardScaler`) for numerical stability. -- **Classification Tasks**: Cross-Entropy Loss is typical. -- **Sequence Tasks**: Depends on the nature of the sequence; could be MSE per time step or another sequence-appropriate loss, also potentially on scaled targets. +with $w_t$ = `loss_weight` from the task config (default 1.0). The $\log\sigma_t$ term +regularises against $\sigma_t \to 0$. -Self-supervised tasks (MFM, contrastive, cross-reconstruction) also compute their respective raw losses. +### 3. Total loss +$$ \mathcal{L}_{\text{train}} = \sum_{t} \mathcal{L}'_{t} $$ -### 2. Learnable Uncertainty Weighting for Supervised Tasks +When the balancer is disabled (default), each term reduces to $w_t \cdot \mathcal{L}_t$. -To address challenges with balancing tasks that may have different loss scales or learning difficulties, the model implements learnable uncertainty weighting for supervised tasks, inspired by the work of Kendall, Gal, and Cipolla, "Multi-task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics," CVPR 2018. - -#### Conceptual Basis: Homoscedastic Uncertainty -Homoscedastic uncertainty refers to task-dependent uncertainty that is constant for all input samples of a given task but varies between tasks. The model learns these task-specific uncertainties ($\sigma_t$) and uses them to automatically balance the contribution of each task's loss. +```mermaid +graph TD + subgraph OverallLoss["Total Training Loss (train_final_loss)"] + direction TB + Sum["Σ task contributions"]:::output + T1["Task 1 final"]:::taskhead + T2["Task 2 final"]:::taskhead + TN["…"]:::taskhead + + T1_raw["L₁ (raw)"]:::rawloss --> Op1["× ½·w₁·exp(−2logσ₁)"]:::operation --> Op1r["+ logσ₁"]:::operation --> T1 + T2_raw["L₂ (raw)"]:::rawloss --> Op2["× ½·w₂·exp(−2logσ₂)"]:::operation --> Op2r["+ logσ₂"]:::operation --> T2 + TN_raw["…"]:::rawloss --> TN + + T1 --> Sum + T2 --> Sum + TN --> Sum + end -#### Probabilistic Formulation -For a regression task $t$, modeling the likelihood $p(y_t | f_t(\mathbf{x}), \sigma_t^2)$ as a Gaussian $\mathcal{N}(y_t | f_t(\mathbf{x}), \sigma_t^2)$, the negative log-likelihood (NLL) to be minimized is proportional to: -$$ \mathcal{L}'_t = \frac{1}{2\sigma_t^2} \mathcal{L}_t + \log \sigma_t $$ -where $\mathcal{L}_t = (y_t - f_t(\mathbf{x}))^2$ is the raw squared error. A similar formulation applies to classification tasks. + Bal["task_log_sigmas (Parameter)"]:::inputsrc -.-> Op1 + Bal -.-> Op1r + Bal -.-> Op2 + Bal -.-> Op2r + LW1["loss_weight w₁ (config)"]:::inputsrc -.-> Op1 + LW2["loss_weight w₂ (config)"]:::inputsrc -.-> Op2 -#### Practical Implementation -The model learns $\log \sigma_t$ for each supervised task $t$, stored in `model.task_log_sigmas`. With an optional per-task scalar `loss_weight = w_t`, the final loss component becomes: -$$ \mathcal{L}'_{t, \text{final}} = \frac{w_t \cdot \exp(-2 \log \sigma_t)}{2} \mathcal{L}_t + \log \sigma_t $$ -Where: -- $\mathcal{L}_t$: The raw, unweighted loss for task $t$. -- $\log \sigma_t$: The learnable log uncertainty for task $t$. -- $\exp(-2 \log \sigma_t)$: Equivalent to $1/\sigma_t^2$ (precision). If $\mathcal{L}_t$ is large (task is hard/noisy), $\log \sigma_t$ increases, down-weighting $\mathcal{L}_t$. -- $w_t$: User-provided scalar (defaults to 1.0) that scales task $t$'s contribution. -- The $\log \sigma_t$ term regularizes, preventing $\sigma_t$ from collapsing. + classDef output fill:#EAEAEA,stroke:#888888,stroke-width:2px,color:#000; + classDef taskhead fill:#FCF8E3,stroke:#F0AD4E,stroke-width:2px,color:#000; + classDef rawloss fill:#FFF3CD,stroke:#FFC107,stroke-width:1px,color:#000; + classDef operation fill:#E1F5FE,stroke:#0288D1,stroke-width:1px,color:#000; + classDef inputsrc fill:#E8EAF6,stroke:#3F51B5,stroke-width:1px,color:#000; +``` -### 3. Total Loss for Optimization -The total loss optimized during training (`train_final_loss`) is: -$$ \text{train\_final\_loss} = \sum_{t \in \text{supervised}} \mathcal{L}'_{t, \text{final}} + \sum_{s \in \text{auxiliary}} \mathcal{L}'_{s, \text{final}} $$ +### 4. Validation +The same formulation is reused with the learned $\log\sigma_t$ frozen. `val_final_loss` is the +default monitor for `ModelCheckpoint` / `EarlyStopping`. -When the uncertainty balancer is disabled, each supervised term simplifies to $w_t \cdot \mathcal{L}_t$. +## Inverse design (added in PR #18) -Any auxiliary/self-supervised heads contribute via their own modules; if none are configured this reduces to the supervised sum above. +The same `FlexibleMultiTaskModel` exposes two gradient-based inverse-design methods on a +trained checkpoint. Both share a regression-MSE + classification-cross-entropy backbone; only +the third loss term and the optimisation variable differ. -### 4. Validation Loss -During validation, the same weighting formulation is applied using the learned $\log \sigma_t$ values (without updating them). The primary metric for callbacks (e.g., `ModelCheckpoint`, `EarlyStopping`) is `val_final_loss`. +| Method | Optimisation variable | Method-specific loss term | Recipe directly available? | +|---|---|---|---| +| `optimize_latent(optimize_space="latent")` | $h$ (latent) | $\alpha \cdot \lVert h - \tanh(E(D(h))) \rVert^2$ — AE-alignment | no — needs AE decode then a `KMD.inverse` | +| `optimize_composition` | $\theta$, with $w = \text{softmax}(\theta)$ | $(1-d)\,H(w)$ — per-output entropy / peakiness | yes — $w$ is the recipe | -### Loss Calculation Flow Diagram +User-facing knobs (all on `[0, 1]` where applicable): -The following diagram illustrates the combination of different loss components: +- `ae_align_scale` (`optimize_latent`) — AE manifold alignment; sweet spot ≈ 0.5. +- `diversity_scale` (`optimize_composition`) — per-output element diversity; 1.0 = no penalty. +- `seed_blend` — fraction of seed kept at the start, rest is uniform over the whitelist (lets new + elements enter the recipe). +- `allowed_elements` — hard whitelist over element symbols. +- `element_step_scale` — per-element gradient scaling; `0` hard-locks an element to its seed value. +- `class_target_weight` — weight on the classification objective vs. the regression targets. ```mermaid graph TD - subgraph OverallLoss["Total Training Loss (train_final_loss)"] + subgraph Latent["optimize_latent (latent space)"] direction TB - SumLosses["Sum All Contributions"]:::output - - %% ---------- Supervised tasks ---------- - subgraph SupervisedLosses["Supervised Tasks Contribution"] - direction TB - SumSupervised["Sum Task Components"]:::output - Task1_Final["Task 1: Final Component"]:::taskhead - Task2_Final["Task 2: Final Component"]:::taskhead - TaskN_Final["Task N: Final Component"]:::taskhead - - Task1_Raw["Raw Loss (L₁)"]:::rawloss --> Op1_Scale["Scale by 0.5 · w₁ · e−2logσ₁"]:::operation - Op1_Scale --> Op1_AddReg["Add logσ₁"]:::operation - Op1_AddReg --> Task1_Final - - Task2_Raw["Raw Loss (L₂)"]:::rawloss --> Op2_Scale["Scale by 0.5 · w₂ · e−2logσ₂"]:::operation - Op2_Scale --> Op2_AddReg["Add logσ₂"]:::operation - Op2_AddReg --> Task2_Final - - TaskN_Raw["..."]:::rawloss --> TaskN_Final - - Task1_Final --> SumSupervised - Task2_Final --> SumSupervised - TaskN_Final --> SumSupervised - end - SumSupervised --> SumLosses + Seed1["Seed x_seed"] + Enc1["encoder + tanh"] + H["h (latent — the optimisation variable)"] + AE["AE round-trip:
D(h) → x̂ → tanh(E(x̂)) = h'"] + Heads1["Task heads (reg + cls)"] + AdamL["Adam updates h ← ∇_h L
L = reg_MSE + w_cls·(−log P(QC)) + α·‖h − h'‖²"] + + Seed1 --> Enc1 --> H + H --> Heads1 + H -. round-trip .-> AE + AE -. "h' (return arrow weighted by α)" .-> H + AdamL -.-> H end - %% ---------- Inputs ---------- - subgraph InputsToLossCalc["Inputs to Loss Calculation"] - L1_Head["Task 1 Head"]:::taskhead --> Task1_Raw - L2_Head["Task 2 Head"]:::taskhead --> Task2_Raw - LN_Head["..."]:::taskhead --> TaskN_Raw - - Learnable_LogSigmas["Learnable: task_log_sigmas (logσ_t)"]:::inputsrc -.-> Op1_Scale - Learnable_LogSigmas -.-> Op1_AddReg - Learnable_LogSigmas -.-> Op2_Scale - Learnable_LogSigmas -.-> Op2_AddReg - LossWeight1["Config: loss_weight (w₁)"]:::inputsrc -.-> Op1_Scale - LossWeight2["Config: loss_weight (w₂)"]:::inputsrc -.-> Op2_Scale + subgraph Comp["optimize_composition (differentiable KMD)"] + direction TB + Theta["logits θ (optimisation variable)"] + WSoft["softmax → w (simplex; the recipe)"] + KMD["x = w · K (KMD transform)"] + Enc2["encoder + tanh"] + Heads2["Task heads (reg + cls)"] + AdamC["Adam updates θ ← ∇_θ L
L = reg_MSE + w_cls·(−log P(QC)) + (1−d)·H(w)"] + + Theta --> WSoft --> KMD --> Enc2 --> Heads2 + AdamC -.-> Theta end - %% ---------- Style definitions ---------- - classDef output fill:#EAEAEA,stroke:#888888,stroke-width:2px,color:#000; - classDef taskhead fill:#FCF8E3,stroke:#F0AD4E,stroke-width:2px,color:#000; - classDef rawloss fill:#FFF3CD,stroke:#FFC107,stroke-width:1px,color:#000; - classDef operation fill:#E1F5FE,stroke:#0288D1,stroke-width:1px,color:#000; - classDef inputsrc fill:#E8EAF6,stroke:#3F51B5,stroke-width:1px,color:#000; - - %% ---------- Class assignments ---------- - class Task1_Final,Task2_Final,TaskN_Final taskhead - class SumLosses,SumSupervised output - class L1_Head,L2_Head,LN_Head taskhead - class Learnable_LogSigmas,LossWeight1,LossWeight2 inputsrc - class Task1_Raw,Task2_Raw,TaskN_Raw rawloss - class Op1_Scale,Op1_AddReg,Op2_Scale,Op2_AddReg operation + classDef latentClass fill:#DFF0D8,stroke:#55A868,stroke-width:2px,color:#000; + classDef compClass fill:#E0EFFF,stroke:#2563EB,stroke-width:2px,color:#000; + class Seed1,Enc1,H,AE,Heads1,AdamL latentClass + class Theta,WSoft,KMD,Enc2,Heads2,AdamC compClass ``` -This adaptive weighting scheme allows the model to dynamically balance the influence of different tasks based on their learned uncertainties, promoting more robust multi-task training. +For the full per-term design intent and the recommended use of each knob, see +[docs/inverse_design_algorithms.md](docs/inverse_design_algorithms.md). For the 3-scenario +study and headline takeaways, see [docs/qc_inverse_design_summary.md](docs/qc_inverse_design_summary.md). diff --git a/ARCHITECTURE_CLEANUP_FINAL.md b/ARCHITECTURE_CLEANUP_FINAL.md deleted file mode 100644 index 7b4b36c..0000000 --- a/ARCHITECTURE_CLEANUP_FINAL.md +++ /dev/null @@ -1,259 +0,0 @@ -# 简化架构清理 - 最终报告 - -## 🎯 任务目标 - -移除 deposit layer 后,彻底清理代码中所有过时的引用、文档和命名。 - -## ✅ 完成的所有修复 - -### 1. 修复代码 Bug:移除 `encoder.deposit` 引用 - -**文件**: [src/foundation_model/models/flexible_multi_task_model.py:522-524](src/foundation_model/models/flexible_multi_task_model.py#L522-L524) - -**问题**: `FoundationEncoder` 已移除 `deposit` 属性,但代码仍在引用 → 导致 `AttributeError` - -**修复**: -```diff - if self.freeze_shared_encoder: - for p in self.encoder.shared.parameters(): - p.requires_grad_(False) -- for p in self.encoder.deposit.parameters(): -- p.requires_grad_(False) -``` - ---- - -### 2. 更新 `_TransformerBackbone` 文档 - -**文件**: [src/foundation_model/models/components/foundation_encoder.py](src/foundation_model/models/components/foundation_encoder.py) - -#### 2.1 Class Docstring (Lines 34-41) - -```diff -- When ``use_cls_token`` is enabled the downstream ``deposit`` layer only sees -+ When ``use_cls_token`` is enabled the downstream task heads only see - the hidden state of the classifier token. - ... -- Disabling the ``[CLS]`` token switches to mean pooling, which exposes the -- aggregated hidden states of all tokens directly to the deposit layer and -+ Disabling the ``[CLS]`` token switches to mean pooling, which exposes the -+ aggregated hidden states of all tokens directly to the task heads and - distributes gradients evenly across the sequence. -``` - -#### 2.2 Forward Method Comments (Lines 133, 140) - -```diff -- # Gradients from the downstream deposit layer flow into the `[CLS]` token -+ # Gradients from the downstream task heads flow into the `[CLS]` token - -- # Mean pooling exposes every contextualised feature token to the deposit layer -+ # Mean pooling exposes every contextualised feature token to the task heads -``` - ---- - -### 3. 更新 `FlexibleMultiTaskModel` 文档 - -**文件**: [src/foundation_model/models/flexible_multi_task_model.py](src/foundation_model/models/flexible_multi_task_model.py) - -#### 3.1 Usage Scenarios (Line 84) - -```diff -- 4. Continual Learning: Support model updates via deposit layer design -+ 4. Continual Learning: Support model updates via modular architecture -``` - -#### 3.2 Parameter Documentation (Lines 90-91) - -```diff - task_configs : list[...] -- ...Regression and classification task heads receive the deposit -- layer output, while KernelRegression task heads receive both -- deposit layer output and sequence points. -+ ...Regression and classification task heads receive Tanh-activated -+ latent representations, while KernelRegression task heads receive both -+ latent representations and sequence points. -``` - -#### 3.3 shared_block_optimizer Documentation (Line 97) - -```diff - shared_block_optimizer : OptimizerConfig | None -- Optimizer configuration for the shared foundation encoder and deposit layer. -+ Optimizer configuration for the shared foundation encoder. -``` - -#### 3.4 Method Parameter Documentation (Line 1144) - -```diff - h_task : torch.Tensor -- Task representations from deposit layer, shape (B, D) -+ Tanh-activated latent representations, shape (B, D) -``` - ---- - -### 4. 重命名 `deposit_dim` → `latent_dim` - -**文件**: [src/foundation_model/models/flexible_multi_task_model.py](src/foundation_model/models/flexible_multi_task_model.py) - -**原因**: `deposit_dim` 名称已不准确,简化架构中不再有 deposit layer - -#### 4.1 定义处 (Line 135) - -```diff -- self.deposit_dim = self.encoder_config.latent_dim -+ # Dimension of latent representation (input to task heads after Tanh activation) -+ self.latent_dim = self.encoder_config.latent_dim -``` - -#### 4.2 使用处 (Line 242) - -```diff -- expected_input_dim = self.deposit_dim -+ expected_input_dim = self.latent_dim -``` - ---- - -## 📊 架构演变 - -### 演变历史 - -**原始架构(已废弃)**: -``` -X → encoder.shared → latent → encoder.deposit(Linear + Tanh) → task_heads - ↑ - 可学习的变换 -``` - -**统一 Tanh 架构(当前)**: -``` -X → encoder.shared → latent → torch.tanh() → task_heads - ↑ - 在 FlexibleMultiTaskModel.forward() 统一应用 -``` - -### 关键差异 - -| 方面 | 旧架构 | 新架构 | -|------|--------|--------| -| **Tanh 位置** | encoder.deposit 内部 | FlexibleMultiTaskModel.forward() | -| **额外变换** | Linear(latent_dim, deposit_dim) | 无 | -| **task heads 输入** | deposit Linear 变换后的表示 | 直接的 Tanh(latent) | -| **梯度流** | 通过 deposit Linear 层 | 直接通过 Tanh | -| **优化性能** | 受限(2.5 分) | 更强(5.0 分) | - ---- - -## 🔍 验证结果 - -### 代码引用检查 - -```bash -# ✅ encoder 中无 deposit 引用 -$ grep "deposit" src/foundation_model/models/components/foundation_encoder.py -# (无输出) - -# ✅ model 中无 deposit_dim 引用 -$ grep "deposit_dim" src/foundation_model/models/flexible_multi_task_model.py -# (无输出) - -# ✅ model 中无 "deposit layer" 文档引用 -$ grep "deposit layer" src/foundation_model/models/flexible_multi_task_model.py -# (无输出) -``` - -### 架构验证 - -可运行 [verify_current_architecture.py](verify_current_architecture.py) 验证: - -```bash -python3 verify_current_architecture.py -``` - -预期输出: -``` -✓ Encoder has NO deposit layer -✓ Tanh applied uniformly in FlexibleMultiTaskModel.forward() -✓ Both input and latent space optimization work correctly -``` - ---- - -## 📈 性能提升分析 - -### 实测数据(来自 notebook) - -| 指标 | 旧架构(有 deposit Linear) | 新架构(简化) | 提升 | -|------|---------------------------|---------------|------| -| 最终分数 | 2.5 | 5.0 | **+100%** | -| 优化曲线 | 不光滑 | 光滑 | ✓ | -| 收敛性 | 受限 | 更快 | ✓ | - -### 原因分析 - -1. **梯度流增强** - - 旧:梯度 → deposit Linear → 衰减 - - 新:梯度 → Tanh → 直接传播 - -2. **优化空间更自由** - - 旧:受 Linear 层权重约束 - - 新:在完整 latent 空间优化 - -3. **更少的参数** - - 旧:encoder + deposit Linear + task heads - - 新:encoder + task heads - ---- - -## ✅ 清理清单 - -- [x] 修复 `encoder.deposit` 代码引用(会导致 AttributeError) -- [x] 更新 `_TransformerBackbone` 所有文档引用 -- [x] 更新 `FlexibleMultiTaskModel` 所有文档引用 -- [x] 重命名 `deposit_dim` → `latent_dim` -- [x] 验证无残留引用 -- [x] 创建验证脚本 -- [x] 更新相关文档 - ---- - -## 📝 相关文件 - -### 核心代码 -- [src/foundation_model/models/components/foundation_encoder.py](src/foundation_model/models/components/foundation_encoder.py) -- [src/foundation_model/models/flexible_multi_task_model.py](src/foundation_model/models/flexible_multi_task_model.py) - -### 验证脚本 -- [verify_current_architecture.py](verify_current_architecture.py) -- [compare_input_vs_latent.py](compare_input_vs_latent.py) -- [test_unified_tanh.py](test_unified_tanh.py) - -### 文档 -- [UNIFIED_TANH_ARCHITECTURE.md](UNIFIED_TANH_ARCHITECTURE.md) -- [FIX_SUMMARY.md](FIX_SUMMARY.md) -- [SIMPLIFIED_ARCHITECTURE_CLEANUP.md](SIMPLIFIED_ARCHITECTURE_CLEANUP.md) -- 本文档 - ---- - -## 🎉 结论 - -**简化架构清理已完成!** - -所有过时的引用、文档和命名都已更新,代码库现在完全反映了新的简化架构: - -1. ✅ 无代码 bug(移除了错误的 `encoder.deposit` 引用) -2. ✅ 文档准确(所有引用更新为 "task heads" 和 "latent representations") -3. ✅ 命名清晰(`deposit_dim` → `latent_dim`) -4. ✅ 架构一致(所有地方统一使用 Tanh(latent)) -5. ✅ 性能提升(优化分数翻倍) - -新架构更简洁、更强大、更易理解! - ---- - -**日期**: 2025-11-25 -**修复**: Claude Code Assistant diff --git a/PLAN_inverse_design_integration.md b/PLAN_inverse_design_integration.md deleted file mode 100644 index 81fe90b..0000000 --- a/PLAN_inverse_design_integration.md +++ /dev/null @@ -1,236 +0,0 @@ -# Plan: Built-in AutoEncoder Head Support - -## Background - -`FlexibleMultiTaskModel.optimize_latent()` is already implemented. The latent-space exploration -workflow (post-training inverse design) is an **independent system** out of scope here. - -This plan covers only the training-time AE support changes. - ---- - -## Confirmed Findings - -| Question | Answer | -|----------|--------| -| `LinearBlock(output_active=None)` supported? | **Yes** — `if self.output_active:` already handles `None` → linear pass-through | -| AE task name | `"__reconstruction__"` (hardcoded everywhere) | -| `ae_task_name` parameter on `optimize_latent` | **Remove** — hardcode `"__reconstruction__"` | -| `AutoEncoderTaskConfig` | **Remove from public API**; replace with private `_AEConfig` | -| AE `loss_weight` | Fixed `1.0` | -| `autoencoder_nonnegative=True` activation | `Softplus`; `False` → linear (`output_active=None`) | - ---- - -## Design Summary - -Two new parameters on `FlexibleMultiTaskModel`: - -```python -enable_autoencoder: bool = False -autoencoder_nonnegative: bool = False -``` - -When `enable_autoencoder=True`, the model auto-creates an AE head mirroring the encoder dims. -No user-facing config class is exposed. The AE head is registered under `"__reconstruction__"` -in the internal `task_heads` dict so the existing training loop handles its loss automatically. - ---- - -## Dim Derivation — Mirror of Encoder - -``` -MLPEncoderConfig.hidden_dims = [input_dim, h1, …, latent_dim] - → AE dims = reversed(hidden_dims) e.g. [latent_dim, …, h1, input_dim] - -TransformerEncoderConfig: has .latent_dim and .input_dim - → AE dims = [latent_dim, input_dim] (single linear projection) -``` - -Both encoder config classes already expose `.input_dim` and `.latent_dim`. - ---- - -## Implementation Steps - -### Step 1 — Remove `AutoEncoderTaskConfig`; add private `_AEConfig` - -**File**: `src/foundation_model/models/model_config.py` - -- Delete `AutoEncoderTaskConfig`. -- Remove it from `TaskConfigType` union. -- Add a private (non-exported) dataclass `_AEConfig` with only what the training loop needs: - -```python -@dataclass -class _AEConfig: - """Internal config for the auto-created reconstruction head. Not part of public API.""" - name: str = "__reconstruction__" - type: TaskType = TaskType.AUTOENCODER - dims: List[int] = field(default_factory=list) # populated by model at init - nonnegative: bool = False - norm: bool = True - residual: bool = False - loss_weight: float = 1.0 - enabled: bool = True - data_column: str = "__autoencoder__" # existing DataModule sentinel -``` - -`TaskType.AUTOENCODER` enum value is **kept** — DataModule and Dataset still use it to skip -external data loading for AE tasks. - -`TaskConfigType` becomes: - -```python -TaskConfigType = RegressionTaskConfig | ClassificationTaskConfig | KernelRegressionTaskConfig -``` - ---- - -### Step 2 — Update `AutoEncoderHead` - -**File**: `src/foundation_model/models/task_head/autoencoder.py` - -- Replace `AutoEncoderTaskConfig` import with `_AEConfig`. -- Replace hardcoded `Sigmoid` with: - -```python -output_act = torch.nn.Softplus() if config.nonnegative else None -self.net = LinearBlock( - [d_in] + head_internal_dims[:-1], - normalization=config.norm, - residual=config.residual, - dim_output_layer=head_internal_dims[-1], - output_active=output_act, -) -``` - -`output_active=None` is already handled by `LinearBlock` (linear pass-through confirmed). - ---- - -### Step 3 — Update `FlexibleMultiTaskModel` - -**File**: `src/foundation_model/models/flexible_multi_task_model.py` - -#### 3a — New `__init__` parameters - -```python -def __init__( - self, - task_configs: Sequence[RegressionTaskConfig | ClassificationTaskConfig | KernelRegressionTaskConfig], - *, - encoder_config: ..., - freeze_shared_encoder: bool = False, - shared_block_optimizer: OptimizerConfig | None = None, - enable_learnable_loss_balancer: bool = False, - allow_all_missing_in_batch: bool = True, - enable_autoencoder: bool = False, # NEW - autoencoder_nonnegative: bool = False, # NEW -): -``` - -`AutoEncoderTaskConfig` is removed from the `task_configs` type hint and validation. - -#### 3b — Auto-create AE config in `__init__` (after encoder init, before `_init_task_heads`) - -```python -self._ae_config: _AEConfig | None = None -if enable_autoencoder: - dims = self._derive_ae_dims(self.encoder_config) - self._ae_config = _AEConfig(dims=dims, nonnegative=autoencoder_nonnegative) - # Append to internal task_configs so training loop and DataModule see it - self.task_configs.append(self._ae_config) - self.task_configs_map[self._ae_config.name] = self._ae_config -``` - -#### 3c — Static helper `_derive_ae_dims` - -```python -@staticmethod -def _derive_ae_dims(encoder_config: BaseEncoderConfig) -> list[int]: - if isinstance(encoder_config, MLPEncoderConfig): - return list(reversed(encoder_config.hidden_dims)) - # TransformerEncoderConfig - return [encoder_config.latent_dim, encoder_config.input_dim] -``` - -#### 3d — Remove `ae_task_name` from `optimize_latent` - -In `optimize_latent`, replace every reference to `ae_task_name` with the constant -`"__reconstruction__"`. Update the validation block: - -```python -# optimize_space == "latent" -AE_TASK_NAME = "__reconstruction__" -if AE_TASK_NAME not in self.task_heads: - raise ValueError( - "optimize_space='latent' requires enable_autoencoder=True on this model." - ) -``` - -Remove the `ae_task_name` parameter from the method signature and all docstring references. - -#### 3e — Remove `AutoEncoderTaskConfig` from all type-hints and assertions - -Grep targets in `flexible_multi_task_model.py`: -- `__init__` task_configs type annotation (line 112) -- `add_task` type annotation (line 404) -- `isinstance(..., AutoEncoderTaskConfig)` assertions (lines 424, 458) -- Import line 43 - -Replace `isinstance(config_item, AutoEncoderTaskConfig)` with -`isinstance(config_item, _AEConfig)` or `config_item.type == TaskType.AUTOENCODER`. - ---- - -### Step 4 — Update DataModule and Dataset - -**Files**: `datamodule.py`, `dataset.py` - -- Remove `AutoEncoderTaskConfig` import; the `TaskType.AUTOENCODER` check is sufficient. -- `TaskConfig` type alias in `datamodule.py` drops `AutoEncoderTaskConfig`: - -```python -TaskConfig = RegressionTaskConfig | ClassificationTaskConfig | KernelRegressionTaskConfig -``` - -No logic changes — the existing `cfg.type != TaskType.AUTOENCODER` guards already work with -`_AEConfig` because `_AEConfig.type = TaskType.AUTOENCODER`. - ---- - -### Step 5 — Tests - -**`flexible_multi_task_model_test.py`** — new group `TestAutoEncoder`: - -| Test | Checks | -|------|--------| -| `test_enable_autoencoder_mlp` | AE head created; dims = reversed `hidden_dims`; forward runs; AE loss in training metrics | -| `test_enable_autoencoder_transformer` | dims = `[latent_dim, input_dim]` | -| `test_nonnegative_output` | `autoencoder_nonnegative=True` → all output values ≥ 0 | -| `test_linear_output` | `autoencoder_nonnegative=False` → output can be negative | -| `test_no_autoencoder_default` | `enable_autoencoder=False` (default) → `"__reconstruction__"` not in `task_heads` | -| `test_optimize_latent_requires_ae` | `optimize_space="latent"` without AE → `ValueError` | -| `test_optimize_latent_with_ae` | `optimize_space="latent"` with `enable_autoencoder=True` → runs correctly | - -**`task_head/autoencoder_test.py`** (new): - -| Test | Checks | -|------|--------| -| `test_softplus_output` | `nonnegative=True` → output ≥ 0 for arbitrary input | -| `test_linear_output` | `nonnegative=False` → output can be negative | - ---- - -## Files Touched - -| File | Change | -|------|--------| -| `models/model_config.py` | Remove `AutoEncoderTaskConfig`; add private `_AEConfig`; update `TaskConfigType` | -| `models/task_head/autoencoder.py` | Swap import to `_AEConfig`; replace `Sigmoid` with `nonnegative`-driven activation | -| `models/flexible_multi_task_model.py` | Add `enable_autoencoder` + `autoencoder_nonnegative`; `_derive_ae_dims`; remove `ae_task_name` from `optimize_latent`; clean up `AutoEncoderTaskConfig` references | -| `data/datamodule.py` | Remove `AutoEncoderTaskConfig` import; update `TaskConfig` alias | -| `data/dataset.py` | Remove `AutoEncoderTaskConfig` import if present | -| `models/flexible_multi_task_model_test.py` | Add `TestAutoEncoder` group | -| `models/task_head/autoencoder_test.py` | New test file | diff --git a/README.md b/README.md index f30f3cb..14affb7 100644 --- a/README.md +++ b/README.md @@ -1,166 +1,197 @@ # Foundation Model for Material Properties -A multi-task learning model for predicting various material properties. +A multi-task learning model for predicting material properties from composition descriptors, with +gradient-based inverse design on top of the trained checkpoint. ## Model Architecture -The `FlexibleMultiTaskModel` is designed with a modular and extensible architecture. At its core, it features: +The `FlexibleMultiTaskModel` is a modular multi-task regressor + classifier built around a shared +encoder. At the model level: -1. A **Foundation Encoder** that processes input features (formula-based, and optionally structure-based) to generate shared representations. This encoder includes mechanisms for multi-modal fusion if structural data is provided. -2. A **Tanh Activation** that is uniformly applied to latent representations at the model level, providing bounded outputs to task heads. -3. A collection of **Task-specific Heads** that take Tanh-activated latent representations from the foundation encoder to make predictions for various tasks, such as: - * Regression (e.g., predicting band gap) - * Classification (e.g., predicting material stability) - * Sequence Prediction (e.g., predicting density of states curves) - -Below is a high-level overview of the architecture: +1. A **Foundation Encoder** (MLP or Transformer) maps composition descriptors → a `latent_dim` + representation. +2. A **`torch.tanh`** at the model level provides bounded inputs (`h_task`) to the task heads. +3. A collection of **task-specific heads**: + - **Regression** — scalar / vector targets (e.g. formation energy, klat). + - **Classification** — discrete labels (e.g. material type), with optional per-class loss weights. + - **Kernel Regression** — per-composition property-vs-`t` sequences (e.g. DOS density vs energy, + power factor vs temperature). + - **AutoEncoder** — reconstructs the input descriptor from `h_task`; required for the + latent-space inverse-design path (see "Inverse design" below). ```mermaid graph TD - %% ---------- Inputs (同一级) ---------- + %% ---------- Inputs ---------- subgraph InputsLayer["Inputs"] direction TB - GeneralInputs["Formula / Structure
(x_formula, x_structure*)
*optional"] - SequenceDataInputs["Sequence Data
(task_sequence_* data)
*optional"] + X["x_formula (B, input_dim)"] + T["Sequence x-axis
(per-task, kernel regression only)"] end %% ---------- Foundation encoder ---------- - FE["Foundation Encoder
(Shared MLP, Fusion*, Deposit)
*optional"] + FE["Foundation Encoder
(MLP or Transformer)"] + TANH["tanh (model-level)"] %% ---------- Task heads ---------- - NonSeqHeads["Regression / Classification Heads"] - SeqHeads["Sequence Heads"] + REG["Regression Head(s)"] + CLF["Classification Head(s)"] + KR["KernelRegression Head(s)"] + AE["AutoEncoder Head
(optional — enables
latent-space inverse design)"] %% ---------- Edges ---------- - GeneralInputs --> FE - FE -- "h_task (for Reg/Class)" --> NonSeqHeads - FE -- "h_task (for Seq)" --> SeqHeads - SequenceDataInputs --> SeqHeads - NonSeqHeads --> Outputs["Outputs (Dictionary)"] - SeqHeads --> Outputs + X --> FE -- "h_latent (B, latent_dim)" --> TANH + TANH -- "h_task (B, latent_dim)" --> REG + TANH -- "h_task" --> CLF + TANH -- "h_task" --> KR + T --> KR + TANH -- "h_task" --> AE + REG --> O["Outputs (Dict[str, Tensor])"] + CLF --> O + KR --> O + AE --> O %% ---------- Styles ---------- classDef io fill:#E0EFFF,stroke:#5C9DFF,stroke-width:2px,color:#000; classDef main fill:#DFF0D8,stroke:#77B55A,stroke-width:2px,color:#000; classDef heads fill:#FCF8E3,stroke:#F0AD4E,stroke-width:2px,color:#000; - - %% ---------- Class assignments ---------- - class GeneralInputs,SequenceDataInputs io - class FE main - class NonSeqHeads,SeqHeads heads - class Outputs io + class X,T io + class FE,TANH main + class REG,CLF,KR,AE heads + class O io ``` -For a more detailed diagram and in-depth explanation of each component, data flow, and dimensionality, please refer to the [**Model Architecture Documentation (ARCHITECTURE.md)**](ARCHITECTURE.md). +For the detailed forward / loss / inverse-design diagrams, see +[**ARCHITECTURE.md**](ARCHITECTURE.md). ## Installation -1. Clone the repository: ```bash -git clone https://github.com/yourusername/foundation_model.git +git clone https://github.com/TsumiNa/foundation_model.git cd foundation_model -``` - -2. Install the package using uv: -```bash uv sync --frozen --all-groups ``` -This will install all dependencies as defined in the pyproject.toml and uv.lock files, including both production and development dependencies, and ensure exact version matching. This method is preferred for reproducible installations. +This installs all dependencies pinned by `uv.lock` (production + dev) for reproducibility. +To add a new dependency: `uv add ` (runtime) or `uv add --dev ` (dev). +## Usage -If you need to add additional dependencies, use: -```bash -uv add -# or for development dependencies -uv add --dev -``` +There are two parallel entry points: -## Usage +1. **`fm-trainer`** (PyTorch Lightning CLI, defined in [pyproject.toml](pyproject.toml) and backed + by [`scripts/train.py`](src/foundation_model/scripts/train.py)) — YAML-driven supervised + training of `FlexibleMultiTaskModel` on `CompoundDataModule`. +2. **`continual_rehearsal_demo` / `continual_rehearsal_full`** — TOML-driven multi-task continual + rehearsal runners that train a sequence of tasks with small replay, then run gradient-based + inverse design on the trained checkpoint. -The primary way to use this model is through the `train.py` script, which leverages PyTorch Lightning's `CLI`. This allows for flexible configuration via YAML files and command-line overrides. +### Training (YAML / LightningCLI) -### Training +```bash +fm-trainer fit --config path/to/config.yaml [--trainer.max_epochs=50] +``` -To train the model, you will typically use a command like: +or equivalently: ```bash -# From the project root directory -python -m foundation_model.scripts.train --config path/to/your/config.yaml [OTHER_CLI_OVERRIDES] +python -m foundation_model.scripts.train fit --config path/to/config.yaml ``` -Or, if you are in `src/foundation_model/scripts/`: + +`fit` / `validate` / `test` / `predict` are the standard LightningCLI subcommands. Any field +under `model.init_args.*`, `data.init_args.*`, `trainer.*` can be overridden from the command +line. See [`samples/`](samples/) for templates. + +### Continual rehearsal + inverse design (TOML) + ```bash -python train.py --config path/to/your/config.yaml [OTHER_CLI_OVERRIDES] +# Demo runner — small multi-task rehearsal, saves final_model.pt, optionally runs inverse design. +python -m foundation_model.scripts.continual_rehearsal_demo \ + --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml + +# Skip training, re-run only the inverse-design stage on an existing checkpoint. +python -m foundation_model.scripts.continual_rehearsal_demo \ + --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml \ + --inverse-only artifacts/inverse_design_run/training/final_model.pt ``` -- Replace `path/to/your/config.yaml` with the path to your experiment's configuration file. -- `[OTHER_CLI_OVERRIDES]` can be used to override specific parameters within your YAML file (e.g., `--trainer.max_epochs=50`). +See the [Inverse design](#inverse-design) section below for the full pipeline. ### Configuration -Model configuration is primarily handled through YAML files. These files define the model architecture (`FlexibleMultiTaskModel`), data loading (`CompoundDataModule`), PyTorch Lightning trainer settings, and any callbacks. - -You can find examples of configuration files in the `samples/generated_configs/` directory (e.g., `generated_model_config.yaml`) and more specific model component configurations in `configs/model_configs/` (e.g., `base_model.yaml`). +Both entry points read configuration as structured objects: -For detailed examples of different configurations (such as pre-training, fine-tuning, using specific model components like different sequence heads) and how to effectively use command-line overrides, please refer to the **## Quick Examples** section below. +- The **YAML / LightningCLI** path uses `init_args` blocks that map 1:1 onto each class's + `__init__` parameters (model, datamodule, trainer, callbacks). +- The **TOML** path uses a single `ContinualRehearsalConfig` dataclass; unknown keys are silently + ignored so the same TOML can drive both `continual_rehearsal_demo` and the downstream + `paper_inverse_comparison` script. ## Features -- Multi‑task learning for material property prediction -- **Dual‑modality support**: formula descriptors **+** optional structure descriptors -- **Pre‑training & downstream in one model** - - Pre‑train losses: contrastive, cross‑reconstruction, masked‑feature, property supervision - - `--pretrain` flag toggles extra losses; same architecture used for fine‑tune -- **Flexible sequence heads**: `rnn`, `vec`, `transformer`, `tcn`, `hybrid` (Flash‑Attention inside) -- **Encoder control**: `--freeze_encoder` to lock shared layers -- Handles missing values via masking & modality dropout -- Comprehensive logging and visualization tools -- Configurable data splitting strategies -- Early stopping and model checkpointing +- **Multi-task** regression + classification + kernel regression on a shared encoder. +- **Learnable per-task uncertainty** loss balancer (Kendall et al. CVPR 2018) — optional, per + `enable_learnable_loss_balancer`. See the "Loss Weighting Strategy" section below. +- **Per-class classification weights** (`ClassificationTaskConfig.class_weights`) — keeps minority + classes alive in imbalanced supervised tasks (e.g. the QC material-type head). +- **Task add / remove at runtime** — `model.add_task(cfg)` / `model.remove_tasks("name")` for + continual-learning-style task sequences. +- **Optional AutoEncoder head** (`enable_autoencoder=True`) — reconstructs the input descriptor + from `h_task`; required for `optimize_latent(optimize_space="latent")`. +- **Gradient-based inverse design** — two paths on a trained checkpoint: + - `model.optimize_latent(...)` — descends on `h` with an AE-alignment penalty + (`ae_align_scale ∈ [0, 1]`) that keeps the optimised latent on the AE manifold. + - `model.optimize_composition(...)` — differentiable KMD: descends on element-weight logits + directly, with optional element whitelist (`allowed_elements`), per-element step scaling + (`element_step_scale`), seed-vs-uniform mix (`seed_blend`), and per-output entropy penalty + (`diversity_scale ∈ [0, 1]`). +- **Continual rehearsal** training scripts (`continual_rehearsal_demo` / `..._full`) with small + replay, per-step checkpoints + parquet predictions, and a fully-automated paper-grade output + folder (figures + JSON + SUMMARY.md per inverse-design scenario). ### Loss Weighting Strategy -To train the `FlexibleMultiTaskModel` on supervised tasks with different loss scales, we rely on a learnable uncertainty term inspired by [Kendall, Gal, and Cipolla (CVPR 2018)](https://doi.org/10.1109/CVPR.2018.00781): +For supervised multi-task training, the model uses a learnable uncertainty term (Kendall, Gal, +and Cipolla, [CVPR 2018](https://doi.org/10.1109/CVPR.2018.00781)): -1. **Task heads produce raw losses.** Each supervised task $t$ supplies the head-specific loss $\mathcal{L}_t$ (e.g., MSE or cross-entropy). -2. **Per-task static scaling.** Each task configuration exposes `loss_weight` (default `1.0`) to scale that task’s raw loss before further combination. -3. **Optional learnable uncertainty.** When `enable_learnable_loss_balancer` is `True`, the model maintains a per-task parameter $\log \sigma_t` and scales the contribution as $\mathcal{L}'_{t} = \tfrac{1}{2}\,\texttt{loss\_weight}_t\,\exp(-2 \log \sigma_t)\,\mathcal{L}_t + \log \sigma_t`. This lets the model down-weight noisier objectives while respecting explicit task priorities. -4. **Fallback when disabled.** If the balancer is disabled or a task does not expose $\log \sigma_t`, the contribution becomes $\mathcal{L}'_{t} = \texttt{loss\_weight}_t \cdot \mathcal{L}_t`. -5. **Total loss.** The overall objective is the sum of all task contributions. +1. **Raw losses** — each task head supplies $\mathcal{L}_t$ (MSE / cross-entropy / sequence loss). +2. **Per-task static scaling** — each task config exposes `loss_weight` (default `1.0`) to scale + the raw loss before combination. +3. **Optional learnable uncertainty** — when `enable_learnable_loss_balancer=True`, the model + maintains $\log\sigma_t$ per task and scales the contribution as + $\mathcal{L}'_t = \tfrac{1}{2}\,w_t\,\exp(-2\log\sigma_t)\,\mathcal{L}_t + \log\sigma_t$. +4. **Fallback** — when disabled, each contribution reduces to $w_t \cdot \mathcal{L}_t$. +5. **Total loss** — sum of all task contributions. -See [ARCHITECTURE.md](ARCHITECTURE.md#loss-calculation-and-weighting) for a deeper walk-through of the loss pipeline and implementation hooks. +See [ARCHITECTURE.md § Loss Calculation](ARCHITECTURE.md#loss-calculation-and-weighting) for the +walk-through. ## Data Handling -- Supports multiple material properties -- Handles missing values through masking -- Configurable data splitting ratios -- Property-specific sampling fractions +- Per-task data files joined by a shared **composition** column. +- Missing values masked rather than dropped (per-task masks in `y_dict`). +- Configurable train/val/test splits, descriptor caching, per-task `task_masking_ratio` for + scaling-law experiments. -### Input Data: composition-keyed per-task sources +### Input data — composition-keyed per-task sources -`CompoundDataModule` is **composition-keyed**: each task owns its own data file(s), joined to -the others by a shared **composition** column. There is no monolithic attributes file — adding -a new property task means adding one file plus one task config. Descriptors are computed on -demand from the union of compositions via a user-supplied `descriptor_fn` (results are cached -per unique composition). +`CompoundDataModule` is composition-keyed: each task owns its own data file(s), joined to the +others by a shared **composition** column. There is no monolithic attributes file — adding a new +property task means adding one file plus one task config. Descriptors are computed on demand from +the union of compositions via a user-supplied `descriptor_fn` (results are cached per unique +composition). -**DataModule wiring:** +**DataModule wiring** (YAML): ```yaml data: class_path: foundation_model.data.datamodule.CompoundDataModule init_args: - # Computes descriptors from compositions. PrecomputedDescriptorSource looks them up from a - # composition-indexed file; supply your own callable to compute them instead. descriptor_fn: class_path: foundation_model.data.composition_sources.PrecomputedDescriptorSource init_args: path: "data/descriptors.parquet" - composition_column: null # null => use the file's index as the composition key - composition_column: "composition" # the join key shared across all task files - # default_data_files: "data/all_targets.parquet" # optional shared fallback for tasks - # # that don't declare their own data_files + composition_column: null # null => use the file's index as the composition key + composition_column: "composition" val_split: 0.1 test_split: 0.1 random_seed: 42 @@ -177,7 +208,7 @@ data: | `composition_column` | Per-task override of the global composition column | | `split_column` | Optional in-file `train` / `val` / `test` labels (default `"split"`) | | `task_masking_ratio` | Optional keep-ratio applied to this task's valid training samples | -| `predict_idx` | Composition subset to predict: a literal `train`/`val`/`test`/`all` or an explicit list | +| `predict_idx` | Composition subset to predict: `train`/`val`/`test`/`all` or an explicit list | ```yaml # In model.init_args.task_configs (linked into the datamodule automatically): @@ -185,9 +216,6 @@ data: type: REGRESSION data_files: "data/band_gap.parquet" data_column: "Band gap" - # split_column: "split" # optional - # task_masking_ratio: 0.9 # optional - # predict_idx: "test" # optional - name: dos type: KernelRegression data_files: "data/dos.parquet" @@ -198,311 +226,145 @@ data: **Splitting.** A single composition-level train/val/test split is derived by overlaying every task file's `split` column (precedence `test > val > train`; conflicts warn). Compositions without a label fall back to a representation-aware random split (`MultiTaskSplitter`) that -prioritizes rare tasks to improve their val/test representation and preserves the overall -val/test proportions (it does not guarantee every tiny task appears in every split). -`test_all=True` assigns everything to test. +prioritises rare tasks. `test_all=True` assigns everything to test. **Prediction.** Each task's `predict_idx` selects a composition subset; the predict set is their -union, exposed as `datamodule.predict_compositions` and attached as the output index by -`PredictionDataFrameWriter` (single-process runs). - -**Important considerations:** -* **Exact column names**: `data_column` / `t_column` / `composition_column` must match the - source columns exactly. The composition key may be a column or the file's index name. -* **List-valued cells**: sequences / multi-dim targets stored in CSV must be strings parseable - by `ast.literal_eval`, e.g. `"[1.0, 2.5, 3.0]"`. -* **Missing data**: compositions absent from a task's file (or with NaN targets) are **masked - out** for that task rather than dropped; placeholders fill `y_dict`. -* **Missing descriptors**: compositions for which `descriptor_fn` produces no valid descriptor - are dropped from all splits (with a warning). - -## Quick Examples - -The `train.py` script utilizes PyTorch Lightning's `CLI` ([see official documentation](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html)). This allows for comprehensive configuration of the model (`FlexibleMultiTaskModel`) and data module (`CompoundDataModule`) through YAML files, with parameters passed directly to their `__init__` methods via an `init_args` block. You can also override these YAML settings using command-line arguments. +union, exposed as `datamodule.predict_compositions`. -You can also adjust tasks programmatically. For example, to swap in two new heads after loading a checkpoint: +**Important.** Composition keys must match exactly across files; list-valued cells in CSV must be +strings parseable by `ast.literal_eval` (e.g. `"[1.0, 2.5, 3.0]"`); missing data is masked +per-task; compositions without a valid descriptor are dropped with a warning. -```python -model.remove_tasks("old_regression") -model.add_task(new_reg_cfg, new_cls_cfg) # accepts multiple configs in one call -``` - -It's recommended to start with a base YAML configuration (e.g., `samples/generated_configs/generated_model_config.yaml` or `configs/model_configs/base_model.yaml` adapted to the `init_args` structure) and then customize it. - -**Command-Line Overrides:** -To override a parameter, you specify its full path. For example: -* `--model.init_args.shared_block_optimizer.freeze_parameters=True` -* `--trainer.max_epochs=50` - -**Note:** Low-Rank Adaptation (LoRA) support has been removed from the codebase. Any legacy configuration keys such as `lora_rank` or `lora_enabled` are currently ignored by the model. - -##### Example 1 – Supervised training run +## Quick Examples -This example runs standard supervised training. +### Example 1 — Supervised training ```bash -python -m foundation_model.scripts.train --config path/to/your/config.yaml \ - --trainer.max_epochs 60 +fm-trainer fit --config path/to/config.yaml --trainer.max_epochs=60 ``` -*Corresponding YAML snippet (`config.yaml`):* + ```yaml +seed_everything: 42 model: class_path: foundation_model.models.FlexibleMultiTaskModel init_args: - # ... other shared_block_dims ... + encoder_config: + type: mlp + hidden_dims: [128, 256, 128] # first = input_dim, last = latent_dim + norm: true task_configs: - - name: example_task_1 + - name: example_task type: REGRESSION dims: [128, 64, 1] data_column: my_property - loss_weight: 0.8 # Optional per-task scaling (defaults to 1.0) - # - name: another_task - # ... - # loss_weight: 1.0 + loss_weight: 0.8 +data: + class_path: foundation_model.data.datamodule.CompoundDataModule + init_args: + descriptor_fn: + class_path: foundation_model.data.composition_sources.PrecomputedDescriptorSource + init_args: { path: "data/descriptors.parquet", composition_column: null } + composition_column: "composition" + batch_size: 64 trainer: max_epochs: 60 ``` -##### Example 2 – Fine-tune only heads (encoder frozen) - -This example demonstrates fine-tuning where the main encoder is frozen. This is achieved by setting `freeze_parameters: true` in the `shared_block_optimizer` configuration. A sequence task (e.g., 'temp_curve') uses an RNN head. +### Example 2 — Freeze the encoder, fine-tune only task heads ```bash -# Assumes config.yaml is set for fine-tuning and includes a sequence task configured with subtype "rnn". - -python -m foundation_model.scripts.train --config path/to/your/config.yaml \ - --model.init_args.shared_block_optimizer.freeze_parameters=True -``` -*YAML snippet (`config.yaml`):* -```yaml -# In your config.yaml -# ... -model: - class_path: foundation_model.models.FlexibleMultiTaskModel - init_args: - # ... - shared_block_optimizer: - # ... - freeze_parameters: true # This freezes the shared encoder - task_configs: - - name: "temp_curve" # Example sequence task - type: "SEQUENCE" - subtype: "rnn" - # ... other settings for temp_curve ... - # ... other tasks ... -# ... +fm-trainer fit --config path/to/config.yaml \ + --model.init_args.shared_block_optimizer.freeze_parameters=True ``` -##### Example 3 – Full fine-tune, Transformer sequence head +`shared_block_optimizer.freeze_parameters` is the model-level knob that locks all encoder +parameters. Use this for head-only fine-tuning on a pre-trained checkpoint. -Full fine-tune: encoder is not frozen (`freeze_parameters: false`). A sequence task uses a Transformer head, configured in YAML. +For a more surgical freeze (encoder + every head NOT in a chosen list + the per-task loss +balancer scalars) see [`scripts/finetune_inverse_heads.py`](src/foundation_model/scripts/finetune_inverse_heads.py). -```bash -# Assumes config.yaml is set for fine-tuning. -# The relevant sequence task should be configured with subtype "transformer" in YAML. +### Example 3 — Transformer encoder -python -m foundation_model.scripts.train --config path/to/your/transformer_encoder.yaml \ - --model.init_args.shared_block_optimizer.freeze_parameters=False -``` -*YAML snippet (`transformer_encoder.yaml`):* ```yaml -# In transformer_encoder.yaml -# ... model: - class_path: foundation_model.models.FlexibleMultiTaskModel init_args: - # ... - shared_block_dims: [128, 256] # Input dimension -> fallback latent dimension encoder_config: type: transformer + input_dim: 128 d_model: 256 num_layers: 4 nhead: 4 dropout: 0.1 use_cls_token: true apply_layer_norm: true - shared_block_optimizer: - # ... - freeze_parameters: false # Encoder is trainable - task_configs: - - name: "temp_dos_transformer" # Example sequence task - type: "SEQUENCE" - subtype: "transformer" # Key: Use Transformer head - d_in: 256 # Input dimension (Tanh-activated latent from encoder) - d_model: 256 # Transformer d_model for the head - nhead: 4 # Transformer nhead - # ... other transformer parameters (num_encoder_layers, dim_feedforward, etc.) - # ... other settings for this task ... - # ... other tasks ... -# ... ``` -> ℹ️ **How the Transformer encoder trains tokens** -> -> * With ``use_cls_token: true`` the task heads consume the contextualised -> ``[CLS]`` embedding. Even though the other feature tokens are not pooled -> explicitly, they still receive gradients through the attention connections to -> the classifier query because their keys and values inform every ``[CLS]`` -> update. -> * Setting ``use_cls_token: false`` switches to mean pooling so every token is -> exposed directly to the supervised loss without relying on masked pre-training; -> gradients are distributed evenly across the sequence length. -> * Both aggregation modes therefore keep all feature tokens in play for -> supervised objectives, and you can choose the variant that best matches your -> task assumptions. - -##### Example 4 – Partial fine-tune (encoder unlocked, specific sequence head) +Both `[CLS]` and mean-pooling aggregations keep every feature token in play for the supervised +loss (gradients reach all tokens through self-attention). -Similar to full fine-tune (encoder trainable). A sequence task uses a 'vector' head, configured in YAML. +### Example 4 — Scaling-law experiment via `task_masking_ratio` -```bash -# Assumes config.yaml is set for fine-tuning. -# The relevant sequence task should be configured with subtype "vec" in YAML. +Each task's `task_masking_ratio` controls the fraction of its valid training samples used (`1.0` += all, `0.5` = half). Re-run training with `task_A.task_masking_ratio` set to `1.0`, `0.5`, +`0.2` in turn and record the final `val_task_A_*` loss — as the ratio drops, validation loss for +that task rises (the scaling-law signal) while other tasks are unaffected. -python -m foundation_model.scripts.train --config path/to/your/vec_head_config.yaml \ - --model.init_args.shared_block_optimizer.freeze_parameters=False -``` -*YAML snippet (`vec_head_config.yaml`):* ```yaml -# In vec_head_config.yaml -# ... -model: - class_path: foundation_model.models.FlexibleMultiTaskModel - init_args: - # ... - shared_block_optimizer: - # ... - freeze_parameters: false # Encoder is trainable - task_configs: - - name: "temp_dos_vector" # Example sequence task - type: "SEQUENCE" - subtype: "vec" # Key: Use fixed vector output head - d_in: 512 # Input dimension (Tanh-activated latent from encoder) - seq_len: 256 # Desired output sequence length for the vector - # ... other vec head parameters ... - # ... other settings for this task ... -# ... +task_configs: + - name: task_A + type: REGRESSION + data_files: "examples/data/task_A.csv" + data_column: "target_A" + dims: [256, 64, 1] + task_masking_ratio: 1.0 # vary this to study the scaling law ``` -These examples should provide a more accurate reflection of how to use `train.py` with your `LightningCLI` setup. - -### Training with Local Data and YAML Configuration (Scaling Law Demo) - -This section demonstrates training `FlexibleMultiTaskModel` from local files with a YAML -config, and how to explore scaling laws by varying a task's data via its per-task -`task_masking_ratio`. Each task owns its own file, joined to the descriptors by a -**composition** column. - -**1. Prepare local data files:** - -* `examples/data/descriptors.csv` — composition-indexed descriptor features: - ```csv - composition,comp_feat_1,comp_feat_2 - mat_1,0.1,0.5 - mat_2,0.2,0.6 - mat_3,0.3,0.7 - mat_4,0.4,0.8 - mat_5,0.5,0.9 - mat_6,0.15,0.55 - mat_7,0.25,0.65 - mat_8,0.35,0.75 - mat_9,0.45,0.85 - mat_10,0.55,0.95 - ``` - -* `examples/data/task_A.csv` — a regression task's own file (composition + target + split): - ```csv - composition,target_A,split - mat_1,1.0,train - mat_2,2.0,train - mat_3,3.0,train - mat_4,1.5,train - mat_5,2.5,train - mat_6,3.5,train - mat_7,4.0,val - mat_8,4.5,val - mat_9,5.0,test - mat_10,5.5,test - ``` - -* `examples/data/task_dos.csv` — a kernel-regression task with sequence target + x-axis. - List-valued cells are strings parseable by `ast.literal_eval`: - ```csv - composition,dos_y,dos_x,split - mat_1,"[0.1,0.2,0.3]","[10,20,30]",train - mat_2,"[0.4,0.5,0.6]","[10,20,30]",train - mat_9,"[1.2,1.3,1.4]","[10,20,30]",test - mat_10,"[1.3,1.4,1.5]","[10,20,30]",test - ``` - Compositions absent from a task's file (e.g. `mat_3` for `task_dos`) are simply masked - out for that task — no need to align files by hand. - -**2. Create the YAML configuration (`examples/configs/demo_scaling_law.yaml`):** -```yaml -seed_everything: 42 +## Inverse design -model: - class_path: foundation_model.models.flexible_multi_task_model.FlexibleMultiTaskModel - init_args: - encoder_config: - type: mlp - hidden_dims: [2, 128, 256] # hidden_dims[0] == input feature count; [-1] == latent_dim - norm: true - task_configs: - - name: "task_A" - type: REGRESSION - data_files: "examples/data/task_A.csv" - data_column: "target_A" - dims: [256, 64, 1] # [latent_dim, hidden, output] - task_masking_ratio: 1.0 # vary this to study the scaling law - optimizer: { lr: 0.001, scheduler_type: "None" } - - name: "dos" - type: KernelRegression - data_files: "examples/data/task_dos.csv" - data_column: "dos_y" - t_column: "dos_x" - x_dim: [256, 64] - t_dim: [16, 8] - optimizer: { lr: 0.001, scheduler_type: "None" } +After training, the same `FlexibleMultiTaskModel` exposes two gradient-based inverse-design +entry points on the model: -data: - class_path: foundation_model.data.datamodule.CompoundDataModule - init_args: - descriptor_fn: - class_path: foundation_model.data.composition_sources.PrecomputedDescriptorSource - init_args: - path: "examples/data/descriptors.csv" - composition_column: "composition" - composition_column: "composition" - task_configs: ${model.init_args.task_configs} # linked from the model - batch_size: 2 - num_workers: 0 - # val_split / test_split / random_seed apply only to compositions lacking a split label +| Method | Optimisation variable | Output is the recipe? | Method-specific knob | +|---|---|---|---| +| `optimize_latent(optimize_space="latent")` | the latent $h$ | no — needs AE decode | `ae_align_scale ∈ [0, 1]` (default 0.5; pulls $h$ onto the AE manifold) | +| `optimize_composition` | element-weight logits $\theta$, with $w = \text{softmax}(\theta)$ | yes — $w$ is the recipe | `diversity_scale ∈ [0, 1]` (default 1.0; per-output entropy penalty) | -trainer: - default_root_dir: "results/logs/scaling_law_demo" - max_epochs: 20 - accelerator: "cpu" - devices: 1 - logger: - - class_path: lightning.pytorch.loggers.CSVLogger - init_args: { save_dir: "${trainer.default_root_dir}", name: "" } -``` +Both methods share the same regression-MSE + classification-cross-entropy backbone; only the +third loss term and the optimisation variable differ. **Reference:** +[docs/inverse_design_algorithms.md](docs/inverse_design_algorithms.md). + +### End-to-end pipeline (PR #18) -**3. Run training:** +`continual_rehearsal_demo` / `continual_rehearsal_full` train an 11-task or 24-task multi-task +model with small replay, then run inverse design on the trained checkpoint: ```bash -fm-trainer fit --config examples/configs/demo_scaling_law.yaml +# 1. Baseline continual rehearsal — saves training/final_model.pt under the output dir. +python -m foundation_model.scripts.continual_rehearsal_demo \ + --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml + +# 2. Targeted retrain of the three inverse-design heads on top of the checkpoint. +python -m foundation_model.scripts.finetune_inverse_heads \ + --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml \ + --checkpoint artifacts/inverse_design_run/training/final_model.pt \ + --output-dir artifacts/inverse_design_run/finetune + +# 3. Per-scenario sweep — 3 scenarios × 8 paths (latent α-sweep + 5 composition configs). +python -m foundation_model.scripts.paper_inverse_3scenarios \ + --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml \ + --checkpoint artifacts/inverse_design_run/finetune/final_model.pt \ + --output-dir artifacts/inverse_design_run/inverse_design ``` -**4. Demonstrating the scaling law via `task_masking_ratio`:** +Each scenario folder ends up with `comparison.png` (bar chart), `element_frequency_heatmap.png` +(per-method × top-K elements with newly-discovered elements highlighted), +`qc_vs_secondary_scatter.png` (per-seed cloud with the seed-baseline layer), and 7× +`seed_to_optimized__*.png` (per-path 1:1 mapping), plus `results.json` + `SUMMARY.md`. -Each task's `task_masking_ratio` controls the fraction of its *valid* (non-NaN) training -samples used (`1.0` = all, `0.5` = half, …), simulating different dataset sizes per task. -Re-run training with `task_A`'s `task_masking_ratio` set to `1.0`, then `0.5`, then `0.2`, and -record the final `val_task_A_*` loss each time. As the ratio drops, the validation loss for -`task_A` generally rises — the expected scaling-law behavior — while other tasks are unaffected. +For the headline messages from the 3-scenario sweep (multi-objective optimisation, element +discovery, comparison of the two paths, conflicting-objective trade-offs), see +[docs/qc_inverse_design_summary.md](docs/qc_inverse_design_summary.md). ## Update History -Update history has been moved to [CHANGES.md](CHANGES.md). +See [CHANGES.md](CHANGES.md). diff --git a/docs/continual_rehearsal_full_PLAN.md b/docs/continual_rehearsal_full_PLAN.md new file mode 100644 index 0000000..04c469d --- /dev/null +++ b/docs/continual_rehearsal_full_PLAN.md @@ -0,0 +1,644 @@ +# Continual Rehearsal — 正式训练 Plan Memo + +> 状态:**评估方法已敲定(§5),实现停在 CPU smoke**(GPU 被占用 + 等 PR #18 合并) +> 路线决策:先走**缩小版**(`sample_per_dataset` 上限 + 减少 epoch 上限),全量留到论文最终复现阶段 +> 日期:2026-05-23 · 分支:`refine-demo-plots`(与 PR #18 同分支) +> 流程蓝本:`run_continual_rehearsal_demo.sh` / `scripts/continual_rehearsal_demo.py` + +--- + +## Handoff — 给接手 PR #18 的 agent + +PR #18 合并时本 workstream(`continual_rehearsal_full`)将随 #18 一并打磨入库。本文档是单一信息源;所有决策与背景都在下方各节。 + +### 当前工作树(未 commit) +本 workstream 的全部产物,都在 working tree 里、未 commit: + +| 路径 | 状态 | 备注 | +|---|---|---| +| `docs/continual_rehearsal_full_PLAN.md` | untracked | **本文件** | +| `src/foundation_model/scripts/continual_rehearsal_full.py` | untracked | 主脚本(24-task 目录、分级 rehearsal、不冻结、EarlyStopping、逐步 pred dump、checkpoint、3 剧本 inverse)。**rebase 时去掉旧版的 pptx+html 产出代码** — 按 §6 改为输出 `SLIDE_PREP.md` + 标准图集,slide 作者外部做 deck | +| `src/foundation_model/scripts/continual_rehearsal_full_test.py` | untracked | 16 tests 通过(catalogue / config / parser) | +| `samples/continual_rehearsal_full_config.toml` | untracked | 默认配置,已含**新版 task_sequence**(12 reg → 7 kr 升序 → 5 tail) | +| `run_continual_rehearsal_full.sh` | untracked | 仿 demo wrapper,日期戳输出 | +| `pyproject.toml` / `uv.lock` | modified | `uv add python-pptx`(runtime dep)— **2026-05-23 决定不再做 PPT 自动生成,这个 dep 现在没有 consumer,rebase 时清掉** | + +`artifacts/continual_rehearsal_full_smoke/` 是 CPU smoke 产物(gitignored,可丢弃)。 + +### 用户已确认的决策(来自本次会话) +1. **任务顺序**(§2):12 regression(自由序)→ 7 kernel regression **按非空行升序** → 5 固定 tail(`formation_energy → magnetic_moment → tc → klat → material_type`)。已同步到 config。 +2. **分级 rehearsal**(§3):固定 tail 作为旧 task 被回放时 `replay_ratio_high=0.10`,其余旧 task `replay_ratio=0.05`。 +3. **不冻结任何层**:每步 encoder + 所有已激活 head 联合训练,仅靠 rehearsal mask 实现增量。 +4. **训练规模**:全量数据 + `max_epochs_per_step=100` + EarlyStopping(`val_final_loss`, patience=8),MPS。 +5. **Inverse design — 强制使用 PR #18 的两条新路径**(§5):旧无约束 `optimize_latent` (`α=0`) **弃用**;每个剧本必须用 (a) `optimize_latent(ae_align_scale=0.5)`([0,1] 范围,默认 0.5)和 (b) `optimize_composition(...)` differentiable KMD 跑 3 个 composition 配置对照。三个剧本目标不变(§5 表)。 +6. **评估细节已敲定**(§5):基于 PR #18 的 `paper_inverse_comparison.py` 实测结论 —— `ae_align_scale=0.5`(empirical sweet spot);composition 路径跑 strict-seed / alloy-palette+blended / random-init 三个配置;种子改为 **17 top-QC 去重 + 3 个固定 Au–Ga–Ln**;alloy palette 见 §5 元素清单。两个用户旋钮 `ae_align_scale` / `diversity_scale` 都在 `[0, 1]` 上,符合直觉(见 §5)。 +7. **PPT 规范**(§6):16:9 · 白底 · 主色 `#2563EB` · 至多两辅助色(`#55A868` / `#C44E52`)· 11 页结构(含 catagoly 短分析与"即插即用 downstream"收尾页)。 +8. **PR #18 依赖与 rebase 步骤**:§11。当前 runner 里 `_inverse_design` 是**占位实现**(沿用 demo 老式 latent-only + `KMD.inverse` 解码),**rebase 时必须替换**。 + +### 已完成 / 验证过 +- `ruff format` / `ruff check` / `mypy src/.../continual_rehearsal_full.py` 全绿。 +- `pytest` 新增 16 tests 通过。 +- **CPU smoke**(`--sample-per-dataset 800 --max-epochs-per-step 2 --accelerator cpu`)端到端 OK,全部交付物产出。注意 smoke 跑的是**旧顺序**(已被本次更新后的新序覆盖),rebase 后建议再跑一次 smoke。 + +### Rebase 后要做的事(详见 §11) +1. 让 `_inverse_design` 改用双路径(latent+λ 与 `optimize_composition`)。 +2. 删除旧无约束 latent 调用与对应配置默认。 +3. **§6 不再生成 PPT / HTML deck**(2026-05-23 修订)。改为:runner 跑完后产出 `SLIDE_PREP.md`(结构化大纲)+ 标准图集 + raw arrays,slide 作者外部完成 deck。如果 `_write_pptx` / `_write_report_html` 还在 runner 里,保留为兼容性占位即可,不要再扩展。 +4. 验证从 demo 模块 import 的 helper 名仍可用(`_apply_plot_style` / `_PALETTE` / `_SCATTER_COLOR` / `_REPORT_TEMPLATE` / `_as_float_array` / `_composition_key` / `_init_kernels`)。 +5. smoke 重跑 → GPU 空闲后启动全量 MPS 正式 run(命令在 §10)。 + +--- + +## 0a. Narrative arc — 论文 / 项目对外叙事链 + +**写作时按这条链组织正文与 slides**: + +1. **问题提出 — 多属性联合优化是材料开发刚需** + 材料设计 = 在 *很多个* 属性约束下找配方(QC 类别 + 形成能 + 热导率 + 磁矩 + …),传统正向 DFT/实验循环成本极高。 + +2. **方案 — 持续学习构建一个 downstream 友好的 foundation model** + 共享 encoder + 多任务头 + rehearsal 增量训练;外部数据形态只要是 + "composition + property(或 category)" 即可一行 task config 接入; + 即插即用 downstream。 + +3. **示例 — Quasicrystal discovery** + 以 QC 形成 + 低形成能 + 高热导率/高磁矩等剧本作 case study; + 展示 model 在三个目标上的反向设计能力。 + +4. **面向实际可用性 — 不只展示 best number,更展示约束的必要** + - latent 路径有 AE-roundtrip 失败模式,靠 `ae_align_scale` 修复; + - composition 路径有 seed-init 锁支撑集问题,靠 `seed_blend` 修复; + - 无约束 composition(含 random init)虽能找到全局最优 QC 点,但落在 Pu/F/Mn 等 **不可合成元素**上 → **架构的搜索能力强,但反过来证明 alloy palette / 领域知识约束的不可或缺**; + - 旋钮命名都按用户直觉(`[0, 1]` + 名字朝向 = 大小语义),文档说明背后算法。 + +5. **系统性分析潜在问题** + 除头条结果外,专门给出**失败模式 + 偏置 + ablation** 三类分析,让读者理解何时该用哪个旋钮、何时该退回 strict-seed、何时该信任 model 的元素发现。 + +6. **下一步 — agent 化的 inverse design 工作台** + 计划围绕本 foundation model 搭建轻量 agent:用户描述目标(自然语言)→ + AI 分解 + 结合领域知识 → 自动设定 `ae_align_scale` / `diversity_scale` / + palette 等优化超参 → 自动跑 `optimize_*` → 给出可视化 result + 报告 PDF。 + +7. **更远期 — AI4S agent 群的一部分** + 把数值模拟 (DFT/MD) + 自动实验 / 表征装置作为额外 agent 接入; + foundation model 在这个 stack 里扮演**快速预测 + 候选生成**的中枢。 + +这条链同时是 §6 的 PPT 大纲,也是论文 Introduction / Discussion / Future +Work 的骨架。**slides 与 ANALYSIS.md 最终输出全部用英文撰写**。 + +--- + +## 0. 目标 + +在 **一个共享 encoder** 上做 continual(增量)多任务学习 + rehearsal 回放,覆盖 4 个无机数据集、全部 task 类型;训练完成后用同一个最终模型跑 **3 个独立的 inverse-design 剧本**。每个阶段的**原始数据 + plot + 每步 checkpoint** 全部落盘到**统一的 run 目录**(training / finetune / inverse_design 是同一父文件夹下的子目录),最后产出 **`SLIDE_PREP.md`(结构化大纲)+ 标准图集 + raw arrays**,slide 作者外部完成 deck。 + +按上一轮确认的 4 个决策执行: +1. 最后固定顺序为 **5 个 task**(重复的 Magnetic moment 是笔误)。 +2. **全量数据 + 早停**(`sample_per_dataset=null`,`max_epochs_per_step=100` 作上限,`EarlyStopping` 监控 `val_final_loss`)。 +3. **新建专用脚本 + 配置**,复用 demo 的 helper,不改动 demo。 +4. **不再自动生成 PPT / HTML deck**(2026-05-23 决定,详见 §6)。runner 跑完后只产出 `SLIDE_PREP.md` + 图 + raw arrays。 + +--- + +## 1. Task 目录(共 24 个监督 task + 常驻 autoencoder) + +非空行数为各任务在对应数据集中可用样本(已核实)。kernel 列均为 `(值序列, T/K 序列)` 且长度一致。 + +### 1a. Regression — 16 个 + +| task 名 | 数据集 | 列 | 非空行数 | +|---|---|---|---| +| density | qc | `Density (normalized)` | ~49034 | +| efermi | qc | `Efermi (normalized)` | ~49034 | +| final_energy | qc | `Final energy per atom (normalized)` | ~49034 | +| **formation_energy** | qc | `Formation energy per atom (normalized)` | ~49034 | +| total_magnetization | qc | `Total magnetization (normalized)` | ~49034 | +| volume | qc | `Volume (normalized)` | ~49034 | +| dielectric_total | qc | `Dielectric total (normalized)` | (子集,介电仅部分材料有) | +| dielectric_ionic | qc | `Dielectric ionic (normalized)` | (子集) | +| dielectric_electronic | qc | `Dielectric electronic (normalized)` | (子集) | +| kp | phonix | `kp[W/mK]` | 6714 | +| **klat** | phonix | `klat[W/mK]` | 6714 | +| **tc** | superconductor | `Transition temperature[K]` | 10465 | +| **magnetic_moment** | magnetic | `Magnetic moment[μB/f.u.]` | 1222 | +| magnetization | magnetic | `Magnetization[A·m²/mol]` | (子集) | +| curie | magnetic | `Curie temperature[K]` | (子集) | +| neel | magnetic | `Neel temperature[K]` | (子集) | + +非-qc 的 raw 回归列(tc/klat/magnetic_moment/magnetization/curie/neel)沿用 demo 处理:`log1p → 用 train 行统计做 z-score → clip 到 ±5`,避免长尾。qc 列已是 normalized,直接用。 + +### 1b. Kernel Regression — 7 个 + +| task 名 | 值列 | t 列 | 非空行数 | +|---|---|---|---| +| dos_density | `DOS density (normalized)` | `DOS energy` | 10321 | +| electrical_resistivity | `Electrical resistivity (normalized)` | `Electrical resistivity (T/K)` | 7334 | +| power_factor | `Power factor (normalized)` | `Power factor (T/K)` | 5223 | +| seebeck | `Seebeck coefficient (normalized)` | `Seebeck coefficient (T/K)` | 11722 | +| thermal_conductivity | `Thermal conductivity (normalized)` | `Thermal conductivity (T/K)` | 6158 | +| zt | `ZT (normalized)` | `ZT (T/K)` | 4971 | +| magnetic_susceptibility | `Magnetic susceptibility (normalized)` | `Magnetic susceptibility (T/K)` | **98 ⚠️** | + +⚠️ `magnetic_susceptibility` 只有 98 行,test/val 后可用样本极少,R² 可能不稳定 —— 仍按要求纳入,但报告里会标注「low-data」。 + +### 1c. Classification — 1 个 + +| task 名 | 列 | 类别 | +|---|---|---| +| **material_type** | `Material type (label)` | 5 类合并为 3 类:AC=DAC+IAC, QC=DQC+IQC, others | + +QC 极不平衡(IQC 213 / IAC 126 / DQC 15 / DAC 13 / others 48667),沿用 demo 的 **inverse-frequency class weights**。 + +### 1d. 数据集来源与规模 + +| 数据集 | 文件 | 行数 | 提供的 task | +|---|---|---|---| +| qc_ac_te_mp (DOS/material) | `qc_ac_te_mp_dos_reformat_20260515.pd.parquet` | 49034 | 9 reg + 7 kr + 1 clf = 17 | +| phonix-db | `phonix-db-filtered_20260425.parquet` | 6714 | 2 reg (kp, klat) | +| NEMAD superconductor | `NEMAD_superconductor_20260425.parquet` | 10465 | 1 reg (tc) | +| NEMAD magnetic | `NEMAD_magnetic_20260419.parquet` | 20271 | 4 reg (magnetic_moment, magnetization, curie, neel) | + +> 注:新 qc 文件**没有**配套 preprocessing pkl(仅有 20250615 版),故 `qc_preprocessing_path=null`,跳过 `dropped_idx` 过滤。各数据集按 composition formula join;qc 用自带 `split` 列(train 34322 / val 7355 / test 7357),其余数据集随机 70/15/15 split。 + +--- + +## 2. 训练顺序(continual 增量) + +分三段以最小化重复训练开销: + +1. **12 个 regression(非 tail)**:顺序自由,按数据集分组一种确定排列保可复现。 +2. **7 个 kernel regression**:**按非空样本数升序**。理由:每个新 task 在 intro 时按 100% mask 跑,之后按 5% mask 回放。把**小**数据集摆前面 → intro 时全量也很便宜;后续每步只回放小数据的 5%。把**大**数据集摆后面 → intro 时一次全量,之后剩余步数少(回放次数也少)。kernel regression 训练单步耗时显著,按这个序能把"100% 全量 + 5%·(N−k) 回放"这项总成本压到最小。 +3. **5 个固定 tail**:`formation_energy → magnetic_moment → tc → klat → material_type`,保证 inverse-design 用到的头(尤其 QC 分类器)最末最新。 + +kernel 数据规模(非空行数,已核实): +`magnetic_susceptibility 98 < zt 4971 < power_factor 5223 < thermal_conductivity 6158 < electrical_resistivity 7334 < dos_density 10321 < seebeck 11722`。 + +完整最终顺序(12 reg + 7 kr 升序 + 5 tail): +``` +# 12 regression (any order, grouped by dataset) +density, efermi, final_energy, total_magnetization, volume, +dielectric_total, dielectric_ionic, dielectric_electronic, # 8 qc reg +magnetization, curie, neel, # 3 magnetic (non-tail) +kp, # 1 phonix (non-tail) +# 7 kernel regression, ascending by non-null row count (cheapest first) +magnetic_susceptibility, zt, power_factor, thermal_conductivity, +electrical_resistivity, dos_density, seebeck, +# 5 fixed tail +formation_energy, magnetic_moment, tc, klat, material_type, +``` + +### Continual rehearsal 机制(沿用 demo + 本次调整) +- AE 头**全程常驻**。 +- **不冻结任何层**:每步 `configure_optimizers` 给 encoder + 所有已激活 task head 各建优化器,联合训练(`freeze_shared_encoder=False`、各 task `freeze_parameters=False`)。增量仅靠 rehearsal mask 实现,非冻结。每步重建 Trainer ⇒ 优化器动量每步重置。 +- 每步用 `model.add_task()` 增加一个新头;新 task `task_masking_ratio=1.0`。**旧 task 回放比例分级**: + - 固定末 5(formation_energy / magnetic_moment / tc / klat / material_type)作为旧 task 被回放时用 **`replay_ratio_high=0.10`**; + - 其余旧 task 用 **`replay_ratio=0.05`**。 +- mask 在每步构建训练集时**抽样一次**(不每 epoch 重抽)。 +- 每步在**固定 test split** 上评估**所有已激活头**,记录 forgetting 轨迹。 + +--- + +## 3. 训练配置 + +| 项 | 值 | 说明 | +|---|---|---| +| `sample_per_dataset` | `null` | 全量数据 | +| `max_epochs_per_step` | `100` | 上限 | +| **EarlyStopping** | monitor=`val_final_loss`, patience≈8, mode=min | 通常提前收敛(**对 demo 的新增**) | +| `accelerator` | `mps` | Mac GPU(CUDA 不可用) | +| `batch_size` | 256 | | +| `n_grids` | 8 | KMD-1d 描述子,可逆 | +| `latent_dim` / `encoder_hidden` | 128 / 256 | | +| `head_lr` / `encoder_lr` | 5e-3 | | +| kernel: `n_kernel`/`kr_lr`/`kr_decay` | 15 / 5e-4 / 5e-5 | | +| `replay_ratio` | 0.05 | 一般旧 task 回放比例 | +| `replay_ratio_high` | 0.10 | 固定末 5 task 作为旧 task 时的回放比例 | +| `random_seed` / `datamodule_random_seed` | 2025 / 42 | 可复现 | + +--- + +## 4. 每阶段落盘的「原始数据 + plot」 + +输出目录:`artifacts/continual_rehearsal_full_/` + +``` +step01_density/ + density_pred.parquet # 新增:test 集 true/pred 原始数组 + density_parity.png +step02_efermi/ … +… +stepNN_material_type/ + material_type_pred.parquet # true/pred 标签 + material_type_confusion.png + <每个已激活 task 的 *_pred.parquet 也在该步落盘,便于看 forgetting 的原始数> +forgetting_trajectory.png +experiment_records.json # 每步 × 每 task 的 metric(at-intro / running) +metrics_table.csv # 新增:扁平化指标表(task, type, dataset, at_intro, final, metric) +final_model.ckpt # 新增:最终模型 checkpoint +final_model_taskconfigs.json # 新增:重建模型所需 task 配置 +inverse_design/ + scenario1_*/ scenario2_*/ scenario3_*/ # 见 §5 +report.html # 自包含 HTML slide deck(沿用 demo) +summary.pptx # 新增:python-pptx +summary.md # 新增:文字 summary doc +``` + +**对 demo 的关键扩展**:除现有「仅新 task 出图」外,每步对**所有已激活 task** dump test 集 `(composition, true, pred)` 为 parquet(kernel task 额外存 t 序列),这样 forgetting 既有曲线也有原始数。 + +--- + +## 5. Inverse design — 3 个独立剧本 + +训练**只跑一次**,最终模型存盘后,对**同一模型**依次跑 3 个剧本,**主目标统一为 QC 概率 ↑**。三个剧本的「目标定义」**保持不变**: + +| 剧本 | 主目标 | 副目标(reg task → target) | 输出子目录 | +|---|---|---|---| +| 1 | QC↑ | formation_energy −2.0;magnetic_moment +2.0 | `scenario1_fe_down_moment_up/` | +| 2 | QC↑ | formation_energy −2.0;tc +2.0;magnetic_moment +2.0 | `scenario2_fe_tc_moment/` | +| 3 | QC↑ | formation_energy −2.0;klat +2.0 | `scenario3_fe_down_klat_up/` | + +### 用户旋钮命名(重要 — 都在 `[0, 1]` 上,直觉对齐) + +PR #18 review 阶段把两个原本"看名字猜不到方向"的反向设计旋钮重新命名 + 限值到 `[0, 1]`: + +| API 旋钮 | 空间 | 0 的含义 | 1 的含义 | 默认 | 内部数学 | +|---|---|---|---|---|---| +| `ae_align_scale` | latent | **不约束**(AE-align 罚项关闭,即 #18 之前的"无约束 latent"失败模式)| **最强约束**(强制 latent 落到 decode/encode 不动子集)| **0.5**(#18 实测 sweet spot)| 在 loss 上加 `α · ‖h − encode(decode(h))‖²` | +| `diversity_scale` | composition | **最强 peaky 惩罚**(强制每个解只用极少元素)| **不约束**(每个解可以用任意多元素,自由)| **1.0**(无惩罚 = 用户默认期望)| 在 loss 上加 `(1 − d) · H(w)`,`H` 是 Shannon entropy | + +两者都是"越大 = 用户角度名字所指的属性越强"——`ae_align_scale=1` 越向 AE 对齐;`diversity_scale=1` 越自由多元素。命名意义直观,不需要看代码就能用。论文里也按这套写。 + +### 优化路径 — **基于 PR #18 实测的双路径** + +旧的无约束 `optimize_latent`(`ae_align_scale=0`)已被证明问题很多(AE round-trip 是瓶颈,#18 实测 QC 0.97→0.35),**不再用作主路径**。每个剧本对**同一组种子 + 同一组目标**,跑下面 **1 个 latent 配置 + 3 个 composition 配置 = 4 条路径**对照: + +| ID | 路径 | 关键参数 | 在 #18 中的作用 | 在 plan 中的作用 | +|---|---|---|---|---| +| L | `optimize_latent` | `ae_align_scale = 0.5`, `optimize_space="latent"` | #18 paper run(16 seeds,剧本 = QC↑/FE↓/klat↑)实测 α=0.5 时 QC=0.96±0.027,FE=+0.92,klat=+1.07,是 [0, 1] 上的 sweet spot;α=0 时 QC 崩到 0.39。 | latent 路径的代表 | +| C-strict | `optimize_composition` | `seed_blend = 1.0`,无 `allowed_elements`,`diversity_scale = 1.0` | **baseline**:复刻"只调种子比例、不引入新元素"。#18 paper run 实测:QC=0.887±0.053,FE=+1.27,klat=+0.76;0/16 越出种子池;解平均 2.6 个元素(成分微调,元素族不变)。 | 锚定 strict-seed 基线 | +| **C-alloy** | `optimize_composition` | `seed_blend = 0.95`,`allowed_elements = ALLOY_PALETTE`(见下),`diversity_scale = 1.0` | **推荐**:#18 paper run 用 **12 元素** alloy palette 实测得 QC=0.870±0.012,FE=+0.84,klat=+1.81;100% 输出落在 Mg–Pd–Al 真实准晶族(Pd 不在任何 seed 里 → **元素发现**);pairwise L1=0.17(收敛紧致)。本 plan 的 **41 元素** palette 已 smoke 实测(见下「预期基线」表):QC 接近,但 pairwise L1 跳到 1.02 — 优化器同时落到 *两簇*(Mg–Ni–Sc–Ga–Al–Ge 与 Al–Pd–Sm/Sc–Ti),表明白名单越宽 → 元素发现越多元,论文可推"模型识别出多个 QC-prone 元素族"的故事。 | **paper 头条结果** | +| C-rand | `optimize_composition` | `initial_weights=None`, `n_starts = N_SEEDS`, `diversity_scale = 1.0` | **对照**:完全脱离种子,揭示模型预测面上的"全局吸引子"(#18 实测是 Ti/Pu/F/Mn — 模型偏置;含 Pu 这种不可合成元素,物理不现实) | 验证 alloy palette 约束的必要性 | + +### "无约束探索能力"专项 ablation — blended-unconstrained vs random-init + +PR #18 paper run 里同时跑了 `composition (blended seed, unconstrained)` 与 `composition (random init)`,两者**只差初值**(一个从种子混 5% uniform 出发,一个从纯随机出发),其他所有约束(无 palette、无 element_step_scale、`diversity_scale=1.0`)全相同。实测: + +| 配置 | QC | FE | klat | top 5 元素 | pairwise L1 | +|---|---:|---:|---:|---|---:| +| blended seed, unconstrained | 0.792 ± 0.022 | −0.68 ± 0.20 | +1.77 ± 0.03 | Ti(16), Pu(11), F(10), S(9), Mn(9) | 0.76 | +| random init, unconstrained | 0.793 ± 0.005 | −0.78 ± 0.03 | +1.77 ± 0.02 | Ti(16), Pu(16), Mn(16), F(16), Zr(10) | 0.10 | + +**这是个系统性发现,不只是冗余信息**: + +1. **同一吸引子**:QC / FE / klat 几乎完全一致,top 元素也都是 Ti/Pu/F/Mn — 两条路径殊途同归。 +2. **强大的搜索能力**:现有 encoder 在 *无约束* 情况下,**无论从哪里出发,都能高保真地找到模型内部所能表达的"最优 QC"点**。这是论文里要展示的**架构性能强项**。 +3. **同时凸显 constraint 的重要性**:这个最优点含 Pu(不可合成)、F(fluoride 不形成 QC)、Mn(在 Mn-rich 系外不易稳定 QC)—— **模型偏置的产物,物理不现实**。无约束的强搜索能力 ↔ 没有约束就误导 — 两面性正好佐证 alloy palette 这类领域知识约束的必要性。 + +所以本 plan **保留 random-init 作为正式对照路径**(C-rand),并在每个剧本的报告里**与 C-alloy 并列对比**: +- 主目标达成(QC):C-rand 0.79 vs C-alloy 0.87,差距合理小; +- 副目标(FE/klat):C-rand 数值更接近 target 边界(无约束自由发挥),但落点在不可合成元素上 → 失去工程价值; +- C-alloy:略损 QC + 副目标向 target 的逼近不如 C-rand 那么"激进",但**100% 落在可合成元素族**,论文头条价值。 + +(其他被尝试但移出主路径的配置:`composition (alloy + peaky, diversity_scale=0)` 在 #18 实测 pairwise L1 从 0.17 跌到 0.01,QC=0.85 几乎不变,输出 16/16 趋同到同一个 Al–Pd–Mg 峰 → 是 *peakiness 旋钮* 不是 diversity 间多样性。如有需要可在论文附录以 ablation 形式呈现。) + +### 预期基线(来自 PR #18 paper run + 41-elem smoke) + +剧本 = `QC↑ / FE↓(target −2) / klat↑(target +2)`,16 seeds。两组都是同一个 checkpoint,差别只在 `allowed_elements`: + +| 路径 | QC after | FE after | klat after | pairwise L1 | mean #elems | top-5 elements | +|---|---:|---:|---:|---:|---:|---| +| latent α=0 (failure) | 0.386 ± 0.315 | +2.46 ± 0.59 | −0.44 ± 0.27 | 1.07 | 5.2 | Na, Mg, Ca, Li, Tm | +| latent α=0.5 (sweet) | 0.960 ± 0.027 | +0.92 ± 1.16 | +1.07 ± 0.31 | 0.82 | 3.4 | Mn, Na, Ca, Mg, Yb | +| latent α=1.0 (max) | 0.951 ± 0.027 | +0.40 ± 1.04 | +1.20 ± 0.35 | 1.06 | 3.6 | Mn, Na, Ca, Mg, Ti | +| C-strict | 0.887 ± 0.053 | +1.27 ± 0.24 | +0.76 ± 0.67 | 1.42 | 2.6 | Mg, Zn, Cu, Al, Ni | +| **C-alloy (12 elem)** | 0.870 ± 0.012 | +0.84 ± 0.03 | +1.81 ± 0.07 | 0.17 | 5.6 | **Al, Pd, Mg, Ga, Ni** | +| **C-alloy (41 elem)** | 0.842 ± 0.018 | +0.68 ± 0.07 | +1.84 ± 0.06 | 1.02 | 6.0 | **Ti, Pd, B, Mg, Ga** | +| C-rand | 0.793 ± 0.005 | −0.78 ± 0.03 | +1.77 ± 0.02 | 0.10 | 6.0 | F, Pu, Mn, Ti, Zr | + +注:本表用作**新 runner 跑通后的健全性检查**——剧本 3(FE↓ + klat↑)的全量训练结果应在以上数量级附近;偏差过大需要查 (a) seed 选择是否含 17+3、(b) `ae_align_scale` 是否传对(0.5 是 sweet spot)、(c) `seed_blend` 是否被覆盖、(d) palette 是否裁错。 + +**41-elem 关键观察**(决定论文叙事): +1. **不再单族塌缩**:pairwise L1 从 0.17 → 1.02,元素发现的多样性显著上升;论文头条从单一"Mg–Pd–Al"扩为"模型识别多个 QC-prone basin"。 +2. **Pd 持续被发现**:14/16 输出含 Pd,但 Pd 不在任何 seed → 强**元素发现**信号("出现率 ≫ 0%、seed 命中率 = 0%")。论文用这个口径作主要 evidence。 +3. **lanthanide 进入解**:Sm 出现在多个输出(Al–Pd–Sm 团簇),扩展到了 Au–Ga–RE 之外的 RE 体系。Au–Ga–Ln 三个 seed 在剧本 1/2 的表现要单独报告。 +4. **strict-seed 与 12-/41-elem palette 结果一致**(QC=0.887/0.888,元素分布几乎相同)——证明 strict-seed 路径对 palette 不敏感(seed 元素早就在任何合理 palette 里),可继续作为不变基线。 + +### 种子(每个剧本共用 — **17 + 3**) + +总 N = 20。 + +- **17 个 top-QC 去重种子**:在 material_type **测试集(test split)** 中按模型预测 QC 概率排序,按**元素系**(element symbols set,忽略比例)去重,每个元素系保留最高的代表,取前 17 个。代码已在 PR #18 `_select_seeds` 中实现 `_dedupe_by_element_system`。 + - **为什么用测试集而不是训练集**:训练集组成模型在持续学习中已经见过,"top-QC" 排序里有 memorization 成分;测试集是 hold-out,QC 排序是模型真正的预测 → 这些 seed 才是**真候选**而不是训练数据的回放。`inverse_seed_split = "test"` 是正式 run 的默认;只有复刻 demo / paper baseline 时才回退到 `"train"`。 +- **3 个固定 Au–Ga–Ln 配方**(强制追加,无论 QC 预测值如何): + - `Au65 Ga20 Gd15` + - `Au65 Ga20 Tb15` + - `Au65 Ga20 Dy15` + + 这三组是已知或推测的 i-QC 形成体系(Au-Ga-RE 家族),用来检验模型在"明确属于实验已实现/接近实现的 Au–Ga 重稀土"区域是否仍给出合理的 QC 概率。如果模型把这 3 个 seed 的 QC 拉得不高,本身就是一个值得在论文里说的发现。 + +`_select_seeds` 改造要点(rebase 时实现): +1. 新增 `inverse_seed_explicit_append: list[str]`(默认 `[]`),追加种子; +2. 改用 `Composition(s).formula` 做归一化避免 `Au65Ga20Gd15` / `Au0.65Ga0.20Gd0.15` 不一致; +3. 通过 `descriptor_fn` 校验追加种子的描述子可计算(不可计算的 fail-fast,给出明确错误); +4. 输出 `seeds.json` 区分 `top_qc_seeds` 与 `explicit_seeds` 两段。 + +### 元素清单(`ALLOY_PALETTE`,**41 个**) + +`composition (alloy)` 路径的 `allowed_elements` 白名单。范围设计原则:覆盖常见准晶元素 + 易于实验的 4/5 周期过渡金属 + 部分易得镧系,**剔除放射性元素**与极冷门难合成的稀有元素。 + +| 类别 | 元素 | 数 | +|---|---|---| +| 轻碱土 | `Mg`, `Ca` | 2 | +| Group 13 | `B`, `Al`, `Ga`, `In`, `Tl` | 5 | +| Group 14 | `Si`, `Ge` | 2 | +| 4th-period TM(Sc–Zn 全) | `Sc`, `Ti`, `V`, `Cr`, `Mn`, `Fe`, `Co`, `Ni`, `Cu`, `Zn` | 10 | +| 5th-period TM(Y–Cd,去 Tc 放射性) | `Y`, `Zr`, `Nb`, `Mo`, `Ru`, `Rh`, `Pd`, `Ag`, `Cd` | 9 | +| 6th-period noble(用于 Au–Ga–Ln seed) | `Au` | 1 | +| 易得镧系(去 Pm 放射性、Tm/Lu 稀贵) | `La`, `Ce`, `Pr`, `Nd`, `Sm`, `Eu`, `Gd`, `Tb`, `Dy`, `Ho`, `Er`, `Yb` | 12 | + +合计 41 个,落到 config 里写成: + +```toml +composition_allowed_elements = [ + "Mg", "Ca", + "B", "Al", "Ga", "In", "Tl", + "Si", "Ge", + "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", + "Y", "Zr", "Nb", "Mo", "Ru", "Rh", "Pd", "Ag", "Cd", + "Au", + "La", "Ce", "Pr", "Nd", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Yb", +] +``` + +这套白名单同时覆盖: +- 经典 i-QC 三元体系(Mg–Zn–RE、Al–Mn、Al–Cu–Fe、Zn–Mg–RE、Ti–Zr–Ni 等); +- d-QC 体系(Al–Ni–Co、Al–Cu–Co 等); +- 重稀土 RE-stabilized 体系(Au–Ga–RE、Mg–Zn–RE)所需的 Au; +- Si/Ge/B 等族 13/14 元素,便于 Mg–Si–Ge 这类边缘体系; +- 3 个追加种子的全部元素(Au/Ga/Gd/Tb/Dy)。 + +### 评估指标(每剧本 × 每路径) + +不止报 QC round-trip 概率。对每条路径在 20 个 seed 上输出: + +| 指标 | 形式 | 说明 | +|---|---|---| +| `qc_after` | mean ± std | 主目标,softmax 后 QC 类(merged)的概率 | +| `_after` | mean ± std | 每个副目标 reg task 的解码后预测值 | +| `dist_to_seed_l1` | mean ± std | 每个解 vs 自己的 seed(latent / strict / alloy)或最近 seed(random)的 L1 距离 | +| **`pairwise_l1`** | scalar | **20 个解两两之间** L1 距离的平均(94 维元素权重单纯形上)。**定义**:对 N=20 个解的所有 C(20,2)=190 对 `(w_i, w_j)`,取 `mean Σ_k |w_i[k] − w_j[k]|`。值域 [0, 2]:0 = 20 个解完全一样;2 = 完全正交。**intra-method 多样性**——同一路径给出的候选库本身有多分散。实测参考:strict seed 1.42(每个 seed 都被推到不同方向,最分散);alloy palette 12-elem 0.17(16 个 Mg-Pd-Al 微变体,紧致);alloy + `diversity_scale=0` 0.01(全部塌成同一个峰)。 | +| `unique_element_systems` | int / N | 20 个解中不同元素集合的数量 | +| `out_of_seed_pool` | int / N | 解的元素超出种子元素池(17+3 合并)的样本数 | +| `mean_n_elements` | float | 每个解的非零元素数平均 | +| `top_elements` | list[(symbol, count)] | 出现在最多解里的前 8 个元素 | +| **`discovered_elements`** | list[(symbol, hit_rate)] | **元素发现专用**:出现率 ≥ 50% **且** 在 20 个 seed 中出现次数 = 0 的元素。这是论文里"模型发现了 X"的硬证据信号(#18 paper run 里 Pd 是 16/16 出现、0/16 seed → hit_rate = 100% 的发现元素)。**该字段为空意味着该路径只是种子比例微调,不是元素发现**。 | +| `elapsed_s` | scalar | 单次 `optimize_*` 调用耗时 | + +`discovered_elements` + `dist_to_seed_l1` + `out_of_seed_pool` 联合回答论文核心问题:"这是元素发现还是种子比例微调?" + +**Raw arrays(必存)**:除上述聚合指标外,`results.json` 还必须包含每路径的 `optimized_weights`(形状 `(B, n_components)`,元素顺序与 `DEFAULT_ELEMENTS` 一致)和 `optimized_descriptor`(形状 `(B, x_dim)`)。这两份原始数组是日后调整图表方案(per-element bar chart、相似度矩阵、ratio 直方图等)的来源——**不用重跑实验**。已在 `paper_inverse_comparison.py` / `eval_inverse_methods.py` 的两个 runner 中加好。模型权重 `final_model.pt` + seeds + targets + 原始数组 = 论文素材的最小可重现集合。 + +### Smoke check(正式 run 前必须通过) + +在 GPU 启动 24-step 训练 + 3 剧本之前,**必须**先跑一次 smoke 验证 §5 的 4 条路径都能产出**合理的数量级**,免得训练几小时后才发现 inverse-design 配置错了。复用 `paper_inverse_comparison.py` + 现成 `artifacts/paper_inverse_design/final_model.pt` checkpoint,用 17+3 的新 seed 方案跑一次: + +```bash +# 1. 在临时输出目录跑 4 路径对比(不污染 artifacts/paper_inverse_design/) +python -m foundation_model.scripts.paper_inverse_comparison \ + --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml \ + --checkpoint artifacts/paper_inverse_design/final_model.pt \ + --output-dir artifacts/smoke_inverse_4path +# 2. 比对 §5「预期基线」表 — 数量级偏差 < 0.1 即通过;偏差大查 seed/参数。 +``` + +注意 `paper_inverse_comparison.py` 目前用的是 12-elem palette + 16 seeds。smoke check 时如想直接用本 plan 的 17+3 seed + 41-elem palette,临时改三处: +- 把 `DEFAULT_ALLOY_PALETTE` 改成本 plan 的 41-elem 列表; +- `_select_seeds` 在 `ContinualRehearsalRunner` 上的调用前后追加 `Au65Ga20Gd15` / `Tb15` / `Dy15`; +- 用 `_dedupe_by_element_system` 取前 17 个 top-QC 后追加这 3 个 → 共 20。 + +正式 runner(`continual_rehearsal_full.py`)落地后再把这些写成正式 config 字段,smoke check 用 runner 自己的 `--inverse-only` 模式即可。 + +### Per-scenario 成功判据(人工核对,不卡 CI) + +每个剧本完成后,论文中要主张"实验有效",至少其中一条必须成立: + +1. **C-alloy 路径**在该剧本主目标 QC ≥ 0.80 **且** 至少一个副 reg target 命中目标方向(`(pred − target) · sign(target − seed_mean)` 在合理范围内); +2. **C-alloy** 的 `discovered_elements` 非空(典型预期:Pd 或某 5th-period TM 被发现); +3. **L (latent)** 的 QC > C-strict 的 QC 至少 0.05 — 否则说明 cycle 罚项也帮不上忙,本剧本对模型来说"无解",论文需诚实标记。 + +若三个剧本里有两个不满足任一条 → 检查训练是否欠拟合(forgetting trajectory 看 tail 5 task 的 final R² / accuracy)或 inverse-design 超参(`inverse_class_weight`、`inverse_steps`、`inverse_lr`)。 + +### 输出落盘 + +``` +inverse_design/ + scenario1_fe_down_moment_up/ + seeds.json # 17 top-QC + 3 explicit, 分两段 + targets.json # 主+副目标定义 + latent_lambda1/ + results.json # 每 seed 一行:qc/reg/decoded_composition + metrics.json # 上表所有聚合指标 + decoded.txt # 人读组成清单(KMD.inverse 解码) + summary.png # 3-panel: QC + 每个 reg target,bar + error + comp_strict_seed/ + ...(同上结构) + comp_alloy_blended/ # **headline** + ... + comp_random_init/ + ... + comparison.png # 4 路径 × 3 panels(QC / FE / 副 reg)并列对比,与 paper 主图同款式 + comparison_diversity.png # 4 路径 × 3 panels(pairwise L1 / out-of-seed / mean_n_elements) + scenario2_fe_tc_moment/ + ... + scenario3_fe_down_klat_up/ + ... + README.md # 三个剧本的 takeaway 一页摘要 +``` + +`comparison.png` 与 `comparison_diversity.png` 都直接复用 `paper_inverse_comparison.py` 的绘图风格(`#2563EB` for composition, `#55A868` for latent, `#C44E52` for target line;x-tick rotation 45)。 + +### 实现路径(rebase 时改写 `_inverse_design`) + +旧的 `_inverse_design` 是占位实现,按下面顺序重写: + +```python +PATHS = [ + ("latent_align0p5", "latent", {"ae_align_scale": 0.5}), + ("comp_strict_seed", "composition", {"seed_blend": 1.0}), + ("comp_alloy_blended", "composition", { + "seed_blend": 0.95, + "allowed_elements": ALLOY_PALETTE, + }), + ("comp_random_init", "composition", {"initial_weights": None, "n_starts": 20}), +] + +for scenario in cfg.inverse_scenarios: + seeds = _select_seeds(...) # 17 top-QC 去重 + 3 explicit + for path_label, mode, extra_kwargs in PATHS: + result = _run_path(model, seeds, scenario.targets, mode, extra_kwargs) + _dump_path_results(result, scenario_dir / path_label) + _plot_comparison(scenario_dir) + _plot_comparison_diversity(scenario_dir) +``` + +参考 `paper_inverse_comparison.py` 的 `_run_latent_method` / `_run_composition_config` 实现一对一复用。 + +--- + +## 6. 交付物 — **`SLIDE_PREP.md`** + 配套图与原始数据(不再生成 PPT / HTML deck) + +**重要变更(2026-05-23 决定)**:之前计划用 `python-pptx` 直接产出 `summary.pptx` + 自包含 +`report.html`。实测下来这两份自动产物**视觉质量不足以拿去给团队/会议用**——layout 死板、 +排版/颜色都需要二次调整。所以本工作流**不再尝试自动生成 PPT / HTML deck**,改为生成一份 +**结构化的 slide-prep markdown**(`SLIDE_PREP.md`,目前已经存在 `artifacts/inverse_design_run/` +作为 preview 模板)和必要的图 + 原始数据,让 *外部 slide 作者* (claude coworker / 人) 自由 +排版做最终 deck。 + +### 落盘内容 + +| 文件 | 内容 | +|---|---| +| `SLIDE_PREP.md` | 9 节 slide 大纲(每节 = 1 张幻灯片或一组),每节写明 takeaway / 要素 / 引用哪张图 / speaker notes | +| `ANALYSIS.md` | 长文分析,speaker notes 的素材库(slide 作者按需引用) | +| `README.md` | 整个 run 目录的索引:top-level 目录结构、每个文件干什么 | +| `comparison.png` | 头条图(QC↑ / Formation energy ↓ / klat ↑ 三联条形图)— 标题含单位与方向箭头 | +| `element_frequency_heatmap.png` | 每个 method × top-25 元素出现次数;**新元素(不在任何 seed 里)的 x 轴 label 加粗 + 下划线** 表示"由优化器发现"| +| `training/forgetting_trajectory.png` | 持续学习的 forgetting 图(per-step × per-task metric)| +| `training/stepNN_/_pred.parquet` | 每步对**每个** active head 的 `(composition, true, pred)` raw — 后续重绘图都不需要重训 | +| `training/stepNN_/_metrics.json` | 同步的 per-task metric(R² / accuracy / MAE / samples)| +| `training/stepNN_/checkpoint.pt` | **每步训练完保存的 model state_dict** — 任意中间阶段可以恢复 | +| `training/final_model.pt` | 训练结束时的最终模型 | +| `finetune/final_model.pt` | 三个 inverse head 微调后的模型 | +| `inverse_design/results.json` | 每条 path × per-seed 原始数组(`optimized_weights` 20×94,`optimized_descriptor` 20×256,预测值,metric)— 重画图不需要重跑优化 | +| `inverse_design/seeds.json` | 20 个 seed(17 top-QC dedup + 3 explicit Au-Ga-Ln)| +| `inverse_design/comparison.png` | 头条图(与上面同一文件,由 `paper_inverse_comparison.py` 输出)| +| `inverse_design/element_frequency_heatmap.png` | 同 above | +| `inverse_design/SUMMARY.md` | auto-generated compact summary 表(每次 paper_inverse_comparison rerun 都会覆盖此文件)| + +### slide 内容大纲(slide 作者参考 `SLIDE_PREP.md` 实现) + +`SLIDE_PREP.md` 列了 9 个 section(≈ 9–11 张幻灯片): + +1. **Experimental goal** — 多属性联合优化是材料开发刚需 +2. **Model structure + inverse-design strategies** — shared encoder + 两条 inverse path 的对比;两个用户旋钮 `ae_align_scale` / `diversity_scale` +3. **Datasets and task types** — 三栏(reg / kr / clf)+ 4 个数据源 +4. **Continual training without catastrophic forgetting** — `training/forgetting_trajectory.png` +5. **Inverse design: scenario setup** — QC↑ + FE↓ + klat↑(plan §5 三个剧本里跑了剧本 3) +6. **Initial seeds + element palette** — 20 seeds(17+3)+ 41-elem `ALLOY_PALETTE`(**用周期表 highlight 形式**)+ 5 个 composition 配置的设计意图 +7. **Results + discussion** — `comparison.png` + `element_frequency_heatmap.png` + 每个 method 的 one-line takeaway + 元素发现叙事(Ti, Pd 100% 在输出/0 seed) +8. **Summary** — 三条 bullet + 头条图缩略 +9. **Future work** — agent-based inverse-design workbench;接入 AI4S agent 群(参考 §0a beats 6–7) + +### 实施备注 + +- 旧版 `_write_pptx` / `_write_report_html` **不要再扩展**。如果还在 runner 里,rebase 时 + 保留为兼容性占位即可,但 plan 不再把它们计为交付物。`SLIDE_PREP.md` 是新的真正交付物。 +- 配色 / 字体只在 *绘图脚本* 里需要约束(`#2563EB` for composition, `#55A868` for latent, + `#C44E52` for target line);slide 模板的视觉风格由 slide 作者完全自由决定。 +- raw arrays 写盘(per-step parquet + ckpt + paper-run `results.json` 全套)是硬要求 —— + 保证日后任何调整图表方案都**不需要重训**。 + +--- + +## 7. 代码改动清单(新建,不动 demo) + +- **新增** `samples/continual_rehearsal_full_config.toml`:全部路径、24-task `task_sequence`、§3 配置、3 个 inverse 剧本(用一个 `[[inverse_scenarios]]` 列表表达)。 +- **新增** `scripts/continual_rehearsal_full.py`: + - 扩充 `TASK_SPECS`(+12 个新 task)与 `TASK_DISPLAY`(中英文友好名)。 + - 复用 demo 的 `descriptor_fn` / KMD / 评估 / 绘图 helper(import 或抽到共享模块)。 + - 训练循环加 **EarlyStopping**(需 val dataloader,CompoundDataModule 已提供)。 + - 每步 dump 所有激活 task 的 `*_pred.parquet`。 + - 训练后 **保存 `final_model.ckpt`** 与 task 配置。 + - inverse design 改为 **遍历 `inverse_scenarios` 列表**,对同一模型跑 3 次,分目录落盘。 + - 新增 `summary.pptx`(python-pptx)+ `summary.md` 生成,保留 `report.html`。 +- **新增** `run_continual_rehearsal_full.sh`:仿 `run_continual_rehearsal_demo.sh`,默认配置 + 日期戳输出目录。 +- 共享逻辑若从 demo 抽取,会保证 demo 行为不变(仅 import,不改语义);并补/跑相关 `*_test.py`。 +- `uv add python-pptx`,更新 `uv.lock`。 + +--- + +## 8. 风险与备注 + +- **耗时**:24 step × 全量数据,即便 MPS + 早停也可能数小时。建议后台跑并定期回看。 +- **magnetic_susceptibility(98 行)**、部分 dielectric / magnetic 列为子集 —— 个别 task R² 可能偏低或不稳定,报告会标注。 +- **MPS 兼容性**:极少数算子在 MPS 上可能缺失;若报错,回退 `accelerator=cpu`(更慢)。 +- raw 回归 z-score 用 train 行统计,避免泄漏;clip ±5。 +- inverse-design 解码用 KMD.inverse(可逆描述子),可能出现 `` 边缘情况(已有 warning 兜底)。 + +--- + +## 9. 执行步骤(确认后) + +1. `uv add python-pptx` 并 sync。 +2. 写 `scripts/continual_rehearsal_full.py` + `samples/continual_rehearsal_full_config.toml` + `run_continual_rehearsal_full.sh`,补测试。 +3. `ruff format && ruff check && mypy src` + 跑相关 `pytest`。 +4. **小规模 smoke**(`--sample-per-dataset 800 --max-epochs-per-step 2`)验证端到端不报错、产物齐全。 +5. 启动**正式全量 run**(后台),完成后核对 forgetting / 5 个目标 task 指标 / 3 个 inverse 剧本 / PPT+MD。 + +--- + +## 10. 执行状态(2026-05-23) + +- ✅ `uv add python-pptx`(runtime dep,已写入 `uv.lock`)。 +- ✅ 新增 `src/foundation_model/scripts/continual_rehearsal_full.py` + `_test.py`(16 tests 通过)、 + `samples/continual_rehearsal_full_config.toml`、`run_continual_rehearsal_full.sh`。 +- ✅ `ruff format` / `ruff check` / `mypy`(new module)全绿。 +- ✅ **CPU smoke**(`--sample-per-dataset 800 --max-epochs-per-step 2 --accelerator cpu`)端到端通过: + 24 个 step、每步全 task `*_pred.parquet`、3 个 inverse 剧本、`final_model.ckpt`、`metrics_table.csv`、 + `forgetting_trajectory.png`、`report.html`、`summary.md`、9 页 `summary.pptx` 全部产出。产物在 + `artifacts/continual_rehearsal_full_smoke/`(可丢弃)。**注意**:smoke 跑的是「旧顺序」(19 free + 5 tail),正式 run 用本次更新后的「12 reg + 7 kr 升序 + 5 tail」新序。 +- ⏸ **正式全量 run 待启动** —— GPU 被另一训练任务占用,且 PR #18 待合并。两者就绪后执行: + + ```bash + ./run_continual_rehearsal_full.sh # 默认 config,MPS,全量数据,输出带日期戳 + ``` + + (会写到 `artifacts/continual_rehearsal_full_/`;建议后台运行。) + +> 修复记录:TOML 中 `[[inverse_scenarios]]` array-of-tables 必须置于文件末尾,否则其后的顶层标量键 +> 会被并入最后一个 scenario 表(已在配置中调整顺序并加注释)。 + +--- + +## 11. PR #18 依赖与 rebase 计划 + +PR#18 在 #17(differentiable KMD upstream)之上落了若干与本工作流相关的改动:算法(cycle-consistency +latent + differentiable composition)会改变 inverse-design 的可选 backend;配色 / plot 风格会经 demo +模块自动透传到本 runner(因为我 `import` 的就是这些 helper)。#18 PR body 明确说 `continual_rehearsal_full.py` 工作流**不在** #18 范围。 + +### #18 引入的可被复用 / 必须感知的部分 + +| 项 | 类型 | 对本 runner 的影响 | +|---|---|---| +| `ae_align_scale` on `optimize_latent` | **必须用** | 给 latent 路径加 AE-alignment 罚项 `α · ‖h − encode(decode(h))‖²`;α=0 = 无约束 = 失败模式(QC→0.39)。**[0, 1] 范围**,默认 0.5(#18 实测 sweet spot)。命名演变:`cycle_consistency_weight` → `ae_cycle_weight` → 最终 `ae_align_scale`(更直觉)。 | +| `optimize_composition`(differentiable KMD) | **必须用** | 94 维 element-weight 单纯形上微分优化;本 plan 跑 3 个配置(strict-seed / alloy-blended / random-init),详见 §5 表。 | +| `seed_blend` (new in #18 fix-up) | **必须用** | composition 路径的核心旋钮:`1.0` = 锁定支撑集(baseline),`0.95` = 允许优化器引入新元素(alloy 路径用此值)。 | +| `diversity_scale` on `optimize_composition` | 可选 | **[0, 1]** 范围,1 = 不约束(默认,最 diverse 多元素),0 = 强惩罚(最 peaky 少元素)。命名演变:`sparsity_weight` → `entropy_weight` → 最终 `diversity_scale`(更直觉)。本 plan 默认 1.0,不主用;仅在论文附录展示 0.0 的 peaky 模式 ablation。 | +| `element_step_scale` hard-lock (#18 PR review fix) | 可选 | `0.0` 现在真正锁定权重(不只是 logit gradient)。本 plan 不主用,但保留作为"锚定 seed 比例 + 只允许新元素进入"的高级手段。 | +| `_dedupe_by_element_system` in `_select_seeds` | **必须用** | top-QC 排序后按元素系去重;本 plan 取前 17 个,再追加 3 个显式 Au–Ga–Ln seed(§5)。 | +| `class_weights` always-registered buffer (#18 PR review fix) | 流程 | state_dict 跨配置 strict-load 不再失败;我们的 `final_model.ckpt` 加载将更稳健。 | +| `material_type` 3-class 合并 + 类权 + plot 通刷(`#2563EB` scatter、widened forgetting、row-normalized confusion、dpi=150) | demo 内部 | 已 `import` 这些 helper、复用同样 merge map;rebase 后视觉自动对齐。需 verify:`_MATERIAL_TYPE_MERGE` / `MATERIAL_TYPE_CLASSES` / `MATERIAL_TYPE_DISPLAY_ORDER` / `_SCATTER_COLOR` 命名是否仍可 import。 | +| `--inverse-only ` (demo 端) | 流程 | demo 跑过后可只重跑 inverse;rebase 时给 `continual_rehearsal_full.py` 加同样的 `--inverse-only` + `--checkpoint`。 | +| `final_model.pt` 强制保存(demo 端) | 流程 | 我们已存 `final_model.ckpt`;可改名为 `final_model.pt` 对齐 demo。 | +| `paper_inverse_comparison.py` | 参考实现 | `_run_latent_method` / `_run_composition_config` 是 §5 双路径的直接母版,rebase 时**复用其函数**,不重复实现。 | + +### Rebase 步骤(#18 合并到 master 之后) + +1. `git fetch origin && git rebase origin/master`;解冲突主要在 demo 的 helper / 配色常量。 +2. **验证 imports 依然成立**:`_apply_plot_style` / `_PALETTE` / `_SCATTER_COLOR` / `_REPORT_TEMPLATE` / `_as_float_array` / `_composition_key` / `_init_kernels`;并新增 `QC_CLASSES` / `_dedupe_by_element_system` from `continual_rehearsal_demo`。 +3. **接入 PR#18 算法(按 §5 表)**: + - `ContinualRehearsalFullConfig` 新增字段: + - `inverse_ae_align_scale: float = 0.5`(latent 路径,[0, 1],0.5 是 #18 sweet spot); + - `inverse_seed_explicit_append: list[str]`(显式追加 seed,**默认即 §5 三个 Au–Ga–Ln**); + - `inverse_n_top_qc_seeds: int = 17`(top-QC 去重后取前 N); + - `inverse_composition_alloy_palette: list[str]`(默认即 §5 的 41 元素清单); + - `inverse_composition_seed_blend: float = 0.95`。 + - `_select_seeds` 改为 "17 top-QC 去重 + 3 显式" 两段拼接,并校验显式 seed 的 descriptor 可计算;输出 `seeds.json` 区分 `top_qc_seeds` / `explicit_seeds`。 + - 重写 `_inverse_design`:对每个 `scenario` 跑 §5 的 4 条路径,复用 `paper_inverse_comparison.py` 的两个 runner 函数(`_run_latent_method` / `_run_composition_config`)。 + - 落盘按 §5 目录结构;`comparison.png` 与 `comparison_diversity.png` 直接复用 paper 风格。 +4. **改写 `_write_pptx`** 为 §6 的 11 页结构、主色 + ≤2 辅助色(提取 `_pptx_palette` 常量)。 +5. **task_sequence 已在 config 中按新序更新**(本次同步),smoke 再跑一次确保通过。 +6. **lint / type / test / smoke**:`ruff format && ruff check && mypy src && pytest src/foundation_model/scripts/continual_rehearsal_full_test.py`,再 + `./run_continual_rehearsal_full.sh ... --sample-per-dataset 800 --max-epochs-per-step 2 --accelerator cpu` 端到端 smoke。 +7. **缩小版正式 run**:受时间成本约束,先跑缩小规模(`--sample-per-dataset 5000`、`--max-epochs-per-step 30` 量级,具体值在 rebase 后根据 smoke 时长定)。全量 run 留到论文最终复现阶段。 +8. **GPU 空闲后**启动缩小版 MPS 正式 run。 diff --git a/docs/inverse_design_algorithms.md b/docs/inverse_design_algorithms.md new file mode 100644 index 0000000..fea443b --- /dev/null +++ b/docs/inverse_design_algorithms.md @@ -0,0 +1,125 @@ +# Inverse-design algorithms — loss & design intent + +Reference for the two inverse-design routines in +[`flexible_multi_task_model.py`](../src/foundation_model/models/flexible_multi_task_model.py): +`optimize_latent` (latent-space gradient descent) and `optimize_composition` (differentiable +KMD). Written as a one-stop sheet so the loss formulas, what each term is *for*, and the +user-facing knobs are all in one place — ready to drop into a slide deck or the paper. + +## A. Latent-space optimisation (`optimize_latent`, `optimize_space="latent"`) + +### Optimisation variable + +**latent vector $h$** (the encoder output). One $h$ per seed; each runs independent gradient +descent. + +### Loss + +$$ +\mathcal{L}_{\text{latent}}(h) \;=\; \underbrace{\sum_{t \in \mathcal{T}_{\text{reg}}} \lambda_t \,\bigl\lVert \hat y_t(h) - \text{target}_t \bigr\rVert^2}_{\text{(1) regression term}} +\;+\;\underbrace{w_{\text{cls}} \cdot \bigl(-\log P\bigl(c = \text{QC} \mid h\bigr)\bigr)}_{\text{(2) classification term}} +\;+\;\underbrace{\alpha \cdot \bigl\lVert h - \tanh\bigl(E(D(h))\bigr) \bigr\rVert^2}_{\text{(3) AE-alignment term}} +$$ + +with + +- $\hat y_t(h)$ = prediction of the $t$-th regression head on $h$; +- $P(c = \text{QC} \mid h)$ = softmax probability of the quasicrystal class out of the QC + classification head on $h$; +- $D(\cdot)$ = AE decoder (latent → input space $\hat x$); $E(\cdot)$ = encoder (input → + latent, with the trailing tanh); +- $\lambda_t$ = the regression task's internal weight (a scalar fixed at training time). + +### What each term is for + +| Term | Design intent | +|---|---| +| **(1) regression term** | Push the latent to a place where every regression head hits its `target_t` (MSE in z-scored space). | +| **(2) classification term** | Push the latent to the region where the QC head emits high $P(c = \text{QC})$. $-\log P$ is the cross-entropy against the target class. `w_cls` sets classification priority relative to regression (use $> 1$ when QC is the primary objective and the regression targets are secondary). | +| **(3) AE-alignment term** | **The crux of this method.** Freely optimised $h$ tends to drift off the AE-learned manifold → decoded $\hat x$ becomes unphysical → the reported composition can't be trusted. This term pulls $h$ toward $\tanh(E(D(h)))$, i.e. the fixed-point of one decode→encode round-trip. $\alpha = 0$ turns the term off (the pre-PR #18 failure mode: QC dropped 0.97 → 0.35); $\alpha = 1$ over-constrains ($h$ effectively locked onto the AE manifold, target attainment drops); **empirical sweet spot $\approx 0.5$**. | + +### Main tunable parameters + +| Parameter | Range | Default | Meaning | +|---|---|---|---| +| `ae_align_scale` (= $\alpha$) | $[0, 1]$ | 0.5 | AE-manifold alignment strength (see (3)). | +| `class_target_weight` (= $w_{\text{cls}}$) | $> 0$ | 1.0 | Classification weight relative to regression. | +| `steps`, `lr` | — | 200, 0.1 | Adam optimisation budget. | +| `num_restarts`, `perturbation_std` | — | 1, 0.0 | Independent restarts with Gaussian jitter on the seed. | + +--- + +## B. Differentiable KMD composition optimisation (`optimize_composition`) + +### Optimisation variable + +**logits $\theta \in \mathbb{R}^n$**, with $n$ = element-table size (default 94, from KMD's +`DEFAULT_ELEMENTS`). The softmax gives the element-weight simplex `w = softmax(θ)` (each row +non-negative, sums to 1). + +### Forward pass + +$$ +w = \text{softmax}(\theta) \;\to\; x = w \cdot K \;\to\; \tilde h = \tanh(E(x)) \;\to\; \text{heads} +$$ + +where $K \in \mathbb{R}^{n \times d_x}$ is the precomputed KMD kernel and $x$ is the +descriptor vector. **`w` itself is the recipe you would report** — no AE decode step. + +### Loss + +$$ +\mathcal{L}_{\text{comp}}(\theta) \;=\; \underbrace{\sum_{t \in \mathcal{T}_{\text{reg}}} \lambda_t \,\bigl\lVert \hat y_t(w) - \text{target}_t \bigr\rVert^2}_{\text{(1) regression term}} +\;+\;\underbrace{w_{\text{cls}} \cdot \bigl(-\log P\bigl(c = \text{QC} \mid w\bigr)\bigr)}_{\text{(2) classification term}} +\;+\;\underbrace{(1 - d) \cdot H(w)}_{\text{(3) entropy / peakiness term}} +$$ + +with + +- $H(w) = -\sum_i w_i \log w_i$ — the per-output-row Shannon entropy; +- $\hat y_t(w)$ and $P(c = \text{QC} \mid w)$ both come from the forward pass above; +- $d$ = `diversity_scale` $\in [0, 1]$. + +### Constraints (not in the loss, but enforced in the implementation) + +| Constraint | How it's enforced | Design intent | +|---|---|---| +| **simplex** | `w = softmax(θ)` | Automatically keeps `w` a valid recipe (non-negative, sums to 1). | +| **`allowed_elements` whitelist** | Masks the logits of disallowed elements to $-\infty$ before every softmax step. | Restrict the search to physically realisable elements (e.g. `ALLOY_PALETTE`, 41 symbols), suppressing model biases toward Pu / F / Cs / etc. | +| **`element_step_scale` soft-freeze / hard-lock** | Soft: multiply the element's logit gradient by the scale before each Adam step. Hard (value = 0): rewrite the softmax output to paste seed values back at locked positions and renormalise unlocked positions over the remaining mass. | Let the user pin certain elements to their seed values ("keep the Au-Ga-RE skeleton; you may only change the rare-earth ratios"). | +| **`seed_blend` mixture** | $w_0 \leftarrow \text{seed\_blend} \cdot \text{seed} + (1 - \text{seed\_blend}) \cdot \text{uniform}_{\text{allowed}}$ | Don't start from a 100 % seed (5 % uniform mass lifts every allowed element's logit from $\log(10^{-12}) \approx -27.6$ to $\log(0.05 / \lvert\text{allowed}\rvert) \approx -7.6$, so Adam can introduce new elements within a few hundred steps — this is the **element-discovery** mechanism). | + +### What each loss term is for + +| Term | Design intent | +|---|---| +| **(1) regression term** | Same as latent — push predictions toward `target_t` via MSE. | +| **(2) classification term** | Same as latent — maximise $P(c = \text{QC})$. | +| **(3) entropy / peakiness term** | **The crux of this method.** Larger $H(w)$ ⇒ flatter `w` ⇒ each solution uses more elements; smaller $H(w)$ ⇒ peakier `w` ⇒ a few elements dominate each solution. $(1 - d)$ is the penalty weight: $d = 1$ turns it off (default — the optimiser uses as many elements as the main objective wants); $d = 0$ is the strongest penalty (forced peaky → binary/ternary recipes, useful as an ablation). **Important**: this is a *per-output-complexity* knob, **not** a between-output diversity knob. Whether the $B$ outputs differ from each other is decided by the loss landscape, not by $d$. | + +### Main tunable parameters + +| Parameter | Range | Default | Meaning | +|---|---|---|---| +| `diversity_scale` (= $d$) | $[0, 1]$ | 1.0 | Per-output element diversity (see (3)). | +| `class_target_weight` (= $w_{\text{cls}}$) | $> 0$ | 1.0 | Classification weight relative to regression. | +| `seed_blend` | $[0, 1]$ | 0.95 | Fraction of seed kept at the start (the rest is uniform, so new elements can enter). | +| `allowed_elements` | symbol list or `"all"` | `"all"` | Element whitelist (hard constraint). | +| `element_step_scale` | float or `{symbol: float}` | 1.0 | Per-element step scaling; `0` = hard-lock to the seed value. | +| `steps`, `lr` | — | 300, 0.05 | Adam optimisation budget over the logits. | + +--- + +## Side-by-side summary + +| | Latent | Composition | +|---|---|---| +| **Optimisation variable** | $h$ (latent vector) | $\theta$, with $w = \text{softmax}(\theta)$ (element-weight simplex) | +| **Where the reported recipe comes from** | $w_{\text{report}}$ inferred from $D(h)$ (an extra AE-decode step) | $w$ itself is the report | +| **Method-specific loss term** | $\alpha \cdot \lVert h - \tanh(E(D(h))) \rVert^2$ (keeps $h$ on the AE manifold) | $(1 - d) \cdot H(w)$ (controls per-solution peakiness) | +| **Failure mode** | $\alpha = 0$: $h$ drifts off the manifold, decoded recipe unphysical (QC 0.97 → 0.35). | `seed_blend = 1.0`: the seed's support set is frozen — no new elements can ever appear. | +| **Method-specific knobs** | `ae_align_scale` | `diversity_scale`, `seed_blend`, `allowed_elements`, `element_step_scale` | + +The shared backbone — (1) regression MSE + (2) classification cross-entropy — is **identical** +between the two methods. They differ *only* in the third loss term and in which variable is +being optimised. diff --git a/docs/inverse_design_extension_notes.md b/docs/inverse_design_extension_notes.md new file mode 100644 index 0000000..a900404 --- /dev/null +++ b/docs/inverse_design_extension_notes.md @@ -0,0 +1,154 @@ +# Inverse-design code map + extension notes + +Snapshot of the inverse-design surface as of PR #18, written so a future session +can extend it (e.g. "exactly K elements", "fix Au at 65 %", "min-weight floor") +without having to reverse-engineer the design. See +[inverse_design_algorithms.md](inverse_design_algorithms.md) for the *math*; this +doc is the *code map*. + +## The two entry points + +| Method | Where | Optimisation variable | Method-specific loss term | +|---|---|---|---| +| `optimize_latent` | [flexible_multi_task_model.py:1735](../src/foundation_model/models/flexible_multi_task_model.py#L1735) | latent `h` | `α · ‖h − tanh(E(D(h)))‖²` (AE-alignment) | +| `optimize_composition` | [flexible_multi_task_model.py:2227](../src/foundation_model/models/flexible_multi_task_model.py#L2227) | element-weight logits `θ`, with `w = softmax(θ)` | `(1 − d) · H(w)` (entropy penalty) | + +Both share: regression-MSE + classification-cross-entropy backbone, opt-in +`record_*_trajectory` flag for per-step capture (used by +[paper_inverse_trajectory.py](../src/foundation_model/scripts/paper_inverse_trajectory.py)). + +## What's already there on the composition path + +User-facing kwargs (validated in `optimize_composition`'s argument block at lines +~2393–2465): + +| Kwarg | Range | What it does | Implementation | +|---|---|---|---| +| `task_targets` | `{task: value}` | MSE target per regression head | inner loop | +| `class_targets` + `class_target_weight` | `{task: class_idx}`, `> 0` | maximise softmax prob of given class | inner loop | +| `diversity_scale` | `[0, 1]`, default 1.0 | 0 = peaky few-element; 1 = no penalty | `(1 − d) · H(w)` added to loss | +| `seed_blend` | `[0, 1]`, default 0.95 | how much seed kept vs uniform-over-allowed at init | `w₀ ← s·seed + (1−s)·uniform` | +| `allowed_elements` | `"all"` or symbol list | hard whitelist | logit mask to `-inf` | +| `element_step_scale` | `float` or `{symbol: float}` | soft per-element gradient scale; `0` = hard-lock to seed value | grad multiplied per element; **lock implemented in `_w_from_logits`** (line 2576) — paste seed values back over softmax + renormalise unlocked positions | + +## The single point of leverage: `_w_from_logits` inside `optimize_composition` + +[flexible_multi_task_model.py:2576-2591](../src/foundation_model/models/flexible_multi_task_model.py#L2576-L2591) + +```python +def _w_from_logits(lg: torch.Tensor) -> torch.Tensor: + """Softmax over logits; mask disallowed elements; hard-lock the chosen ones at seed.""" + w = softmax_with_mask(lg, elem_mask) # whitelist + if locked_mask is None: + return w + # rewrite locked positions to seed values + renormalise unlocked positions to fill 1 − Σ_locked + ... +``` + +**Every simplex-projection / hardening rule belongs here.** It runs once per step, +on every (B × n_components) row, and the gradient flows correctly through any +differentiable rewriting (the existing lock branch is `.detach()`-constant, so +its gradient is 0; new differentiable steps would let gradient flow naturally). +Adding a new constraint = (a) accept a new kwarg in the signature, (b) validate +it in the arg-block, (c) compute any per-step state once before the loop, (d) +apply it inside `_w_from_logits`. + +## Three extensions the user has flagged + +### A. "Specify number of elements" — top-K mass constraint + +**Use case**: "give me exactly 3-element recipes" / "at most K elements". + +**Suggested API**: +```python +optimize_composition(..., max_elements: int | None = None) +``` + +**Implementation sketch** (inside `_w_from_logits`): +```python +if max_elements is not None and max_elements < n_components: + # Top-K hardening: keep the K largest weights per row, zero the rest, renormalise. + topk_vals, topk_idx = w.topk(max_elements, dim=-1) + mask = torch.zeros_like(w).scatter_(-1, topk_idx, 1.0) + w = w * mask + w = w / w.sum(dim=-1, keepdim=True).clamp(min=1e-12) +``` + +Notes: +- `topk` returns the K largest indices — this is non-differentiable at the "K-th + vs (K+1)-th" boundary, but the gradient through the K kept values is correct. + In practice, with `diversity_scale < 1` to drive peakiness *before* the hard + cutoff, the boundary doesn't oscillate. +- Validate `1 ≤ max_elements ≤ n_components` in the arg-block. +- Tests to add: pattern after + [test_optimize_composition_element_step_scale_locks_symbols](../src/foundation_model/models/flexible_multi_task_model_test.py) + — assert that `(w > 1e-6).sum(dim=-1) <= max_elements` for every row of the + output. + +### B. "Fix Au at exactly 65 %" — explicit fixed-amount API + +**Use case**: chemistry-driven prior says "I want exactly 65 % Au, 20 % Ga, +optimiser picks the remaining 15 % freely". + +**Already half-possible** via `element_step_scale = {"Au": 0.0, "Ga": 0.0}` + +seed has those amounts. But it requires constructing a seed; cleaner standalone +API: + +```python +optimize_composition(..., fixed_amounts: Mapping[str, float] | None = None) +``` + +**Implementation sketch**: +- Validate that `sum(fixed_amounts.values()) < 1.0` (need free mass) and each + symbol resolves in `DEFAULT_ELEMENTS`. +- Compute `fixed_w0: (n_components,)` with those positions set, zeros elsewhere. +- Reuse the existing `locked_mask` / `locked_w0` infrastructure + ([line 2557-2574](../src/foundation_model/models/flexible_multi_task_model.py#L2557-L2574)) + — basically: set `locked_mask = (fixed_w0 > 0)` and `locked_w0 = fixed_w0` for + every row in the batch, skip the "needs `initial_weights`" requirement that + the `element_step_scale=0` branch has. +- The existing `_w_from_logits` already does the right paste + renormalise; no + change needed there. + +Tests: assert `w[:, fixed_idx] ≈ fixed_amount` exactly after every step. + +### C. Min-weight floor / "if you use Au, use ≥ 10 %" + +**Use case**: avoid trace-amount appearances (`Pt = 0.5 %`) that are not +synthesisable. + +**Suggested API**: `min_nonzero_weight: float = 0.0`. After top-K (B) or simplex +projection, zero out any weight below the floor, renormalise. + +Implementation goes in the same `_w_from_logits` block, after any top-K / +locking. Same test pattern. + +## What lives where (for the future agent) + +| Concern | Location | +|---|---| +| Method docstring (the user-facing contract) | `optimize_composition` docstring, lines 2243–2370 | +| Kwarg validation | arg-block lines 2393–2465 (mirror the pattern: per-kwarg validation block + a `*_arg` local prepared for the inner loop) | +| One-time setup (locked indices, scaled steps, …) | lines 2536–2574 (before the `for _ in range(steps)` loop) | +| Per-step constraint application | `_w_from_logits` (line 2576) — single point | +| Loss term additions (entropy etc.) | inner loop, line ~2614 (`if diversity_scale < 1.0:`) | +| Per-step trajectory recording | already wired via `record_weights_trajectory` (line 2603 + 2622) — new constraints automatically reflected in the trajectory because we record post-`_w_from_logits` weights | +| Tests | [flexible_multi_task_model_test.py](../src/foundation_model/models/flexible_multi_task_model_test.py) — search `test_optimize_composition_*` (38 existing tests cover the current surface) | + +The latent path (`optimize_latent`, line 1735) is more rigid: the optimisation +variable is `h`, not a simplex, so the same constraints don't translate +naturally. Most new constraint features will only make sense for +`optimize_composition` — call this out in any new kwarg's docstring. + +## Pre-merge checklist + +When extending: keep the surgical-edits pattern. The existing PR already added a +lot; future extensions should be one-kwarg-per-PR with the validation + +`_w_from_logits` change + at least one test that pins the contract end-to-end +(input kwarg → output `w` rows satisfy the constraint). + +Reference tests to mimic: +- `test_optimize_composition_element_step_scale_locks_symbols` — + contract test for an existing constraint kwarg. +- `test_optimize_composition_runs_and_returns_simplex_weights` — smoke test + that the simplex is preserved (rows sum to 1, all ≥ 0). diff --git a/docs/qc_inverse_design_summary.md b/docs/qc_inverse_design_summary.md new file mode 100644 index 0000000..d817cd9 --- /dev/null +++ b/docs/qc_inverse_design_summary.md @@ -0,0 +1,182 @@ +# QC inverse-design study — summary + +One-page summary of the messages the +[continual-rehearsal + inverse-design pipeline](continual_rehearsal_full_PLAN.md) carries. +Written so each bullet maps to either a slide or a paragraph of the paper. + +## Headline messages + +### 1. A multi-task foundation model + gradient-based inverse design is an effective recipe for multi-objective materials optimisation. + +* The same model is trained once on **11 supervised tasks** (7 regression + 1 kernel + regression + 3 inverse-design tail tasks: `formation_energy`, `klat`, `material_type`). + Continual rehearsal (small replay) keeps earlier tasks from collapsing while new ones land — + what we get out is a single descriptor → latent representation that all heads share. +* On top of that one checkpoint we run **three multi-objective scenarios** with no retraining: + (a) FE↓ + magnetisation↑, (b) FE↓ + Tc↑ + magnetisation↑, (c) FE↓ + klat↑. All three are + drivenby gradient descent against a single composite loss: + $\text{MSE-to-target} + w_{\text{cls}} \cdot (-\log P(\text{QC}))$. +* The takeaway: once the encoder + heads are good enough, *adding a new joint objective is just + another `task_targets` entry*. No new training, no new data, no per-objective bespoke model. + +### 2. For QC, the differentiable-KMD composition path gives more controllable and chemically meaningful results than latent-space optimisation. + +* **Latent path** (`optimize_latent`, $\alpha = 0.5$ sweet spot) hits high $P(\text{QC})$ + (~0.92 across scenarios) but produces *predictions in latent space*; the reported recipe has + to be back-decoded through the AE, which still costs target attainment on the secondary + regression objectives. Without the AE-alignment term ($\alpha = 0$), the latent drifts off + the manifold and QC collapses (post-decode 0.97 → 0.35 in PR #18 measurements) — the term is + doing real work. +* **Composition path** (`optimize_composition`) optimises the element-weight simplex directly: + $w = \text{softmax}(\theta) \to x = w \cdot K \to \text{heads}$. The optimised `w` *is the + reported recipe* — no AE round-trip, no fidelity loss between "what the optimiser sees" and + "what gets written down". +* On the headline `comp (seed, 5% all, element list)` configuration the composition path lands + at QC ≈ 0.85 — slightly below latent's 0.92 — but the trade-off buys (a) outputs that are + *valid alloy recipes by construction* (simplex + element whitelist), (b) chemistry-consistent + outputs that cluster around real QC-prone families (Al-Pd-Pt, Mg-Pd-Al, Au-Ga-RE), and + (c) a per-knob control surface that materials scientists can actually reason about + (`allowed_elements`, `element_step_scale`, `seed_blend`). + +## Other points worth keeping in the summary + +### 3. The two methods are complementary, not competing. + +* Latent finds the model's *internal* attractors — it answers "what region of representation + space does the model think is QC-like, regardless of physical realisability". This is + scientifically useful as a diagnostic (it surfaces model biases like the "Ti/Pu/F/Mn" + attractor seen in `comp (random)`). +* Composition is the *recipe generator* — what you'd hand to a synthesist. It's the path + reported as the "paper-headline" output, with the latent runs kept as the baseline / + failure-mode control. +* Use them together: latent shows where the model thinks QC lives; composition shows what to + actually make. + +### 4. The model demonstrably learns chemistry beyond the seed set — "element discovery" is real. + +* Seed set = 20 compositions (17 top-QC element-system-dedup from training + 3 explicit + Au-Ga-RE i-QC formers). Crucially, **Pt, Pd, Re, Hf, Ta are *not* in any seed**. +* After running the constrained composition path with the 48-element `ALLOY_PALETTE`: + * **Pt** is picked up in **6/20 outputs (scenario 1: FE↓ + mag↑)** and **7/20 outputs + (scenario 3: FE↓ + klat↑)**, as part of an Al-Pd-Pt ternary that the model converges to + repeatedly. Pt is not seeded — the optimiser introduced it via the `seed_blend = 0.95` + mechanism (5 % uniform mass over the whitelist) and the gradient signal recognised it as + QC-favourable. + * **Pd** also appears in many scenario-1/3 outputs (not in any seed either) — the Mg-Pd-Al + family was the headline finding in PR #18's smaller-palette run too. + * **Hf, Ta** are picked up occasionally by latent in scenario 3. +* These are not random noise insertions: the same element families show up consistently across + seeds and across scenarios with related objectives, which is consistent with the model having + learned the underlying chemistry of QC-prone compositions, not memorised individual seeds. + +### 5. The user-facing knobs are intuitive and on a $[0, 1]$ scale. + +* `ae_align_scale` $\in [0, 1]$ (latent): 0 = no manifold constraint (fails: $h$ drifts off the + AE-learned region); 1 = strict constraint (over-tight, hurts target attainment); 0.5 is the + empirical sweet spot. +* `diversity_scale` $\in [0, 1]$ (composition): 1 = no entropy penalty (default, lets the + optimiser pick as many elements as the objective wants); 0 = peaky few-element recipes + (forces binary / ternary, useful as an ablation). +* `seed_blend` $\in [0, 1]$ (composition): 0.95 default = keep 95 % seed, mix 5 % uniform over + the allowed elements at the start so the optimiser can actually *introduce new elements* + (this is the element-discovery enabler). +* The point: no need to read the implementation to use these. The knob name predicts the + direction. + +### 6. The 3 scenarios stress-test conflicting objectives. + +* Each scenario combines QC↑ (always primary) with 1–2 regression targets that the model has + *no a-priori reason* to expect can co-exist with QC. FE↓ is the most aggressive ask (drives + toward thermodynamically stable phases, often in tension with the metastable QC family); + klat↑ is also non-obvious for amorphous-leaning compositions. +* The fact that the composition path lands at QC ≈ 0.85 *and* meets the secondary targets + (scenario 3: FE close to 0, klat ≈ 1.6 / target 2.0) on average shows the model isn't simply + collapsing to a single trivial "high-QC" point — it's negotiating the trade-off. +* The 8-path × 3-scenario × 20-seed sweep (480 optimisation runs total) gives enough data to + read the trade-off as a Pareto-like front (the `qc_vs_secondary_scatter.png` figure). + +### 7. The pipeline is end-to-end automated and reproducible. + +* One run produces, for each scenario: + * `comparison.png` — 3-panel bar chart with QC + each reg target across all 8 paths. + * `element_frequency_heatmap.png` — 8 paths × top-25 elements; newly-discovered elements + (not in any seed) are bold-orange on the x-axis. + * `qc_vs_secondary_scatter.png` — per-seed cloud, latent = ○ Greens / composition = △ Blues, + with red dashed target lines. + * `seed_to_optimized__*.png` × 7 — per-method 1:1 mapping (seed → optimised composition) with + per-row `(QC%, ΔFE, Δklat, …)` deltas. + * `trajectories/.{png,gif,html}` per path — **mean-across-seeds** normalised + per-step target curves (static `.png`), and the same curves animated alongside a per-step + element bar chart of the best representative seed (default `.gif`, self-contained interactive + `.html` on request). Raw per-step arrays `(steps, B, T)` for targets and `(steps, B, n_components)` + for weights persisted as `trajectories/.npz` so re-plots don't need to rerun the + optimisation. + * `trajectories_per_seed/seed{NN}/.{png,gif,html}` — **per-(path × seed)** plots + and animations under a seed-major layout (one folder per seed, all 8 paths inside). Each + title carries the seed's composition formula in monospace. Default on; opt out with + `--no-per-seed-trajectories`. + * `results.json` + `SUMMARY.md` — raw arrays and a markdown table. +* Configs, seeds, and the trained checkpoint are all saved per run, so any figure can be + regenerated from `results.json` alone (no re-running the optimisation needed for re-plots). +* The orchestrator (`paper_inverse_3scenarios`) writes the three scenarios into sibling + subfolders so the full study is one directory. + +### 8. Per-step optimisation trajectories explain why the same seed → different scenarios → different recipes. + +Each path's per-step `(targets, weights)` are now persisted as +`/trajectories/.npz`; the corresponding `trajectory__*.png` / +`.gif` / `.html` plots normalise each target to "progress" (0 = seed baseline, +1 = at target) and overlay all targets on one axis. The headline finding is +that **secondary targets converge on very different time-scales**, and the +fastest-converging one locks in the recipe early: + +| Scenario | Path | Per-target trajectory (300 steps) | What it tells you | +|---|---|---|---| +| 3: FE↓ + klat↑ | `comp (seed, 5% all, element list)` | **klat overshoots** to progress ≈ 1.5 by step ~100 then plateaus; **FE crawls** to ~0.32 across all 300 steps | klat dominates the gradient early; once a klat-favourable recipe is locked, the remaining 200 steps only nudge FE in the residual subspace | +| 1: FE↓ + mag↑ | same path | **FE** crawls to ~0.26; **magnetisation** essentially flat at ~0.01 | the model can't find compositions that increase magnetisation without dropping QC — magnetisation is a *stuck* target on this manifold | +| 2: FE↓ + tc↑ + mag↑ | same path | **FE and tc** rise together to ~0.22 by step ~200 (coupled); **magnetisation** plateaus at ~0.08 | when two targets pull on similar element directions they couple cleanly; the orthogonal one (mag) again barely moves | + +Three consequences for interpreting the per-scenario heatmaps: + +* "Same seed, different scenario, different recipe" is not optimisation noise — + it's the *dominant target* taking over the gradient in the first ~50–100 steps + and steering the composition into a different chemistry basin. The trajectory + plot lets you see this happening in real time (left-panel target curves + + right-panel evolving element bars in the GIF / HTML). +* Most paths have flatlined by step ~150–200, so the configured `inverse_steps = + 300` is enough headroom; further steps would mainly refine the slow tail. The + bottleneck is not training time, it's the magnetisation-style "model can't + reach this target from any QC-prone basin" failure mode. +* The klat overshoot (progress > 1.0) is honest signal: the optimiser keeps + pushing klat past the target because the joint loss is still falling on the + other axes. Reading the `seed_to_optimized__*.png` per-row Δreg values gives + the absolute (not relative) numbers if "did it actually overshoot" matters + for the application. + +The "best-per-target representative seed" used in the GIF / HTML's composition +panel is picked by `paper_inverse_trajectory.best_seed_by_target_distance` +(minimises the joint normalised distance to QC = 1 and every reg target). To +see all 20 seeds individually instead of the mean, rerun with +`--per-seed-trajectories`. + +### 9. Constraints and honest limitations. + +* The 48-element `ALLOY_PALETTE` is a *chemistry-aware whitelist*, not a synthesisability + predictor. The optimiser will still happily propose Al-Pd-Pt at a ratio nobody has yet + reported as quasicrystalline — the model's confidence ≠ experimental confirmation. +* The single-task regression heads are trained on z-scored targets, so "FE = −2" means + "2 σ below the dataset mean", not "−2 eV/atom" directly. The summary numbers are best read + as *relative* improvements over the seed baseline (the per-seed `ΔFE` in + `seed_to_optimized__*.png` is the cleanest view of that). +* The latent path's "α = 0 failure" baseline is *deliberately included* in the comparison + figure so the AE-alignment term's contribution is visible — readers occasionally interpret + the α=0 results as the method's overall performance; they're meant to be the *control*, not + the recommendation. + +## Where to go for detail + +* **Method math + per-term design intent**: [docs/inverse_design_algorithms.md](inverse_design_algorithms.md) +* **Plan and rationale for the 3 scenarios + alloy palette**: [docs/continual_rehearsal_full_PLAN.md](continual_rehearsal_full_PLAN.md) +* **Per-scenario outputs**: `artifacts/inverse_design_run/inverse_design/scenario{1,2,3}_*/` + (gitignored — regenerate with `paper_inverse_3scenarios`). +* **Implementation**: [`paper_inverse_comparison.py`](../src/foundation_model/scripts/paper_inverse_comparison.py) (single-scenario runner) → [`paper_inverse_3scenarios.py`](../src/foundation_model/scripts/paper_inverse_3scenarios.py) (orchestrator). diff --git a/docs/trajectory_integration.md b/docs/trajectory_integration.md new file mode 100644 index 0000000..6c2b016 --- /dev/null +++ b/docs/trajectory_integration.md @@ -0,0 +1,327 @@ +# Wiring the trajectory plotting module into another runner + +Short integration note for any runner that calls `model.optimize_latent` / +`model.optimize_composition` and wants the per-step trajectory artefacts +([`paper_inverse_trajectory`](../src/foundation_model/scripts/paper_inverse_trajectory.py)). +The reference wiring lives in +[`paper_inverse_comparison.run()`](../src/foundation_model/scripts/paper_inverse_comparison.py) +— copy that pattern. + +## Where this module lives (and why) + +| File | Role | +|---|---| +| [`paper_inverse_trajectory.py`](../src/foundation_model/scripts/paper_inverse_trajectory.py) | **NEW standalone module** — all trajectory helpers live here (`best_seed_by_target_distance`, `normalize_target_trajectories`, `plot_trajectory_static`, `plot_trajectory_animation`). | +| [`paper_inverse_comparison.py`](../src/foundation_model/scripts/paper_inverse_comparison.py) | Calls the helpers from a private orchestrator `_emit_trajectory_outputs()` (search for it). **This is the reference wiring.** | +| [`continual_rehearsal_common.py`](../src/foundation_model/scripts/continual_rehearsal_common.py) | **Untouched.** It hosts training-loop helpers shared between the two training runners. Trajectory plotting is an analysis-time concern, not a training concern. | +| [`continual_rehearsal_demo.py`](../src/foundation_model/scripts/continual_rehearsal_demo.py) / [`continual_rehearsal_full.py`](../src/foundation_model/scripts/continual_rehearsal_full.py) | **Untouched** by the trajectory feature. To opt in, `import` from `paper_inverse_trajectory` directly — same pattern these runners already use to import `_plot_qc_vs_reg_scatter` / `_plot_seed_to_optimized_mapping` from `paper_inverse_comparison` (see e.g. [`continual_rehearsal_full.py:100-101`](../src/foundation_model/scripts/continual_rehearsal_full.py#L100-L101)). | + +**Rationale**: the `paper_inverse_*` files form the post-training analysis +layer; `continual_rehearsal_common.py` holds the training-time shared helpers. +A single consumer doesn't justify promoting to `common`. If a second consumer +materialises (and the wiring is genuinely shared, not just the plot helpers), +the wiring itself — not the plotters — can graduate to `common` later. + +## What this module produces, per path + +| File | Content | +|---|---| +| `trajectories/.npz` | `targets`: `(steps, B, T_reg)` per-step regression predictions. `weights`: `(steps, B, n_components)` per-step element weights. | +| `trajectories/trajectory__.png` | Static **mean-across-seeds** line plot, x = step, y = normalised progress (0 = seed, 1 = target), all reg targets on one axis. | +| `trajectories/trajectory__.{gif,html,svg}` | Same line + per-step top-K composition bar chart of the **best representative seed**. The chosen seed's composition formula is rendered under the title. Format-controlled by `animation_formats`. | +| `trajectories_per_seed/seed{NN}/.{png,gif,html,svg}` | **Per-(path × seed)** plots/animations under a **seed-major** layout — one folder per seed, with all 8 paths inside. This is the layout you want for "compare how the same seed behaved across paths" workflow. Each title carries the seed's composition formula in monospace under the bold main title. Default on; pass `--no-per-seed-trajectories` to skip (480 PNG + 480 GIF + 480 HTML / scenario when both animation formats are enabled). | + +The npz file is the **single source of truth** — both plots and any later +replot read from it; no need to rerun the optimisation. + +## 3 hook-up steps + +### Step 1 — turn recording on at the model call + +`optimize_composition` and `optimize_latent` each take an opt-in flag (default +`False`, zero cost when off): + +```python +res = model.optimize_composition( + kmd_kernel, task_targets=reg_targets, + # … existing args … + record_weights_trajectory=True, # ← was the only new line +) +# res.weights_trajectory: (steps, B, n_components) — None if flag was False + +res = model.optimize_latent( + initial_input=x_seed, task_targets=reg_targets, + # … existing args … + record_input_trajectory=True, # ← was the only new line +) +# res.input_trajectory: (B, R, steps, input_dim) — None if flag was False +# For latent, decode to weights via runner._kmd.inverse(per_step_inputs[s]) per step. +``` + +The latent flag stores the AE-decoded per-step input; `KMD.inverse` then gives +the per-step element weights (one extra QP solve per step × seed, ~10 % overhead). + +### Step 2 — persist as compressed npz + +Inlining `(steps=300, B=20, n_components=94)` into `results.json` balloons it +to ~36 MB / scenario. Persist alongside instead: + +```python +import numpy as np +traj_dir = out_dir / "trajectories" +traj_dir.mkdir(exist_ok=True) +np.savez_compressed( + traj_dir / f"{slug}.npz", + targets=res.trajectory.cpu().numpy(), # composition: (steps, B, T) + weights=res.weights_trajectory.cpu().numpy(), # composition: (steps, B, n_components) +) +# For latent, ``targets`` is res.trajectory[:, 0, :, :].permute(1, 0, 2) → (steps, B, T) +# and ``weights`` is the per-step KMD.inverse stack → (steps, B, n_components). +``` + +The composition path's slug helper is +[`paper_inverse_comparison._path_slug(r)`](../src/foundation_model/scripts/paper_inverse_comparison.py) +— reuse it so filenames match the existing convention (`latent_align0p25`, +`comp_seed_5_all_element_list`, …). + +### Step 3 — render the figures + +One helper call per path; the module handles both axes (the line plot on +mean-across-seeds, the comp panel on the best representative seed): + +```python +from foundation_model.scripts.paper_inverse_trajectory import ( + best_seed_by_target_distance, normalize_target_trajectories, + plot_trajectory_static, plot_trajectory_animation, +) +from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS + +# 1. Normalise per-step targets to progress fractions (0 = seed, 1 = target): +progress = normalize_target_trajectories( + qc_trajectory=np.tile(qc_after_decode[None, :], (steps, 1)), # see Note A below + reg_trajectory={t: traj_targets[:, :, j] for j, t in enumerate(reg_names)}, + reg_targets=reg_targets, + seed_qc=before_qc, seed_reg=before_reg, +) +progress.pop("QC", None) # we don't have per-step QC; drop the flat synthesised line + +# 2. Pick the representative seed for the animation's comp panel: +best_idx = best_seed_by_target_distance(qc_after_decode, reg_after_decode, reg_targets) + +# 3. Static + animated: +plot_trajectory_static(progress, out_dir / "trajectory.png", title="…") +plot_trajectory_animation( + progress, + per_step_weights=traj_weights[:, best_idx, :], # (steps, n_components) + element_symbols=list(DEFAULT_ELEMENTS), + out_paths_by_format={"gif": out_dir / "trajectory.gif", + "html": out_dir / "trajectory.html"}, # any of gif/html/svg + title="…", +) +``` + +**Note A** — per-step QC: the model's `optimize_*.trajectory` only records the +reg-target predictions, not the QC head's per-step probability. We synthesise a +flat QC line from the end-state `qc_after_decode` so `normalize_target_trajectories` +has something to return, then drop `progress["QC"]` from the plot. If you need the +real per-step QC curve, post-process the per-step weights yourself: + +```python +qc_traj = np.stack( + [_qc_prob(model, torch.tensor(traj_weights[s] @ kmd_kernel_np, dtype=...)) + for s in range(traj_weights.shape[0])] +) # (steps, B) +``` + +That's an extra `B × steps` forward pass — cheap for composition path; for the +latent path it's redundant because the predicts are already on the decoded x. + +## Worked example — `continual_rehearsal_full.py` + +The runner's existing inverse-design layout (a `paths: dict[str, dict[str, +Any]]` per scenario, populated by `_run_latent_path` / `_run_composition_path`, +then plotted in one shot via the existing +`_plot_inverse_scenario` + `_element_frequency_heatmap` + +`_plot_qc_vs_reg_scatter` block) is exactly the right shape — just three +edits: + +### Edit A — `_run_latent_path` (around [continual_rehearsal_full.py:1405](../src/foundation_model/scripts/continual_rehearsal_full.py#L1405)) + +```python +def _run_latent_path(self, model, x_seed, seeds, reg_targets, path_dir, *, + ae_align_scale, label, _qc_prob_fn, _reg_preds_fn, + record_trajectory: bool = False): # ← new arg + # … existing setup … + res = model.optimize_latent( + # … existing args … + record_input_trajectory=record_trajectory, # ← new line + ) + # … existing post-processing populates ``result`` dict … + + if record_trajectory and res.input_trajectory is not None: + # (B, R=1, steps, input_dim) → (steps, B, input_dim) via permute+squeeze + per_step_inputs = res.input_trajectory[:, 0, :, :].cpu().numpy().transpose(1, 0, 2) + per_step_weights = np.stack( + [self._kmd.inverse(per_step_inputs[s]) for s in range(per_step_inputs.shape[0])] + ) # (steps, B, n_components) — one QP per step × seed (~10% overhead) + # ``res.trajectory`` is (B, R=1, steps, T) — squeeze restart → (steps, B, T) + result["trajectory_targets"] = res.trajectory[:, 0, :, :].cpu().numpy().transpose(1, 0, 2) + result["trajectory_weights"] = per_step_weights + return result +``` + +### Edit B — `_run_composition_path` (around [continual_rehearsal_full.py:1465](../src/foundation_model/scripts/continual_rehearsal_full.py#L1465)) + +```python +def _run_composition_path(self, model, kmd_kernel, w_seed, seeds, reg_targets, + path_dir, *, init, blend, allowed, diversity, label, + _qc_prob_fn, _reg_preds_fn, + record_trajectory: bool = False): # ← new arg + # … existing setup … + res = model.optimize_composition( + kmd_kernel, task_targets=reg_targets, + # … existing args … + record_weights_trajectory=record_trajectory, # ← new line + ) + # … existing post-processing populates ``result`` dict … + + if record_trajectory and res.weights_trajectory is not None: + # Composition path: trajectories are already on the right surface, no decoding needed. + result["trajectory_targets"] = res.trajectory.cpu().numpy() # (steps, B, T) + result["trajectory_weights"] = res.weights_trajectory.cpu().numpy() # (steps, B, n_components) + return result +``` + +### Edit C — scenario loop (the `paths` dict block around [continual_rehearsal_full.py:1230](../src/foundation_model/scripts/continual_rehearsal_full.py#L1230)) + +After the existing `_plot_qc_vs_reg_scatter` / `_plot_seed_to_optimized_mapping` +calls, persist the trajectory arrays and emit the new figures: + +```python +from foundation_model.scripts.paper_inverse_comparison import _path_slug +from foundation_model.scripts.paper_inverse_trajectory import ( + best_seed_by_target_distance, normalize_target_trajectories, + plot_trajectory_static, plot_trajectory_animation, +) +from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS + +if record_trajectory: + traj_dir = sc_dir / "trajectories" + traj_dir.mkdir(exist_ok=True) + per_seed_dir = sc_dir / "trajectories_per_seed" if per_seed_trajectories else None + if per_seed_dir is not None: + per_seed_dir.mkdir(exist_ok=True) + + for path_key, p in paths.items(): + if "trajectory_targets" not in p: + continue + slug = _path_slug({"method": p["method"], "label": p["label"], + "align_scale": p.get("ae_align_scale")}) + np.savez_compressed( + traj_dir / f"{slug}.npz", + targets=p["trajectory_targets"].astype(np.float32), + weights=p["trajectory_weights"].astype(np.float32), + ) + + # --- shared data --- + reg_names = list(reg_targets) + traj_targets = p["trajectory_targets"] # (steps, B, T) + traj_weights = p["trajectory_weights"] # (steps, B, n_components) + qc_after = np.asarray(p["qc_after_decode"], dtype=float) + per_row_seeds = list(p.get("seeds", seeds)) # composition strings per row + + # --- mean across-seeds plot/animation --- + reg_traj = {t: traj_targets[:, :, j] for j, t in enumerate(reg_names)} + qc_traj = np.tile(qc_after[None, :], (traj_targets.shape[0], 1)) + progress_mean = normalize_target_trajectories( + qc_trajectory=qc_traj, reg_trajectory=reg_traj, reg_targets=reg_targets, + seed_qc=before_qc, seed_reg=before_reg, + ) + progress_mean.pop("QC", None) + best_idx = best_seed_by_target_distance( + qc_after, {t: np.asarray(p["reg_after_decode"][t]) for t in reg_names}, + reg_targets, + ) + plot_trajectory_static(progress_mean, traj_dir / f"trajectory__{slug}.png", + title=f"Trajectory · {p['label']} (mean over {qc_after.shape[0]} seeds)") + if animation_formats and animation_formats != ("none",): + out_paths = {fmt: traj_dir / f"trajectory__{slug}.{fmt}" for fmt in animation_formats if fmt != "none"} + plot_trajectory_animation( + progress_mean, traj_weights[:, best_idx, :], list(DEFAULT_ELEMENTS), + out_paths_by_format=out_paths, + title=f"Trajectory · {p['label']} (best seed: {best_idx})", + seed_composition=per_row_seeds[best_idx], # ← shows comp under title + ) + + # --- per-seed plot/animation (seed-major layout) --- + if per_seed_dir is not None: + for seed_i in range(qc_after.shape[0]): + seed_dir = per_seed_dir / f"seed{seed_i:02d}" + seed_dir.mkdir(exist_ok=True) + progress_seed = normalize_target_trajectories( + qc_trajectory=qc_traj[:, seed_i:seed_i+1], + reg_trajectory={t: traj_targets[:, seed_i:seed_i+1, j] for j, t in enumerate(reg_names)}, + reg_targets=reg_targets, + seed_qc=before_qc[seed_i:seed_i+1], + seed_reg={t: v[seed_i:seed_i+1] for t, v in before_reg.items()}, + ) + progress_seed.pop("QC", None) + plot_trajectory_static( + progress_seed, seed_dir / f"{slug}.png", + title=f"{p['label']} · seed {seed_i}", + seed_composition=per_row_seeds[seed_i], + ) + if animation_formats and animation_formats != ("none",): + plot_trajectory_animation( + progress_seed, traj_weights[:, seed_i, :], list(DEFAULT_ELEMENTS), + out_paths_by_format={fmt: seed_dir / f"{slug}.{fmt}" + for fmt in animation_formats if fmt != "none"}, + title=f"{p['label']} · seed {seed_i}", + seed_composition=per_row_seeds[seed_i], + ) + + # Free memory before the next path — the trajectories are now on disk. + del p["trajectory_targets"], p["trajectory_weights"] + p["trajectory_file"] = str((traj_dir / f"{slug}.npz").relative_to(sc_dir)) +``` + +`record_trajectory`, `per_seed_trajectories`, and `animation_formats` come from +the CLI flags below; thread them down from `_parse_args` → the inverse-design +entry method that owns the scenario loop. The `before_qc` / `before_reg` +arrays are already computed in that same loop for the existing scatter plot, +so no extra forward passes. + +## Reference wiring + +The full pattern is in +[`paper_inverse_comparison.run()`](../src/foundation_model/scripts/paper_inverse_comparison.py) +(search `_emit_trajectory_outputs`). It also handles the +`--per-seed-trajectories` flag (one plot + animation per `(path × seed)` instead +of the across-seed mean) — same helpers, looped per seed. + +## CLI flags to forward + +If your runner has its own CLI, mirror these three on it (or read them from the +existing config): + +```python +parser.add_argument("--record-trajectory", action=argparse.BooleanOptionalAction, default=True) +parser.add_argument("--per-seed-trajectories", action=argparse.BooleanOptionalAction, default=True) +parser.add_argument( + "--animation-formats", nargs="+", + choices=["gif", "html", "svg", "none"], default=["gif"], +) +``` + +Pass them through to the runner's inverse-design loop so users can switch +formats without code changes. + +### Per-seed title convention + +Per-seed plots show the seed's composition in monospace under the bold main +title (e.g. `seed: Au65 Ga20 Gd15`). The helpers do this automatically when +the optional `seed_composition: str` kwarg is passed to +`plot_trajectory_static` / `plot_trajectory_animation`. Pass `r["seeds"][i]` +(the per-row seed label from the path runner; for `comp (random)` it's the +`random_start_N` placeholder string). The mean plot does the same for the +"best representative seed" picked by `best_seed_by_target_distance`. diff --git a/run_continual_rehearsal_full.sh b/run_continual_rehearsal_full.sh new file mode 100755 index 0000000..7bfba82 --- /dev/null +++ b/run_continual_rehearsal_full.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +# Convenience wrapper for the full / formal continual multi-task rehearsal + inverse-design run. +# Usage: +# ./run_continual_rehearsal_full.sh [CONFIG_PATH] [-- additional CLI args...] +# +# If CONFIG_PATH is omitted, the default full config in samples/ is used. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="${SCRIPT_DIR}" + +DEFAULT_CONFIG="${REPO_ROOT}/samples/continual_rehearsal_full_config.toml" + +CONFIG_FILE="${1:-${DEFAULT_CONFIG}}" +shift || true +EXTRA_ARGS=() +if [[ $# -gt 0 ]]; then + EXTRA_ARGS=("$@") +fi + +if [[ ! -f "${CONFIG_FILE}" ]]; then + echo "Config file not found: ${CONFIG_FILE}" >&2 + exit 1 +fi + +DATE_SUFFIX="$(date +"%y%m%d")" + +function has_flag() { + local flag="$1" + for arg in "${EXTRA_ARGS[@]+"${EXTRA_ARGS[@]}"}"; do + if [[ "${arg}" == "${flag}" || "${arg}" == ${flag}=* ]]; then + return 0 + fi + done + return 1 +} + +# Append a date suffix to the config's output_dir so repeated runs don't clobber prior artifacts. +OUTPUT_OVERRIDE=() +if ! has_flag "--output-dir"; then + OUTPUT_BASE="$(python3 - "$CONFIG_FILE" <<'PY' +import sys +from pathlib import Path + +try: + import tomllib # type: ignore[attr-defined] +except ModuleNotFoundError: # pragma: no cover + try: + import tomli as tomllib # type: ignore + except ModuleNotFoundError: + print("", end="") + sys.exit(0) + +loaded = tomllib.loads(Path(sys.argv[1]).read_text(encoding="utf-8")) +value = loaded.get("output_dir") if isinstance(loaded, dict) else None +if isinstance(value, str) and value.strip(): + print(value.strip(), end="") +PY +)" + if [[ -z "${OUTPUT_BASE}" ]]; then + OUTPUT_BASE="${REPO_ROOT}/artifacts/continual_rehearsal_full" + fi + OUTPUT_BASE="${OUTPUT_BASE%/}" + OUTPUT_OVERRIDE=(--output-dir "${OUTPUT_BASE}_${DATE_SUFFIX}") +fi + +python3 -m foundation_model.scripts.continual_rehearsal_full --config-file "${CONFIG_FILE}" "${OUTPUT_OVERRIDE[@]+"${OUTPUT_OVERRIDE[@]}"}" "${EXTRA_ARGS[@]+"${EXTRA_ARGS[@]}"}" diff --git a/samples/continual_rehearsal_demo_config.toml b/samples/continual_rehearsal_demo_config.toml index 729efda..d4ab2a5 100644 --- a/samples/continual_rehearsal_demo_config.toml +++ b/samples/continual_rehearsal_demo_config.toml @@ -18,7 +18,9 @@ magnetic_path = "data/NEMAD_magnetic_20260419.parquet" phonix_path = "data/phonix-db-filtered_20260425.parquet" output_dir = "artifacts/continual_rehearsal" -task_sequence = ["density", "formation_energy", "dos_density", "power_factor", "material_type", "tc", "pressure", "curie", "magnetization", "neel", "kp", "klat"] +# Last three fixed as formation_energy -> klat -> material_type (the inverse-design heads), so the +# QC classifier in particular is freshest when inverse design runs; the first nine order is free. +task_sequence = ["density", "dos_density", "power_factor", "tc", "pressure", "curie", "magnetization", "neel", "kp", "formation_energy", "klat", "material_type"] replay_ratio = 0.05 # sample_per_dataset = 8000 # uncomment to cap rows per dataset for a faster run @@ -37,8 +39,13 @@ kr_decay = 5e-5 inverse_n_seeds = 16 inverse_steps = 300 inverse_lr = 0.05 -inverse_reg_tasks = ["density", "formation_energy"] -inverse_reg_targets = [1.5, -1.5] +inverse_class_weight = 5.0 # QC probability is the primary objective +inverse_reg_tasks = ["formation_energy", "klat"] +inverse_reg_targets = [-2.0, 2.0] # secondary: low formation energy, high klat +# Seed (starting latent) selection: "top_qc" | "random" | "explicit". +inverse_seed_strategy = "top_qc" +inverse_seed_split = "train" # pool for top_qc / random: train | val | test | all +# inverse_seed_compositions = ["Al65 Cu23 Fe12", "Ho9 Mg34 Zn57"] # used when strategy = "explicit" random_seed = 2025 datamodule_random_seed = 42 diff --git a/samples/continual_rehearsal_demo_config_inverse_baseline.toml b/samples/continual_rehearsal_demo_config_inverse_baseline.toml new file mode 100644 index 0000000..59b3891 --- /dev/null +++ b/samples/continual_rehearsal_demo_config_inverse_baseline.toml @@ -0,0 +1,81 @@ +# Preview baseline for the inverse-design pipeline (per docs/continual_rehearsal_full_PLAN.md). +# +# Compared to the previous "no KR" baseline this version adds ONE kernel-regression task — +# ``dos_density`` — right before the inverse-design tail. Rationale: a KR task in the training +# mix gives the encoder broader inductive coverage of property-vs-T behaviour without paying the +# cost of the full 7 KR tasks. The last 3 tasks remain formation_energy → klat → material_type +# so the inverse heads stay freshest at the end of the continual sequence. +# +# Inverse-design defaults match plan §5: 17 top-QC dedup seeds + 3 explicit Au-Ga-Ln formers +# (Au65Ga20{Gd,Tb,Dy}15), giving N=20 seeds per scenario. ``ae_align_scale=0.5`` is the empirical +# sweet spot from PR #18. final_model.pt is saved so the paper_inverse_comparison + finetune +# scripts can iterate without retraining. +# +# ./run_continual_rehearsal_demo.sh samples/continual_rehearsal_demo_config_inverse_baseline.toml + +qc_data_path = "data/qc_ac_te_mp_dos_reformat_20250615_enforce_quaternary_test.pd.parquet" +qc_preprocessing_path = "data/preprocessing_objects_20250615.pkl.z" +superconductor_path = "data/NEMAD_superconductor_20260425.parquet" +magnetic_path = "data/NEMAD_magnetic_20260419.parquet" +phonix_path = "data/phonix-db-filtered_20260425.parquet" +# Unified pipeline output: one parent folder holds training + finetune + inverse-design. +# This script (continual_rehearsal_demo) writes the rehearsal stage into the ``training/`` +# subfolder; the downstream scripts (finetune_inverse_heads, paper_inverse_comparison) write +# into the sibling ``finetune/`` and ``inverse_design/`` subfolders so all artefacts for one +# pipeline run live under a single directory. +output_dir = "artifacts/inverse_design_run/training" + +# 11 tasks: 7 reg + 1 KR (dos_density) + 3 tail (formation_energy → klat → material_type). +task_sequence = ["density", "tc", "pressure", "curie", "magnetization", "neel", "kp", "dos_density", "formation_energy", "klat", "material_type"] +replay_ratio = 0.05 + +max_epochs_per_step = 20 +batch_size = 256 +n_grids = 8 +latent_dim = 128 +encoder_hidden = 256 +head_hidden_dim = 64 +head_lr = 0.005 +encoder_lr = 0.005 + +# Inverse design (per plan §5). n_seeds = 17 strategy + 3 explicit = 20 total. +inverse_n_seeds = 20 +inverse_steps = 300 +inverse_lr = 0.05 +inverse_class_weight = 5.0 +inverse_ae_align_scale = 0.5 # [0, 1]; default sweet spot from PR #18 +inverse_reg_tasks = ["formation_energy", "klat"] +inverse_reg_targets = [-2.0, 2.0] +inverse_seed_strategy = "top_qc" +inverse_seed_split = "train" +inverse_seed_explicit_append = ["Au65 Ga20 Gd15", "Au65 Ga20 Tb15", "Au65 Ga20 Dy15"] + +random_seed = 2025 +datamodule_random_seed = 42 +accelerator = "auto" +devices = 1 +num_workers = 0 + +# Three inverse-design scenarios (per docs/continual_rehearsal_full_PLAN.md §5). Consumed by the +# orchestrator ``paper_inverse_3scenarios.py``. The plan uses ``magnetic_moment`` as the magnetic +# target, but the current 11-task baseline trained on ``magnetization`` (a sibling NEMAD-magnetic +# column) — they encode the same "more-magnetic" intent in z-scored space, and the substitution +# is the only way to run the 3 scenarios *without* a base-model retrain. Documented in the +# top-level ANALYSIS.md of the run folder. +# +# NOTE: TOML array-of-tables must be at the file END — any top-level scalar after a [[...]] +# header would be absorbed into that table. Keep below this line empty of scalars. +[[inverse_scenarios]] +name = "scenario1_fe_down_magnetic_up" +reg_tasks = ["formation_energy", "magnetization"] +reg_targets = [-2.0, 2.0] + +[[inverse_scenarios]] +name = "scenario2_fe_down_tc_up_magnetic_up" +reg_tasks = ["formation_energy", "tc", "magnetization"] +reg_targets = [-2.0, 2.0, 2.0] + +[[inverse_scenarios]] +name = "scenario3_fe_down_klat_up" +reg_tasks = ["formation_energy", "klat"] +reg_targets = [-2.0, 2.0] diff --git a/samples/continual_rehearsal_demo_config_smoke.toml b/samples/continual_rehearsal_demo_config_smoke.toml index 132f001..03dde5a 100644 --- a/samples/continual_rehearsal_demo_config_smoke.toml +++ b/samples/continual_rehearsal_demo_config_smoke.toml @@ -14,11 +14,11 @@ magnetic_path = "data/NEMAD_magnetic_20260419.parquet" phonix_path = "data/phonix-db-filtered_20260425.parquet" output_dir = "artifacts/continual_rehearsal_smoke" -task_sequence = ["density", "formation_energy", "dos_density", "power_factor", "material_type", "tc", "pressure", "curie", "magnetization", "neel", "kp", "klat"] +task_sequence = ["density", "dos_density", "power_factor", "tc", "pressure", "curie", "magnetization", "neel", "kp", "formation_energy", "klat", "material_type"] replay_ratio = 0.05 sample_per_dataset = 500 -max_epochs_per_step = 1 +max_epochs_per_step = 5 batch_size = 256 n_grids = 8 latent_dim = 128 @@ -29,6 +29,9 @@ n_kernel = 15 inverse_n_seeds = 8 inverse_steps = 50 inverse_lr = 0.05 +inverse_class_weight = 5.0 +inverse_reg_tasks = ["formation_energy", "klat"] +inverse_reg_targets = [-2.0, 2.0] random_seed = 2025 datamodule_random_seed = 42 diff --git a/samples/continual_rehearsal_full_config.toml b/samples/continual_rehearsal_full_config.toml new file mode 100644 index 0000000..822b0e2 --- /dev/null +++ b/samples/continual_rehearsal_full_config.toml @@ -0,0 +1,123 @@ +# Continual multi-task rehearsal + inverse-design — FULL / formal run. +# +# 24 supervised tasks over 4 inorganic datasets + always-on autoencoder: +# qc_ac_te_mp : 9 regression + 7 kernel regression + 1 classification +# phonix-db : 2 regression (kp, klat) +# NEMAD sc : 1 regression (tc) +# NEMAD mag : 4 regression (magnetic_moment, magnetization, curie, neel) +# +# Tasks are added incrementally. The fixed tail (formation_energy, magnetic_moment, tc, klat, +# material_type) is trained last and, when later replayed as an old task, keeps replay_ratio_high +# (10%) of its labels; every other learned task keeps replay_ratio (5%). No layers are frozen — +# the shared encoder + all active heads train jointly each step. EarlyStopping on val_final_loss +# means max_epochs_per_step is only a ceiling. The same final model is then optimized toward 3 +# independent inverse-design scenarios (QC probability primary). +# +# ./run_continual_rehearsal_full.sh samples/continual_rehearsal_full_config.toml +# +# sample_per_dataset = null uses every row (formal run); set an integer to cap per dataset (smoke). + +qc_data_path = "data/qc_ac_te_mp_dos_reformat_20260515.pd.parquet" +qc_preprocessing_path = "" # no matching pkl for 20260515 → skip dropped_idx +superconductor_path = "data/NEMAD_superconductor_20260425.parquet" +magnetic_path = "data/NEMAD_magnetic_20260419.parquet" +phonix_path = "data/phonix-db-filtered_20260425.parquet" +output_dir = "artifacts/continual_rehearsal_full" + +# Three-segment order to minimise total replay cost: +# 1) 12 regression (any order; grouped by dataset for readability) +# 2) 7 kernel regression ascending by non-null row count — kr training is expensive per row, +# so small kr's are introduced earlier (cheap at 100% mask) and then replayed cheaply (5% +# of a small set), while big kr's land late so they're replayed for fewer subsequent steps. +# 3) 5 fixed tail (inverse-design heads, kept freshest; material_type last → QC clf newest). +task_sequence = [ + # --- 12 regression (any order) --- + "density", "efermi", "final_energy", "total_magnetization", "volume", + "dielectric_total", "dielectric_ionic", "dielectric_electronic", # 8 qc reg + "magnetization", "curie", "neel", # 3 magnetic (non-tail) + "kp", # 1 phonix (non-tail) + # --- 7 kernel regression, ascending by non-null row count --- + # magnetic_susceptibility 98 → zt 4971 → power_factor 5223 → thermal_conductivity 6158 + # → electrical_resistivity 7334 → dos_density 10321 → seebeck 11722 + "magnetic_susceptibility", "zt", "power_factor", "thermal_conductivity", + "electrical_resistivity", "dos_density", "seebeck", + # --- 5 fixed tail (inverse-design heads, freshest at the end) --- + "formation_energy", "magnetic_moment", "tc", "klat", "material_type", +] +fixed_tail = ["formation_energy", "magnetic_moment", "tc", "klat", "material_type"] +replay_ratio = 0.05 +replay_ratio_high = 0.10 +# sample_per_dataset = 12000 # uncomment to cap rows per dataset for a faster run + +max_epochs_per_step = 100 +early_stop_patience = 8 +early_stop_min_delta = 1e-4 +batch_size = 256 +n_grids = 8 +latent_dim = 128 +encoder_hidden = 256 +head_hidden_dim = 64 +head_lr = 0.005 +encoder_lr = 0.005 +n_kernel = 15 +kr_lr = 5e-4 +kr_decay = 5e-5 + +# Inverse design (mirrors the PR #18 demo's paper_inverse_comparison.py). +# Each scenario walks 8 configurations on the same seeds: +# 3 latent rows — ae_align_scale ∈ {0, 0.25, 1} (failure / mid / max alignment) +# 5 composition rows — strict seed / blend / blend+palette / blend+palette+low diversity / random +# Every per-config knob (ae_align_scale, seed_blend, diversity_scale) is fixed in +# ``INVERSE_PATH_CONFIGS`` at the module level so the ablation is stable across runs; only the +# palette is overridable here for the rows that whitelist elements. +inverse_n_seeds = 20 # 17 top-QC dedup + 3 explicit Au-Ga-Ln +inverse_steps = 300 +inverse_lr = 0.05 +inverse_class_weight = 5.0 # QC probability is the primary objective +inverse_seed_strategy = "top_qc" +# Use the **test** split for seed selection: the model has seen the train compositions during +# training, so its top-QC ranking on train is part memorisation; test compositions are held out, +# so the ranking is a genuine prediction → seeds are real novel QC candidates, not training data +# the model already saw. (The demo / paper run defaulted to "train" because it was a self- +# contained baseline; the formal full run wants held-out candidates.) +inverse_seed_split = "test" +# Three Au-Ga-Ln formers appended to top-QC seeds (strategy budget reduced by 3). +inverse_seed_explicit_append = ["Au65 Ga20 Gd15", "Au65 Ga20 Tb15", "Au65 Ga20 Dy15"] +# 48-element alloy palette (plan §5, extended 2026-05 with the full Hf–Pt 5d TM row) — restricts +# the C-alloy composition path. Covers classic i-QC / d-QC formers, group 13/14 enablers, the +# full 4th/5th-period TMs (Tc excluded), the full 6th-period TMs (Hf–Au), and easy lanthanides +# (Pm/Tm/Lu excluded). Keep aligned with ``ALLOY_PALETTE`` in ``continual_rehearsal_full.py``. +inverse_composition_allowed_elements = [ + "Mg", "Ca", + "B", "Al", "Ga", "In", "Tl", + "Si", "Ge", + "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", + "Y", "Zr", "Nb", "Mo", "Ru", "Rh", "Pd", "Ag", "Cd", + "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", + "La", "Ce", "Pr", "Nd", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Yb", +] + +random_seed = 2025 +datamodule_random_seed = 42 +accelerator = "mps" +devices = 1 +num_workers = 0 + +# Three independent inverse-design scenarios; primary objective (QC ↑) is implicit for all. +# reg_targets are in normalized / z-scored space: -2.0 ≈ low, +2.0 ≈ high. +# NOTE: array-of-tables must come LAST in the file — any top-level key after a [[...]] header +# would be absorbed into that table by TOML rules. +[[inverse_scenarios]] +name = "scenario1_fe_down_moment_up" +reg_tasks = ["formation_energy", "magnetic_moment"] +reg_targets = [-2.0, 2.0] + +[[inverse_scenarios]] +name = "scenario2_fe_tc_moment" +reg_tasks = ["formation_energy", "tc", "magnetic_moment"] +reg_targets = [-2.0, 2.0, 2.0] + +[[inverse_scenarios]] +name = "scenario3_fe_down_klat_up" +reg_tasks = ["formation_energy", "klat"] +reg_targets = [-2.0, 2.0] diff --git a/src/foundation_model/data/composition_sources.py b/src/foundation_model/data/composition_sources.py index b8a176f..19e3afa 100644 --- a/src/foundation_model/data/composition_sources.py +++ b/src/foundation_model/data/composition_sources.py @@ -35,23 +35,19 @@ _SPLIT_PRECEDENCE: dict[str, int] = {"train": 1, "val": 2, "test": 3} VALID_SPLIT_LABELS = frozenset(_SPLIT_PRECEDENCE) -# Decimal places used when rendering element amounts in a canonical composition key. Six is -# enough for typical fractional stoichiometries while collapsing float-representation noise. -_COMPOSITION_AMOUNT_DECIMALS = 6 - - -def normalize_composition(value: object, *, decimals: int = _COMPOSITION_AMOUNT_DECIMALS) -> str | None: - """Canonical, float-amount composition key shared across every data source. +def normalize_composition(value: object) -> str | None: + """Canonical composition key shared across every data source. Different files spell the same composition differently — a pymatgen ``Composition`` / element-amount ``dict`` (the qc dataset) versus a formula string (NEMAD / phonix), and - ``"Fe3O2"`` versus ``"Fe3.0O2.0"``. This maps any of them through pymatgen to a single - canonical string so the composition-keyed DataModule can join them by exact match. + ``"Fe3O2"`` versus ``"Fe3.0O2.0"``. Routing all of them through pymatgen and returning the + (non-reduced) ``Composition.formula`` yields a single readable canonical string — pymatgen + already normalizes element order and integer-vs-decimal amounts — so the composition-keyed + DataModule joins heterogeneous sources by exact match. Compositions that pymatgen considers + equal collapse to the same key, which is exactly the duplicate the DataModule then keeps once. - The amounts are **not reduced** (``Fe2O3`` ≠ ``Fe4O6``) because some descriptors aggregate - by sum rather than by mean, so the absolute stoichiometry must be preserved. Every amount is - rendered as a fixed-precision float and elements are sorted by symbol, making the key - invariant to integer-vs-decimal spelling and to element ordering. + The amounts are **not reduced** (``Fe2O3`` ≠ ``Fe4O6``) because some descriptors aggregate by + sum rather than by mean, so the absolute stoichiometry must be preserved. Parameters ---------- @@ -59,13 +55,11 @@ def normalize_composition(value: object, *, decimals: int = _COMPOSITION_AMOUNT_ A formula string, a pymatgen ``Composition``, or an element→amount mapping. Mapping entries that are ``None`` or non-positive are dropped (the qc ``composition`` column stores every element with mostly-``None`` amounts). - decimals : int, optional - Decimal places for each amount. Defaults to six. Returns ------- str | None - e.g. ``"Fe2.000000 O3.000000"``; ``None`` if the input is empty or unparseable. + e.g. ``"Fe2 O3"``; ``None`` if the input is empty or unparseable. """ from pymatgen.core.composition import Composition # local import; pymatgen is heavy @@ -86,10 +80,9 @@ def normalize_composition(value: object, *, decimals: int = _COMPOSITION_AMOUNT_ comp = Composition(text) except Exception: return None - amounts = comp.get_el_amt_dict() - if not amounts: + if len(comp) == 0: return None - return " ".join(f"{el}{amounts[el]:.{decimals}f}" for el in sorted(amounts)) + return comp.formula CompositionNormalizer = Callable[[object], str | None] diff --git a/src/foundation_model/data/composition_sources_test.py b/src/foundation_model/data/composition_sources_test.py index 96002a7..e09781b 100644 --- a/src/foundation_model/data/composition_sources_test.py +++ b/src/foundation_model/data/composition_sources_test.py @@ -24,11 +24,11 @@ # --- normalize_composition -------------------------------------------------- -def test_normalize_composition_float_and_order_invariant(): +def test_normalize_composition_formula_and_order_invariant(): # Integer vs decimal spelling and element ordering all collapse to one canonical key. assert normalize_composition("Fe3O2") == normalize_composition("Fe3.0O2.0") assert normalize_composition("Fe2O3") == normalize_composition("O3Fe2") - assert normalize_composition("Fe2O3") == "Fe2.000000 O3.000000" + assert normalize_composition("Fe2O3") == "Fe2 O3" # readable pymatgen .formula # Amounts are NOT reduced: absolute stoichiometry is preserved. assert normalize_composition("Fe2O3") != normalize_composition("Fe4O6") @@ -36,7 +36,7 @@ def test_normalize_composition_float_and_order_invariant(): def test_normalize_composition_accepts_mapping_dropping_none(): # The qc 'composition' column stores every element, mostly None. sparse = {"Fe": 2.0, "O": 3.0, "Na": None, "Cl": 0.0} - assert normalize_composition(sparse) == "Fe2.000000 O3.000000" + assert normalize_composition(sparse) == "Fe2 O3" def test_normalize_composition_invalid_returns_none(): diff --git a/src/foundation_model/models/flexible_multi_task_model.py b/src/foundation_model/models/flexible_multi_task_model.py index 1a612fe..cfdf4af 100644 --- a/src/foundation_model/models/flexible_multi_task_model.py +++ b/src/foundation_model/models/flexible_multi_task_model.py @@ -56,16 +56,25 @@ from .task_head.kernel_regression import KernelRegressionHead from .task_head.regression import RegressionHead -# Named tuple for optimization results +# Named tuple for optimization results. ``input_trajectory`` is None unless the caller passes +# ``record_input_trajectory=True`` to :meth:`optimize_latent` (gated because storing it costs +# O(B·R·steps·input_dim) memory and per-step latent-→-input decodes); when present it has shape +# ``(B, R, steps, input_dim)`` — used by the inverse-design trajectory animations to decode the +# per-step composition without rerunning the optimisation. OptimizationResult = namedtuple( - "OptimizationResult", ["optimized_input", "optimized_target", "initial_score", "trajectory"] + "OptimizationResult", + ["optimized_input", "optimized_target", "initial_score", "trajectory", "input_trajectory"], + defaults=[None], ) # Composition-space optimization (gradient descent over element weights w ∈ simplex). The optimised # w *is* the recipe (no AE-decode round-trip), so it is reported alongside the descriptor x = w @ K. +# ``weights_trajectory`` is None unless the caller passes ``record_weights_trajectory=True`` to +# :meth:`optimize_composition`; when present it has shape ``(steps, B, n_components)``. CompositionOptimizationResult = namedtuple( "CompositionOptimizationResult", - ["optimized_weights", "optimized_descriptor", "optimized_target", "initial_score", "trajectory"], + ["optimized_weights", "optimized_descriptor", "optimized_target", "initial_score", "trajectory", "weights_trajectory"], + defaults=[None], ) @@ -1735,7 +1744,10 @@ def optimize_latent( target_value: torch.Tensor | float | None = None, task_targets: Mapping[str, torch.Tensor | float] | None = None, class_targets: Mapping[str, int | Sequence[int]] | None = None, + class_target_weight: float = 1.0, + ae_align_scale: float = 0.5, optimize_space: str = "input", + record_input_trajectory: bool = False, ) -> OptimizationResult: """ Optimize inputs to drive one or multiple regression heads toward targets or extremes. @@ -1772,6 +1784,25 @@ def optimize_latent( Classification objectives: maps a classification task name to the class index (or indices) whose combined probability should be *maximized*. Adds a ``-log P(target classes)`` term to the objective and may be combined with ``task_targets``. + class_target_weight : float, optional + Multiplier on each classification objective term relative to the regression terms. + Use ``> 1`` to make class probability the primary objective and regression targets + secondary. Default ``1.0``. + ae_align_scale : float, optional + Latent-space optimization only. How hard to pull the optimised latent ``h`` toward the + AE's decode/encode fixed set, on a [0, 1] scale. + + * ``0.0``: **no alignment penalty** — pure unconstrained latent optimisation. This was + shown in PR #18 to fail badly (QC drops from ~0.97 to ~0.35 after the decode/encode + round-trip); recorded for completeness as a failure-mode baseline. + * ``1.0``: **strong alignment penalty** — keeps ``h`` close to ``encode(decode(h))``, + i.e. on the AE's stable manifold. Over-constraining tends to reduce target achievement. + * ``0.5`` (default): the empirical sweet spot from PR #18 experiments. + + Implementation detail (skip if not curious): the loss gets a + ``ae_align_scale · ‖tanh(encoder(AE.decode(h))) − h‖²`` term added. Operates in + **latent space**; orthogonal to :meth:`optimize_composition`'s ``diversity_scale`` + which lives in composition space. optimize_space : str, optional ``"input"`` or ``"latent"``. Default ``"input"``. @@ -1818,6 +1849,8 @@ def optimize_latent( if class_targets is not None: if not isinstance(class_targets, Mapping) or len(class_targets) == 0: raise ValueError("class_targets must be a non-empty mapping of task_name -> class index/indices") + if class_target_weight <= 0: + raise ValueError(f"class_target_weight must be > 0, got {class_target_weight}") class_target_map = {} for name, classes in class_targets.items(): if name not in self.task_heads: @@ -1838,6 +1871,9 @@ def optimize_latent( ) class_target_map[name] = idxs + if not 0.0 <= ae_align_scale <= 1.0: + raise ValueError(f"ae_align_scale must be in [0, 1], got {ae_align_scale}.") + # Legacy single-task path (mode / target_value) only when no target maps are given if target_tasks is None and class_target_map is None: if task_name is None or task_name not in self.task_heads: @@ -1866,9 +1902,16 @@ def optimize_latent( if num_restarts < 1: raise ValueError(f"num_restarts must be >= 1, got {num_restarts}") - # Store original training state + # Store original training state. We also snapshot every parameter's ``requires_grad`` + # because the optimisation only differentiates through ``optim_input`` / ``optim_latent`` + # — leaving ``requires_grad=True`` on the model parameters would let ``loss.backward()`` + # populate stale ``.grad`` tensors on the encoder / heads. Mirrors the same pattern used + # by :meth:`optimize_composition` so a later ``model.fit(...)`` works as expected. was_training = self.training + saved_req_grad: list[tuple[torch.nn.Parameter, bool]] = [(p, p.requires_grad) for p in self.parameters()] self.eval() + for p, _ in saved_req_grad: + p.requires_grad_(False) device = next(self.parameters()).device if initial_input is None: @@ -1935,6 +1978,10 @@ def _class_loss_terms(h_task: torch.Tensor) -> list[torch.Tensor]: optimized_inputs: list[torch.Tensor] = [] optimized_targets: list[torch.Tensor] = [] trajectories: list[torch.Tensor] = [] + # When ``record_input_trajectory=True`` we snapshot the per-step input every iteration + # (input-space: ``optim_input`` directly; latent-space: ``AE.decode(tanh(h))``). Stored on + # CPU to keep GPU memory flat on long trajectories. One per restart, stacked at the end. + input_trajectories: list[torch.Tensor] = [] initial_scores_list: list[torch.Tensor] = [] for restart_idx in range(num_restarts): @@ -1964,6 +2011,7 @@ def _class_loss_terms(h_task: torch.Tensor) -> list[torch.Tensor]: # Optimization loop step_traj: list[torch.Tensor] = [] + step_input_traj: list[torch.Tensor] = [] sign = 1.0 if mode == "max" else -1.0 for step in range(steps): @@ -1995,7 +2043,7 @@ def _class_loss_terms(h_task: torch.Tensor) -> list[torch.Tensor]: expanded_target = expanded_target.expand(pred.shape) loss_terms.append(F.mse_loss(pred, expanded_target)) - loss_terms.extend(_class_loss_terms(h_task)) + loss_terms.extend(class_target_weight * term for term in _class_loss_terms(h_task)) per_task_values_tensor = _stack_scores(per_task_values) # (B, T) if loss_terms: @@ -2012,6 +2060,9 @@ def _class_loss_terms(h_task: torch.Tensor) -> list[torch.Tensor]: # Record history step_traj.append(score_for_history) + if record_input_trajectory: + # Input-space optim variable IS the input — just snapshot it. + step_input_traj.append(optim_input.detach().cpu()) # Get final optimized values with torch.no_grad(): @@ -2028,6 +2079,8 @@ def _class_loss_terms(h_task: torch.Tensor) -> list[torch.Tensor]: optimized_targets.append(per_task_final_tensor) # (B, T) traj_tensor = torch.stack(step_traj, dim=0) # (steps, B, T) trajectories.append(traj_tensor) + if record_input_trajectory: + input_trajectories.append(torch.stack(step_input_traj, dim=0)) # (steps, B, D) else: # optimize_space == "latent" # Latent space optimization: encode X -> optimize latent -> decode via AE @@ -2058,6 +2111,7 @@ def _class_loss_terms(h_task: torch.Tensor) -> list[torch.Tensor]: # Optimization loop step_traj: list[torch.Tensor] = [] + step_input_traj: list[torch.Tensor] = [] sign = 1.0 if mode == "max" else -1.0 for step in range(steps): @@ -2091,7 +2145,13 @@ def _class_loss_terms(h_task: torch.Tensor) -> list[torch.Tensor]: expanded_target = expanded_target.expand(pred.shape) loss_terms.append(F.mse_loss(pred, expanded_target)) - loss_terms.extend(_class_loss_terms(h_task)) + loss_terms.extend(class_target_weight * term for term in _class_loss_terms(h_task)) + if ae_align_scale > 0: + # Pull the optimised latent toward what the AE faithfully reconstructs: + # decode it to a descriptor, re-encode, and penalise the drift in h_task. + # The user-facing knob is [0, 1] with 0 = no penalty / 1 = strong penalty. + re_h_task = torch.tanh(self.encoder(self.task_heads[_AE_TASK](h_task))) + loss_terms.append(ae_align_scale * F.mse_loss(re_h_task, h_task)) per_task_values_tensor = _stack_scores(per_task_values) # (B, T) if loss_terms: @@ -2108,6 +2168,12 @@ def _class_loss_terms(h_task: torch.Tensor) -> list[torch.Tensor]: # Record history step_traj.append(score_for_history) + if record_input_trajectory: + # Latent-space optim: decode the current h via the AE head to recover the + # per-step input. ``no_grad`` keeps this from polluting the optim graph. + with torch.no_grad(): + step_input = self.task_heads[_AE_TASK](torch.tanh(optim_latent)) + step_input_traj.append(step_input.detach().cpu()) # Get final optimized values and reconstruct via AE with torch.no_grad(): @@ -2126,9 +2192,16 @@ def _class_loss_terms(h_task: torch.Tensor) -> list[torch.Tensor]: optimized_targets.append(per_task_final_tensor) # (B, T) traj_tensor = torch.stack(step_traj, dim=0) # (steps, B, T) trajectories.append(traj_tensor) + if record_input_trajectory: + input_trajectories.append(torch.stack(step_input_traj, dim=0)) # (steps, B, D) - # Restore training state + # Restore training state + per-parameter ``requires_grad``. Without the latter, every + # encoder / head parameter would be left frozen for any later ``.fit()`` in the same + # Python session — the symptom is "training silently stops moving the encoder" which + # is annoying to bisect. self.train(was_training) + for p, prev in saved_req_grad: + p.requires_grad_(prev) # Stack outputs opt_input_tensor = torch.stack(optimized_inputs, dim=1) # (B, R, D) @@ -2138,11 +2211,17 @@ def _class_loss_terms(h_task: torch.Tensor) -> list[torch.Tensor]: initial_score_tensor = torch.stack(initial_scores_list, dim=0) # (R, B, T) initial_score_tensor = initial_score_tensor.permute(1, 0, 2) # (B, R, T) + input_traj_tensor: torch.Tensor | None = None + if record_input_trajectory and input_trajectories: + input_traj_tensor = torch.stack(input_trajectories, dim=0) # (R, steps, B, D) + input_traj_tensor = input_traj_tensor.permute(2, 0, 1, 3) # (B, R, steps, D) + return OptimizationResult( optimized_input=opt_input_tensor, optimized_target=opt_target_tensor, initial_score=initial_score_tensor, trajectory=traj_tensor, + input_trajectory=input_traj_tensor, ) def optimize_composition( @@ -2154,9 +2233,13 @@ def optimize_composition( task_targets: Mapping[str, torch.Tensor | float] | None = None, class_targets: Mapping[str, int | Sequence[int]] | None = None, class_target_weight: float = 1.0, - sparsity_weight: float = 0.0, + diversity_scale: float = 1.0, + allowed_elements: str | list[str] = "all", + element_step_scale: float | Mapping[str, float] = 1.0, + seed_blend: float = 0.95, steps: int = 300, lr: float = 0.05, + record_weights_trajectory: bool = False, ) -> CompositionOptimizationResult: """Gradient-based inverse design in **composition space**. @@ -2186,9 +2269,64 @@ def optimize_composition( Same semantics as :meth:`optimize_latent`. Regression targets are matched by MSE; classification objectives add ``-log P(target classes)`` (scaled by ``class_target_weight``). - sparsity_weight : float, optional - Adds a negative-entropy term ``λ · H(w)`` to the loss, pushing ``w`` toward few-element - mixtures. Default 0 (no sparsity pressure). + diversity_scale : float, optional + How spread-out the per-output element mixture is allowed to be, on a [0, 1] scale. + Bigger = more diverse / multi-element per output. + + * ``1.0`` (default): **no penalty** on having many elements — the optimiser is free + to land on a many-element recipe if the main objective likes it. + * ``0.0``: **strong penalty** on having many elements — the optimiser is pushed + toward peaky few-element recipes (e.g. binary alloys). + * ``0.5`` etc.: linearly interpolates between the two. + + The point is to give users a simple [0, 1] knob without needing to know the underlying + math. **Implementation detail** (skip if not curious): the loss gets a + ``(1 − diversity_scale) · H(w)`` term added, where ``H(w) = −Σ w_i log w_i`` is the + Shannon entropy of the per-row weight vector. ``diversity_scale = 1`` zeros that + coefficient (no penalty); ``diversity_scale = 0`` applies the full entropy penalty. + + Important: this is a **per-output complexity** knob, not a diversity-*between*-outputs + knob. Increasing it lets each of the ``B`` outputs individually use more elements; + whether the ``B`` outputs are different from each other (pairwise L1) depends on the + optimisation landscape, not on this knob. + allowed_elements : str | list[str], optional + Element whitelist for the optimisation. ``"all"`` (default) imposes no constraint. + A non-empty list of element symbols (e.g. ``["Mg", "Al", "Cu", "Ni"]``) restricts the + optimisation to those elements only — disallowed elements are forced to ``w = 0`` at + every step (their logits are masked to ``-inf`` inside the softmax), so no gradient + ever lifts them. Symbols are resolved against + :data:`~foundation_model.utils.kmd_plus.DEFAULT_ELEMENTS`; the kernel must therefore + have ``n_components == len(DEFAULT_ELEMENTS)`` when symbols are used. + element_step_scale : float | Mapping[str, float], optional + Per-element constraint on how fast each element's weight can move during optimisation. + A scalar applies uniformly to every element (default ``1.0`` = no constraint). A + symbol→float mapping overrides specific elements while leaving the rest at ``1.0``. + + Two regimes with different mechanics: + + * **Hard lock (value = 0):** ``{"Mg": 0.0, "Al": 0.0}`` pins those elements' weights + at their un-blended ``initial_weights`` values for the entire optimisation. The + implementation rewrites the softmax output to paste seed values back at locked + positions and renormalises the unlocked positions over the remaining + ``1 − Σ_locked seed`` mass — so the locked weights truly do not drift, even when + other (unlocked) logits move. Requires ``initial_weights`` (no seed → nothing to + lock to) and the locked elements must be in ``allowed_elements`` if a whitelist + is set. + * **Soft constraint (0 < value < 1):** the element's logit gradient is multiplied by + the scale before each Adam step, slowing (but not freezing) its drift. ``0.1`` lets + an element move at 10 % of the normal speed. The softmax denominator still couples + it to the rest of the row, so this is a soft preference, not a hard guarantee. + + Symbols are resolved against ``DEFAULT_ELEMENTS`` (kernel alignment required, as above). + seed_blend : float, optional + How much of the (per-row) seed prior to keep when ``initial_weights`` is given; + ``w0 ← seed_blend · seed + (1 − seed_blend) · uniform_over_allowed``. Default ``0.95`` + (5 % uniform mass spread over the allowed elements). The blend lifts non-seed-element + logits from ``log(1e-12) ≈ −27.6`` (effectively unreachable by Adam in a few hundred + steps) to ``log(0.05 / |allowed|) ≈ −7.6``, so the optimiser can introduce new elements + when they help the objective. Set to ``1.0`` to reproduce the strict seed-only behaviour + (no new elements can enter the support set); ``0.0`` makes the seed irrelevant and + starts from uniform. Ignored when ``initial_weights is None``. steps : int Adam optimisation steps. Default 300. lr : float @@ -2254,8 +2392,69 @@ def optimize_composition( if target_tasks is None and class_target_map is None: raise ValueError("Provide at least one of task_targets / class_targets.") - if sparsity_weight < 0: - raise ValueError(f"sparsity_weight must be >= 0, got {sparsity_weight}") + if not 0.0 <= diversity_scale <= 1.0: + raise ValueError(f"diversity_scale must be in [0, 1], got {diversity_scale}.") + if not 0.0 <= seed_blend <= 1.0: + raise ValueError(f"seed_blend must be in [0, 1], got {seed_blend}") + + # --- Per-element constraints (symbol-based) ----------------------------------------------- + # ``allowed_elements`` is a hard whitelist; ``element_step_scale`` is a soft per-element + # learning-rate multiplier (0 = frozen). Symbol-based inputs are resolved against the + # bundled :data:`DEFAULT_ELEMENTS` registry — see argument docs above. + from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS # local import; small list + + elem_mask_arg: torch.Tensor | None = None + if isinstance(allowed_elements, str): + if allowed_elements != "all": + raise ValueError(f"allowed_elements as a string must be 'all'; got {allowed_elements!r}.") + # "all": no constraint, leave elem_mask_arg as None. + elif isinstance(allowed_elements, (list, tuple)): + if len(allowed_elements) == 0: + raise ValueError("allowed_elements list must be non-empty.") + sym_to_idx = {s: i for i, s in enumerate(DEFAULT_ELEMENTS)} + bad = [s for s in allowed_elements if s not in sym_to_idx] + if bad: + raise ValueError(f"Unknown element symbol(s) in allowed_elements: {bad}.") + if n_components != len(DEFAULT_ELEMENTS): + raise ValueError( + f"allowed_elements as element symbols requires the kernel to align with " + f"DEFAULT_ELEMENTS (n_components={n_components}, expected {len(DEFAULT_ELEMENTS)})." + ) + elem_mask_arg = torch.zeros(n_components, dtype=torch.bool) + for sym in allowed_elements: + elem_mask_arg[sym_to_idx[sym]] = True + else: + raise TypeError( + f"allowed_elements must be 'all' or a non-empty list of element symbols; got {type(allowed_elements).__name__}." + ) + + step_scale_arg: torch.Tensor | None = None + if isinstance(element_step_scale, (int, float)) and not isinstance(element_step_scale, bool): + if element_step_scale < 0: + raise ValueError(f"element_step_scale must be >= 0; got {element_step_scale}.") + if float(element_step_scale) != 1.0: + step_scale_arg = torch.full((n_components,), float(element_step_scale)) + # else: 1.0 means "no scaling"; keep step_scale_arg = None for the fast path. + elif isinstance(element_step_scale, Mapping): + sym_to_idx = {s: i for i, s in enumerate(DEFAULT_ELEMENTS)} + bad = [s for s in element_step_scale if s not in sym_to_idx] + if bad: + raise ValueError(f"Unknown element symbol(s) in element_step_scale: {bad}.") + if any(float(v) < 0 for v in element_step_scale.values()): + raise ValueError("element_step_scale values must be >= 0.") + if n_components != len(DEFAULT_ELEMENTS): + raise ValueError( + f"element_step_scale as a symbol dict requires the kernel to align with " + f"DEFAULT_ELEMENTS (n_components={n_components}, expected {len(DEFAULT_ELEMENTS)})." + ) + step_scale_arg = torch.ones(n_components) + for sym, val in element_step_scale.items(): + step_scale_arg[sym_to_idx[sym]] = float(val) + else: + raise TypeError( + f"element_step_scale must be a non-negative float or a mapping of " + f"element_symbol → float; got {type(element_step_scale).__name__}." + ) # --- Validate the seed (BEFORE touching model state, so a bad input doesn't leave the # model in eval() / with params switched off). --------------------------------------- @@ -2289,16 +2488,40 @@ def optimize_composition( kmd_kernel = kmd_kernel.to(device=device, dtype=dtype) # --- Build logits over n_components --------------------------------------------------- + # We additionally capture the *un-blended* normalised seed (``w0_seed``) — the + # locked-element hard-lock below uses these values, not the post-blend ones, so a + # user who writes ``element_step_scale={"Mg": 0.0}`` with ``initial_weights`` placing + # Mg at 0.30 sees Mg held at exactly 0.30 (not the slightly blended 0.286). + w0_seed: torch.Tensor | None = None if initial_weights is None: # Use the caller's existing global RNG state — don't reseed here (would defeat # the intended diversity across repeated calls and would leak state outward). logits = torch.randn(n_starts, n_components, device=device, dtype=dtype) * 0.5 + if elem_mask_arg is not None: + # Push disallowed elements to a deep negative logit so softmax mask works + # consistently for both the random and seeded branches (the per-step mask + # below also enforces this; we mirror it here for the t=0 score). + logits = logits.masked_fill(~elem_mask_arg.to(device=device), -1e9) else: w0 = initial_weights.to(device=device, dtype=dtype) - # Normalise to the simplex (callers may pass un-normalised positive weights); - # log gives logits whose softmax recovers the row. A tiny floor only avoids - # log(0) for legitimate zero entries (sparse element-presence seeds). w0 = w0 / w0.sum(dim=-1, keepdim=True) + w0_seed = w0.detach().clone() # un-blended; used as the lock reference below + # Blend in a uniform prior so non-seed-element logits are reachable by Adam. + # Without this, log(0) → −∞ (clamped to log(1e-12) ≈ −27.6); the softmax Jacobian + # is proportional to w_i, so the per-step gradient on those logits is ≈ 1e-12 and + # Adam cannot lift them within a few hundred steps — the support set is frozen to + # the seed's nonzero elements. ``seed_blend < 1`` spreads a small uniform mass + # over the allowed elements so every reachable element starts at a workable logit. + if seed_blend < 1.0: + if elem_mask_arg is not None: + uniform_row = elem_mask_arg.to(device=device, dtype=dtype) + uniform_row = uniform_row / uniform_row.sum() + else: + uniform_row = torch.full((n_components,), 1.0 / n_components, device=device, dtype=dtype) + w0 = seed_blend * w0 + (1.0 - seed_blend) * uniform_row + w0 = w0 / w0.sum(dim=-1, keepdim=True) + # Tiny floor only to avoid log(0) when an element is both disallowed AND not in + # the uniform support (i.e. seed_blend == 1.0 with sparse seeds). logits = torch.log(w0.clamp(min=1e-12)).detach().clone() logits = logits.requires_grad_(True) optimizer = optim.Adam([logits], lr=lr) @@ -2316,6 +2539,57 @@ def optimize_composition( for name, idxs in class_target_map.items() } + # Move the element-constraint tensors onto the right device (validated above). + elem_mask = elem_mask_arg.to(device=device) if elem_mask_arg is not None else None + step_scale = step_scale_arg.to(device=device, dtype=dtype) if step_scale_arg is not None else None + + # --- Hard-lock setup for elements with step_scale == 0 ---------------------------------- + # Zeroing ``logit_i.grad`` keeps that logit constant but does NOT keep ``w_i`` constant, + # because softmax renormalises across all logits — when other (unlocked) logits move, the + # softmax denominator changes and so does the locked weight. To truly honour the docstring + # promise "freezes those elements at their seed values", we (a) detect locked indices, (b) + # capture their per-row seed weights, and (c) inside ``_w_from_logits`` paste those seed + # values back over the softmax output and renormalise the unlocked positions to fill the + # remaining ``1 − Σ locked_w`` mass per row. The gradient through the locked indices is + # automatically zero (the lock branch uses a constant), so we no longer need the + # ``step_scale.mul_`` zeroing for them — but we leave that path active for the genuinely + # soft case ``0 < step_scale < 1``. + locked_mask: torch.Tensor | None = None + locked_w0: torch.Tensor | None = None + if step_scale is not None: + locked_idx_mask = step_scale == 0 + if locked_idx_mask.any(): + if w0_seed is None: + raise ValueError( + "element_step_scale = 0 (hard lock) requires initial_weights — there's no " + "per-row seed to lock to when initial_weights=None." + ) + if elem_mask is not None and (~elem_mask[locked_idx_mask]).any(): + raise ValueError( + "Locked elements (element_step_scale = 0) must also be in allowed_elements; " + "locking a disallowed element is contradictory." + ) + locked_mask = locked_idx_mask # (n_components,) bool, on device + # (B, n_components): seed values at locked positions, 0 elsewhere — constant. + locked_w0 = (w0_seed * locked_mask.to(dtype)).detach() + + def _w_from_logits(lg: torch.Tensor) -> torch.Tensor: + """Softmax over logits; mask disallowed elements; hard-lock the chosen ones at seed.""" + if elem_mask is not None: + lg = lg.masked_fill(~elem_mask, float("-inf")) + w = torch.softmax(lg, dim=-1) + if locked_mask is None: + return w + # Locked rows hold their seed values; unlocked rows are renormalised to fill the + # remaining mass ``1 − Σ_locked seed``. Differentiable: the lock branch is a constant + # so its gradient is 0; the unlocked branch's gradient flows through the rescale. + free_mask_f = (~locked_mask).to(w.dtype) # (n_components,) + w_unlocked = w * free_mask_f # zero at locked positions + # type: ignore[union-attr] — locked_w0 is set together with locked_mask above. + free_mass = (1.0 - locked_w0.sum(dim=-1, keepdim=True)).clamp(min=0.0) + w_unlocked = w_unlocked / w_unlocked.sum(dim=-1, keepdim=True).clamp(min=1e-12) * free_mass + return w_unlocked + locked_w0 + def _heads_forward(h_task: torch.Tensor) -> tuple[list[torch.Tensor], list[torch.Tensor]]: """Run regression heads, return (per-task predictions, loss terms).""" preds, terms = [], [] @@ -2343,7 +2617,7 @@ def _stack(values: list[torch.Tensor], B: int) -> torch.Tensor: # --- Record initial scores -------------------------------------------------------------- with torch.no_grad(): - w0_tensor = torch.softmax(logits, dim=-1) + w0_tensor = _w_from_logits(logits) h0 = torch.tanh(self.encoder(w0_tensor @ kmd_kernel)) initial_preds, _ = _heads_forward(h0) initial_score = _stack([p.detach() for p in initial_preds], logits.shape[0]) @@ -2352,29 +2626,51 @@ def _stack(values: list[torch.Tensor], B: int) -> torch.Tensor: # With every model parameter at ``requires_grad=False``, ``loss.backward()`` populates # gradient only on ``logits`` — no stale grads accumulate on encoder/heads. trajectory: list[torch.Tensor] = [] + weights_trajectory: list[torch.Tensor] = [] if record_weights_trajectory else [] for _ in range(steps): optimizer.zero_grad() - w = torch.softmax(logits, dim=-1) + w = _w_from_logits(logits) x = w @ kmd_kernel h_task = torch.tanh(self.encoder(x)) preds, terms = _heads_forward(h_task) - if sparsity_weight > 0: - # Negative entropy of w (minimise entropy → push w toward few-element mixtures). + if diversity_scale < 1.0: + # The penalty strength is (1 − diversity_scale): user sees a [0, 1] knob + # where 1 means "no penalty / most diverse" and 0 means "max penalty / most + # peaky". The internal term is `(1 − diversity_scale) · H(w)` added to loss. entropy = -(w * w.clamp(min=1e-12).log()).sum(dim=-1).mean() - terms.append(sparsity_weight * entropy) + terms.append((1.0 - diversity_scale) * entropy) loss = torch.stack(terms).mean() loss.backward() + if step_scale is not None and logits.grad is not None: + # Soft per-element constraint: scale each element's logit gradient (0 = frozen). + logits.grad.mul_(step_scale) optimizer.step() trajectory.append(_stack([p.detach() for p in preds], logits.shape[0])) + if record_weights_trajectory: + # Snapshot the post-step weights (after the softmax+hard-lock applied next iter + # would re-clean them, but the user wants what the *current* recipe looks like). + # Stored on CPU to keep GPU memory flat for long trajectories on large B. + with torch.no_grad(): + weights_trajectory.append(_w_from_logits(logits).detach().cpu()) # --- Final state ------------------------------------------------------------------------ with torch.no_grad(): - w_final = torch.softmax(logits, dim=-1) + w_final = _w_from_logits(logits) x_final = w_final @ kmd_kernel h_final = torch.tanh(self.encoder(x_final)) final_preds, _ = _heads_forward(h_final) final_target = _stack([p.detach() for p in final_preds], logits.shape[0]) + weights_traj_tensor: torch.Tensor | None = None + if record_weights_trajectory: + # (steps, B, n_components). Same empty-steps fallback as ``trajectory`` so the + # downstream code can rely on the shape contract without a None branch. + weights_traj_tensor = ( + torch.stack(weights_trajectory, dim=0) + if weights_trajectory + else torch.empty((0, logits.shape[0], n_components), dtype=torch.float32) + ) + return CompositionOptimizationResult( optimized_weights=w_final.detach(), optimized_descriptor=x_final.detach(), @@ -2384,6 +2680,7 @@ def _stack(values: list[torch.Tensor], B: int) -> torch.Tensor: trajectory=torch.stack(trajectory, dim=0) if trajectory else torch.empty((0, logits.shape[0], n_reg_tracked), device=device, dtype=dtype), + weights_trajectory=weights_traj_tensor, ) finally: if was_training: diff --git a/src/foundation_model/models/flexible_multi_task_model_test.py b/src/foundation_model/models/flexible_multi_task_model_test.py index 6512c48..18a450c 100644 --- a/src/foundation_model/models/flexible_multi_task_model_test.py +++ b/src/foundation_model/models/flexible_multi_task_model_test.py @@ -981,6 +981,91 @@ def test_optimize_latent_class_targets_only_no_regression(): assert res.optimized_target.shape == (4, 1, 0) # no regression tasks tracked +def test_optimize_latent_ae_align_validates_range(): + """ae_align_scale lives in [0, 1] — out-of-range values are rejected.""" + model = _make_reg_clf_model() + with pytest.raises(ValueError, match=r"ae_align_scale must be in \[0, 1\]"): + model.optimize_latent( + initial_input=torch.randn(2, INPUT_DIM), + task_targets={"prop": 1.0}, + optimize_space="latent", + ae_align_scale=-0.1, + ) + with pytest.raises(ValueError, match=r"ae_align_scale must be in \[0, 1\]"): + model.optimize_latent( + initial_input=torch.randn(2, INPUT_DIM), + task_targets={"prop": 1.0}, + optimize_space="latent", + ae_align_scale=1.5, + ) + + +def test_optimize_latent_ae_align_runs_in_latent_space(): + torch.manual_seed(0) + model = _make_reg_clf_model() # enable_autoencoder=True, so AE head is available + x = torch.randn(4, INPUT_DIM) + res = model.optimize_latent( + initial_input=x, + task_targets={"prop": 1.0}, + class_targets={"cls": [1]}, + class_target_weight=3.0, + ae_align_scale=0.5, # default empirical sweet spot + optimize_space="latent", + steps=10, + ) + assert res.optimized_input.shape == (4, 1, INPUT_DIM) + assert res.optimized_target.shape == (4, 1, 1) + + +def test_optimize_latent_class_target_weight_rejects_nonpositive(): + model = _make_reg_clf_model() + with pytest.raises(ValueError, match="class_target_weight must be > 0"): + model.optimize_latent( + initial_input=torch.randn(2, INPUT_DIM), + class_targets={"cls": [1]}, + class_target_weight=0.0, + optimize_space="input", + ) + + +def test_optimize_latent_class_target_weight_runs_with_combined_objectives(): + torch.manual_seed(0) + model = _make_reg_clf_model() + x = torch.randn(4, INPUT_DIM) + res = model.optimize_latent( + initial_input=x, + task_targets={"prop": 1.0}, + class_targets={"cls": [1]}, + class_target_weight=5.0, # class probability is the primary objective + optimize_space="input", + steps=10, + ) + assert res.optimized_input.shape == (4, 1, INPUT_DIM) + assert res.optimized_target.shape == (4, 1, 1) # one regression task tracked + + +def test_optimize_latent_restores_requires_grad_after_call(): + """Regression test for the requires_grad leak: optimize_latent must leave every model + parameter's ``requires_grad`` flag as it was before the call. Previously only ``training`` + mode was restored, so subsequent ``model.fit(...)`` calls silently froze the encoder / + heads and "training stopped moving the weights" became annoying to bisect. + """ + torch.manual_seed(0) + model = _make_reg_clf_model() + # Snapshot whatever pattern the caller had (all True by default, but the test should hold + # for any non-trivial pattern too). + expected = [p.requires_grad for p in model.parameters()] + model.optimize_latent( + initial_input=torch.randn(3, INPUT_DIM), + task_targets={"prop": 1.0}, + class_targets={"cls": [1]}, + optimize_space="input", + steps=5, + ) + actual = [p.requires_grad for p in model.parameters()] + assert actual == expected + + # --- optimize_composition (differentiable KMD) -------------------------------- @@ -1110,6 +1195,342 @@ def test_optimize_composition_restores_model_state_on_error(): assert [p.requires_grad for p in model.parameters()] == before_req_grad +def _build_aligned_model_and_kernel(): + """Helper for symbol-based tests: a tiny model + kernel whose first dim == len(DEFAULT_ELEMENTS). + + Symbol-based ``allowed_elements`` / ``element_step_scale`` require the kernel to align with + the bundled element registry. The kernel is random (matmul correctness is irrelevant here); + we just need the right shape so the symbol→index mapping is unambiguous. + """ + from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS + + n_components = len(DEFAULT_ELEMENTS) + enc = MLPEncoderConfig(hidden_dims=[INPUT_DIM, 16, LATENT_DIM]) + tasks = [ + RegressionTaskConfig(name="prop", data_column="prop", dims=[LATENT_DIM, 8, 1]), + ClassificationTaskConfig(name="cls", data_column="cls", num_classes=3, dims=[LATENT_DIM, 8, 3]), + ] + model = FlexibleMultiTaskModel(task_configs=tasks, encoder_config=enc, enable_autoencoder=True) + kernel = torch.randn(n_components, INPUT_DIM) + return model, kernel, DEFAULT_ELEMENTS + + +def test_optimize_composition_allowed_elements_symbol_whitelist(): + """A list of element symbols restricts w to those elements; the rest stay at exactly 0.""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + whitelist = ["Mg", "Al", "Cu", "Ni"] + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + class_targets={"cls": [1]}, + class_target_weight=3.0, + n_starts=3, + allowed_elements=whitelist, + steps=15, + lr=0.2, + ) + w = res.optimized_weights + allowed_idx = [elements.index(s) for s in whitelist] + forbidden_idx = [i for i in range(len(elements)) if i not in allowed_idx] + assert torch.all(w[:, forbidden_idx] == 0) + assert torch.allclose(w[:, allowed_idx].sum(dim=-1), torch.ones(3), atol=1e-5) + + +def test_optimize_composition_allowed_elements_default_all(): + """The default ``allowed_elements='all'`` imposes no constraint.""" + torch.manual_seed(0) + model = _make_reg_clf_model() + kernel = torch.randn(6, INPUT_DIM) # any kernel size works when no symbols are used + res = model.optimize_composition(kernel, task_targets={"prop": 0.5}, n_starts=2, steps=5) + # All columns can carry weight; nothing should be forced to zero by the default. + assert (res.optimized_weights > 0).all() + + +def test_optimize_composition_allowed_elements_validation(): + model, kernel, _ = _build_aligned_model_and_kernel() + # "all" is the only acceptable string. + with pytest.raises(ValueError, match="must be 'all'"): + model.optimize_composition(kernel, task_targets={"prop": 0.0}, allowed_elements="everything", steps=2) + # Empty list rejected. + with pytest.raises(ValueError, match="non-empty"): + model.optimize_composition(kernel, task_targets={"prop": 0.0}, allowed_elements=[], steps=2) + # Unknown symbol rejected. + with pytest.raises(ValueError, match="Unknown element symbol"): + model.optimize_composition(kernel, task_targets={"prop": 0.0}, allowed_elements=["Mg", "NotAnElement"], steps=2) + # Wrong type rejected. + with pytest.raises(TypeError, match="non-empty list"): + model.optimize_composition(kernel, task_targets={"prop": 0.0}, allowed_elements=42, steps=2) # type: ignore[arg-type] + # Symbols with a non-aligned kernel rejected. + small_kernel = torch.randn(6, INPUT_DIM) + with pytest.raises(ValueError, match="align with DEFAULT_ELEMENTS"): + model.optimize_composition( + small_kernel, task_targets={"prop": 0.0}, allowed_elements=["Mg", "Al"], n_starts=2, steps=2 + ) + + +def test_optimize_composition_element_step_scale_locks_symbols(): + """A symbol→0.0 mapping freezes those elements' weights at their **absolute** seed values. + + The previous version of this test only checked that the locked elements' ratio stayed at 1.0 + (which holds even if both drift together, since their logits move in lockstep). That doesn't + actually verify "frozen": with the bare gradient-zeroing implementation, ``w[Mg]`` drifts + because the softmax denominator changes whenever other (unlocked) logits move. This test + now asserts each locked element holds its **un-blended seed value** to within float tolerance. + """ + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + + # Seed: asymmetric mass on 4 specific symbols, zero on the rest. The asymmetry matters — + # equal-mass locks would survive ratio-only checks even if both drift together. + locked_syms = ["Mg", "Al"] + free_syms = ["Cu", "Ni"] + init_w = torch.zeros(1, len(elements)) + init_w[0, elements.index("Mg")] = 0.30 + init_w[0, elements.index("Al")] = 0.20 + init_w[0, elements.index("Cu")] = 0.30 + init_w[0, elements.index("Ni")] = 0.20 + + res = model.optimize_composition( + kernel, + task_targets={"prop": 5.0}, + initial_weights=init_w, + element_step_scale={s: 0.0 for s in locked_syms}, + steps=80, + lr=0.5, # large enough that any drift in locked weights would show up + ) + w = res.optimized_weights[0] + mg, al = elements.index("Mg"), elements.index("Al") + assert torch.isclose(w[mg], torch.tensor(0.30, dtype=w.dtype), atol=1e-4) + assert torch.isclose(w[al], torch.tensor(0.20, dtype=w.dtype), atol=1e-4) + # And the unlocked elements share the remaining 0.50 mass. + free_total = w.sum() - w[mg] - w[al] + assert torch.isclose(free_total, torch.tensor(0.50, dtype=w.dtype), atol=1e-4) + + +def test_optimize_composition_element_step_scale_locks_with_unlocked_drift(): + """Locked elements stay at seed even while unlocked elements actually move.""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + init_w = torch.zeros(1, len(elements)) + init_w[0, elements.index("Mg")] = 0.40 # locked + init_w[0, elements.index("Cu")] = 0.30 # free + init_w[0, elements.index("Ni")] = 0.30 # free + + res = model.optimize_composition( + kernel, + task_targets={"prop": 5.0}, + initial_weights=init_w, + element_step_scale={"Mg": 0.0}, + steps=80, + lr=0.5, + ) + w = res.optimized_weights[0] + # Mg held exactly. + assert torch.isclose(w[elements.index("Mg")], torch.tensor(0.40, dtype=w.dtype), atol=1e-4) + # The unlocked elements ended up in different ratios than they started (proves they moved). + cu0, ni0 = init_w[0, elements.index("Cu")], init_w[0, elements.index("Ni")] + cu_f, ni_f = w[elements.index("Cu")], w[elements.index("Ni")] + assert not torch.isclose(cu_f / ni_f, cu0 / ni0, atol=1e-3), "unlocked weights didn't actually move" + # And the unlocked mass equals 1 - locked mass. + assert torch.isclose(w.sum() - w[elements.index("Mg")], torch.tensor(0.60, dtype=w.dtype), atol=1e-4) + + +def test_optimize_composition_element_step_scale_lock_requires_initial_weights(): + """A hard lock with random init is rejected (no seed to lock to).""" + model, kernel, _ = _build_aligned_model_and_kernel() + with pytest.raises(ValueError, match="hard lock.*initial_weights"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + element_step_scale={"Mg": 0.0}, + n_starts=2, + steps=2, + ) + + +def test_optimize_composition_element_step_scale_lock_must_be_allowed(): + """Locking an element that's not in allowed_elements is contradictory and rejected.""" + model, kernel, elements = _build_aligned_model_and_kernel() + init_w = torch.zeros(1, len(elements)) + init_w[0, elements.index("Mg")] = 1.0 + with pytest.raises(ValueError, match="must also be in allowed_elements"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + initial_weights=init_w, + allowed_elements=["Al", "Cu"], + element_step_scale={"Mg": 0.0}, + steps=2, + ) + + +def test_optimize_composition_element_step_scale_uniform_scalar(): + """A scalar element_step_scale=0 freezes every element at the seed (uniform behaviour).""" + torch.manual_seed(0) + model = _make_reg_clf_model() + kernel = torch.randn(6, INPUT_DIM) + init_w = torch.tensor([[0.2, 0.2, 0.2, 0.2, 0.1, 0.1]]) + res = model.optimize_composition( + kernel, + task_targets={"prop": 5.0}, + initial_weights=init_w, + element_step_scale=0.0, # everything frozen + seed_blend=1.0, # strict seed → no uniform mixing, so w should match init_w exactly + steps=30, + lr=0.5, + ) + # With every element frozen and equal seed proportions kept, w should match init_w (normalised). + assert torch.allclose(res.optimized_weights, init_w, atol=1e-5) + + +def test_optimize_composition_element_step_scale_validation(): + model, kernel, _ = _build_aligned_model_and_kernel() + # Negative scalar rejected. + with pytest.raises(ValueError, match=">= 0"): + model.optimize_composition(kernel, task_targets={"prop": 0.0}, element_step_scale=-0.5, steps=2) + # Unknown symbol rejected. + with pytest.raises(ValueError, match="Unknown element symbol"): + model.optimize_composition( + kernel, task_targets={"prop": 0.0}, element_step_scale={"Mg": 0.5, "NotAnElement": 0.0}, steps=2 + ) + # Negative value in mapping rejected. + with pytest.raises(ValueError, match="values must be >= 0"): + model.optimize_composition( + kernel, task_targets={"prop": 0.0}, element_step_scale={"Mg": 0.5, "Al": -0.1}, steps=2 + ) + # Wrong type rejected. + with pytest.raises(TypeError, match="non-negative float or a mapping"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + element_step_scale=[1.0, 1.0], + steps=2, # type: ignore[arg-type] + ) + # Symbol dict with a non-aligned kernel rejected. + small_kernel = torch.randn(6, INPUT_DIM) + with pytest.raises(ValueError, match="align with DEFAULT_ELEMENTS"): + model.optimize_composition( + small_kernel, task_targets={"prop": 0.0}, element_step_scale={"Mg": 0.0}, n_starts=2, steps=2 + ) + + +def test_optimize_composition_seed_blend_validates_range(): + """seed_blend must be in [0, 1].""" + model, kernel, elements = _build_aligned_model_and_kernel() + w = torch.zeros(1, len(elements)) + w[0, 0] = 1.0 + with pytest.raises(ValueError, match=r"seed_blend must be in \[0, 1\]"): + model.optimize_composition(kernel, initial_weights=w, task_targets={"prop": 0.0}, seed_blend=-0.1, steps=2) + with pytest.raises(ValueError, match=r"seed_blend must be in \[0, 1\]"): + model.optimize_composition(kernel, initial_weights=w, task_targets={"prop": 0.0}, seed_blend=1.5, steps=2) + + +def test_optimize_composition_seed_blend_strict_freezes_support_set(): + """seed_blend=1.0 reproduces the old strict-seed behaviour: non-seed elements stay ~0.""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + + # Seed places all mass on Mg + Al; with seed_blend=1.0 every other element starts at logit + # log(1e-12) ≈ −27.6 and can't escape in a handful of steps. + init_w = torch.zeros(1, len(elements)) + init_w[0, elements.index("Mg")] = 0.6 + init_w[0, elements.index("Al")] = 0.4 + + res = model.optimize_composition( + kernel, + initial_weights=init_w, + task_targets={"prop": 5.0}, + seed_blend=1.0, + steps=40, + lr=0.1, + ) + w = res.optimized_weights[0] + seed_mass = w[elements.index("Mg")] + w[elements.index("Al")] + # Strict seed: non-seed elements never recruited — essentially all mass stays on Mg+Al. + assert seed_mass > 0.999 + + +def test_optimize_composition_seed_blend_allows_new_elements(): + """seed_blend<1.0 lifts non-seed logits enough that Adam can recruit new elements.""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + + init_w = torch.zeros(1, len(elements)) + init_w[0, elements.index("Mg")] = 0.6 + init_w[0, elements.index("Al")] = 0.4 + + res = model.optimize_composition( + kernel, + initial_weights=init_w, + task_targets={"prop": 5.0}, + seed_blend=0.5, # heavy blend so the test is robust to model init + steps=80, + lr=0.2, + ) + w = res.optimized_weights[0] + non_seed = sum(w[i].item() for i, s in enumerate(elements) if s not in {"Mg", "Al"}) + # Some non-seed mass should accumulate (the toy model has no specific preference, so we + # only require the floor to be measurably above zero — the strict-seed test above shows + # the same setup gives ~0 when seed_blend=1.0). + assert non_seed > 0.05 + + +def test_optimize_composition_random_init_uses_n_starts(): + """initial_weights=None falls back to n_starts random simplex points; allowed_elements still binds.""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + allowed = ["Mg", "Al", "Cu", "Ni"] + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + n_starts=5, + allowed_elements=allowed, + steps=5, + ) + assert res.optimized_weights.shape == (5, len(elements)) + # Disallowed elements stay at exactly zero (mask is applied at every step). + disallowed = [i for i, s in enumerate(elements) if s not in allowed] + assert torch.allclose(res.optimized_weights[:, disallowed], torch.zeros_like(res.optimized_weights[:, disallowed])) + + +def test_optimize_composition_diversity_scale_validates_range(): + """diversity_scale lives in [0, 1] — out-of-range values are rejected.""" + model, kernel, _ = _build_aligned_model_and_kernel() + with pytest.raises(ValueError, match=r"diversity_scale must be in \[0, 1\]"): + model.optimize_composition(kernel, task_targets={"prop": 0.0}, diversity_scale=-0.1, n_starts=2, steps=2) + with pytest.raises(ValueError, match=r"diversity_scale must be in \[0, 1\]"): + model.optimize_composition(kernel, task_targets={"prop": 0.0}, diversity_scale=1.5, n_starts=2, steps=2) + + +def test_optimize_composition_diversity_scale_endpoints_run(): + """Both endpoints (0 = max penalty, 1 = no penalty default) run cleanly and stay on the simplex.""" + torch.manual_seed(0) + model, kernel, _ = _build_aligned_model_and_kernel() + for scale in (0.0, 0.5, 1.0): + res = model.optimize_composition(kernel, task_targets={"prop": 1.0}, n_starts=3, diversity_scale=scale, steps=5) + assert res.optimized_weights.shape[0] == 3 + assert torch.allclose(res.optimized_weights.sum(dim=-1), torch.ones(3), atol=1e-5) + + +def test_optimize_composition_diversity_scale_direction(): + """diversity_scale=1 (no penalty) keeps a higher per-output entropy than diversity_scale=0 (max penalty).""" + torch.manual_seed(0) + model, kernel, _ = _build_aligned_model_and_kernel() + res_peaky = model.optimize_composition( + kernel, task_targets={"prop": 1.0}, n_starts=4, diversity_scale=0.0, steps=60, lr=0.2 + ) + torch.manual_seed(0) + res_spread = model.optimize_composition( + kernel, task_targets={"prop": 1.0}, n_starts=4, diversity_scale=1.0, steps=60, lr=0.2 + ) + + def _mean_entropy(w): + return float(-(w * w.clamp(min=1e-12).log()).sum(dim=-1).mean()) + + assert _mean_entropy(res_spread.optimized_weights) > _mean_entropy(res_peaky.optimized_weights) + + def test_optimize_composition_uses_kmd_kernel_torch(): """End-to-end: a real KMD's kernel_torch flows into optimize_composition.""" from foundation_model.utils.kmd_plus import KMD diff --git a/src/foundation_model/models/model_config.py b/src/foundation_model/models/model_config.py index 2a85c3e..582d271 100644 --- a/src/foundation_model/models/model_config.py +++ b/src/foundation_model/models/model_config.py @@ -76,7 +76,7 @@ class TransformerEncoderConfig(BaseEncoderConfig): """Configuration for the transformer foundation encoder. ``use_cls_token`` determines how the encoder aggregates feature tokens - before passing them into the deposit layer: enabling it selects the + before the model-level ``tanh`` and the task heads: enabling it selects the contextualised ``[CLS]`` embedding, while disabling it applies mean pooling over all tokens. In both cases gradients still reach every feature token via the self-attention blocks. @@ -278,6 +278,9 @@ class ClassificationTaskConfig(BaseTaskConfig): type: TaskType = TaskType.CLASSIFICATION # Overrides Base.type, provides default, remains positional norm: bool = True # New positional argument with default residual: bool = False # New positional argument with default + # Optional per-class weights for the cross-entropy loss (length == num_classes). Use to + # counter class imbalance so a dominant class doesn't collapse predictions onto itself. + class_weights: Optional[List[float]] = field(default=None, kw_only=True) @dataclass diff --git a/src/foundation_model/models/task_head/classification.py b/src/foundation_model/models/task_head/classification.py index d295911..4bf02fd 100644 --- a/src/foundation_model/models/task_head/classification.py +++ b/src/foundation_model/models/task_head/classification.py @@ -61,6 +61,24 @@ def __init__(self, config: ClassificationTaskConfig): # Changed signature self.num_classes = num_classes + # Per-class loss weights. We **always** register a real tensor buffer so the state_dict + # key ``class_weights`` is present regardless of whether per-class weights were configured + # — without this, a checkpoint saved with a configured head couldn't strict-load into one + # built without weights (or vice versa). When weights aren't configured we register ones, + # which is the identity for both ``F.cross_entropy(..., weight=w)`` and the per-sample + # reduction below, so the unweighted behaviour is unchanged. + class_weights = getattr(config, "class_weights", None) + if class_weights is not None: + weights = torch.as_tensor(class_weights, dtype=torch.float) + if weights.numel() != num_classes: + raise ValueError(f"class_weights length ({weights.numel()}) must equal num_classes ({num_classes}).") + else: + weights = torch.ones(num_classes, dtype=torch.float) + self.register_buffer("class_weights", weights) + # Keep a flag so callers / code paths that branch on "did the user actually pass weights?" + # don't have to compare against ones. Not part of state_dict. + self._has_class_weights = class_weights is not None + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: """ Forward pass of the classification head. @@ -147,7 +165,9 @@ def compute_loss( # 4. Individual sample losses # Use mask mechanism only (no ignore_index) for unified missing data handling # Missing data placeholders (-100) in targets won't affect loss due to mask filtering - losses = F.cross_entropy(pred, final_target_for_loss, reduction="none") # losses is (B,) + losses = F.cross_entropy( + pred, final_target_for_loss, weight=self.class_weights, reduction="none" + ) # losses is (B,) masked_losses = losses * mask_1d # Apply 1D mask, result is (B,) # 5. Total loss - simple division without defensive clamp diff --git a/src/foundation_model/models/task_head/classification_test.py b/src/foundation_model/models/task_head/classification_test.py new file mode 100644 index 0000000..7bc5976 --- /dev/null +++ b/src/foundation_model/models/task_head/classification_test.py @@ -0,0 +1,59 @@ +# Copyright 2025 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn.functional as F + +from foundation_model.models.model_config import ClassificationTaskConfig +from foundation_model.models.task_head.classification import ClassificationHead + + +def _head(class_weights=None, num_classes=3): + cfg = ClassificationTaskConfig( + name="cls", data_column="cls", dims=[8, 16, num_classes], num_classes=num_classes, class_weights=class_weights + ) + return ClassificationHead(cfg) + + +def test_class_weights_none_matches_unweighted_cross_entropy(): + head = _head(class_weights=None) + pred = torch.randn(5, 3) + target = torch.tensor([0, 1, 2, 1, 0]) + loss = head.compute_loss(pred, target) + expected = F.cross_entropy(pred, target) + assert torch.allclose(loss, expected) + + +def test_class_weights_applied_in_loss(): + weights = [1.0, 5.0, 0.2] + head = _head(class_weights=weights) + pred = torch.randn(5, 3) + target = torch.tensor([0, 1, 2, 1, 0]) + loss = head.compute_loss(pred, target) + # The head averages weighted per-sample losses by sample count (its masking convention), + # not by the sum of weights (F.cross_entropy's default "mean"). + per_sample = F.cross_entropy(pred, target, weight=torch.tensor(weights), reduction="none") + expected = per_sample.sum() / target.numel() + assert torch.allclose(loss, expected) + # The weights buffer follows the module (saved/moved with it). + assert "class_weights" in dict(head.named_buffers()) + + +def test_class_weights_length_must_match_num_classes(): + with pytest.raises(ValueError, match="class_weights length"): + _head(class_weights=[1.0, 2.0], num_classes=3) + + +def test_class_weights_state_dict_key_present_when_unset(): + """Whether class_weights is configured or not, the ``class_weights`` buffer key must exist + in the state_dict — so a checkpoint saved with weights can strict-load into a head built + without them (and vice versa). Without ``register_buffer("class_weights", None)`` the key + only appears when weights are set, which breaks cross-config checkpoint compatibility.""" + head_unweighted = _head(class_weights=None) + head_weighted = _head(class_weights=[1.0, 2.0, 0.5]) + assert "class_weights" in head_unweighted.state_dict() + assert "class_weights" in head_weighted.state_dict() + # And strict-loading across configs works in both directions (the missing/present None case). + head_unweighted.load_state_dict(head_weighted.state_dict(), strict=True) + head_weighted.load_state_dict(head_unweighted.state_dict(), strict=True) diff --git a/src/foundation_model/scripts/continual_rehearsal_common.py b/src/foundation_model/scripts/continual_rehearsal_common.py new file mode 100644 index 0000000..b2dc4ae --- /dev/null +++ b/src/foundation_model/scripts/continual_rehearsal_common.py @@ -0,0 +1,415 @@ +# Copyright 2025 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +""" +Shared evaluation-dump + plotting helpers used by both continual-rehearsal runners. + +:mod:`continual_rehearsal_demo` (educational, single scenario) and +:mod:`continual_rehearsal_full` (formal, three scenarios) previously carried near-identical +copies of these functions as bound methods on their respective ``Runner`` classes. The +duplication caused at least one drift incident (the ``_plot_kr_sequences`` ``NameError`` +on empty ``comps`` was fixed in ``full`` first; ``demo`` carried the broken copy for +several PRs). Centralising the pure helpers here prevents future drift. + +What's in scope here: + +* **Constants** the two runners share (the project's plot palette and the merged + material_type 3-class ordering). +* **Pure dumpers** — `(composition, true, pred)` parquet + per-task ``_metrics.json`` + emitted at every step. No model / runner state needed. +* **Pure plotters** — parity scatter, confusion matrix, kernel-regression sequences. + Each takes a per-task ``title`` argument so the runner-specific task display vocabulary + (``TASK_DISPLAY`` / ``_title()`` / ``_display()``) stays in its home file. + +What's NOT in scope here: + +* Anything that needs ``Runner`` state (data caches, ``TASK_SPECS``, model parameters). +* The forgetting trajectory plot (uses per-runner ``_task_colors``). +* The inverse-design plotters (different per runner — single-scenario vs eight-path). +""" + +from __future__ import annotations + +import json +import re +from collections import Counter +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from loguru import logger +from sklearn.metrics import r2_score # type: ignore[import-untyped] + +# --- Shared constants ---------------------------------------------------------------------- + +#: Single blue used for every regression parity scatter and KR-prediction line — keeps the +#: meaning of "predicted vs ideal" colour-consistent across regression and kernel-regression +#: panels. PR #18 settled on this exact tone. +SCATTER_COLOR = "#2563EB" + +#: Orange used to highlight inverse-design "discovered" elements (element symbols that appear in +#: at least one optimised composition but not in any of the seeds). Used both by the element- +#: frequency heatmap (x-tick label) and by the seed-vs-optimised mapping plot (legend / arrow tip) +#: so the colour meaning is consistent across figures. +DISCOVERED_ELEMENT_COLOR = "#E67E22" + +#: The merged material_type label set (5 fine classes → 3). The order here is the *canonical* +#: index order (so ``MATERIAL_TYPE_CLASSES[0] == "AC"`` means merged class 0 is AC, etc.). +MATERIAL_TYPE_CLASSES: tuple[str, ...] = ("AC", "QC", "others") + +#: Display order for the confusion-matrix axes. Bottom-left → top-right diagonal places the +#: minority QC class in the upper-right corner, mirroring the canonical "others → AC → QC" +#: progression the project standardised on in PR #18. +MATERIAL_TYPE_DISPLAY_ORDER: tuple[str, ...] = ("others", "AC", "QC") + +#: Element-symbol regex: capital + optional lowercase, paired with an optional stoichiometry +#: suffix. Same pattern both runners' seed parsing uses, kept centrally so the heatmap and any +#: future seed-parsing utility can't drift. +_COMP_RE = re.compile(r"([A-Z][a-z]?)([\d.]*)") + + +def element_set(formula: str) -> frozenset[str]: + """Set of element symbols in a composition string, ignoring stoichiometry. + + Handles both human-friendly forms (``"Mg2 Zn1 Y1"``) and the project's KMD-decoded form + (``"Al0.473 Cu0.130 Fe0.109 …"``). Whitespace is irrelevant. + """ + return frozenset(el for el, _ in _COMP_RE.findall(formula) if el) + + +# --- Pure dumpers -------------------------------------------------------------------------- + + +def dump_predictions( + task_name: str, + step_dir: Path, + *, + comps: list[str], + true: np.ndarray, + pred: np.ndarray, +) -> None: + """Persist ``(composition, true, pred)`` for a regression or classification task as parquet. + + Single row per test sample. The trio is enough for downstream re-plotting (parity scatter + for regression, confusion matrix for classification) without re-running the model. + """ + pd.DataFrame({"composition": comps, "true": true, "pred": pred}).to_parquet(step_dir / f"{task_name}_pred.parquet") + + +def dump_kr_predictions( + task_name: str, + step_dir: Path, + *, + comps: list[str], + t_list: list[np.ndarray], + true_parts: list[np.ndarray], + pred: np.ndarray, +) -> None: + """Persist kernel-regression test predictions in long form: one row per ``(composition, t)``. + + The flat ``pred`` array carries every composition's values back-to-back; we re-split it + using each composition's ``true_parts`` length so the long-form table is fully reconstructible. + """ + rows: list[dict[str, object]] = [] + offset = 0 + for comp, t_arr, y_true in zip(comps, t_list, true_parts): + n = int(y_true.size) + for k in range(n): + rows.append( + { + "composition": comp, + "t": float(t_arr[k]), + "true": float(y_true[k]), + "pred": float(pred[offset + k]), + } + ) + offset += n + pd.DataFrame(rows).to_parquet(step_dir / f"{task_name}_pred.parquet") + + +def dump_metrics(task_name: str, step_dir: Path, metric: dict[str, float]) -> None: + """Drop the per-task metric dict next to the parquet, for quick human / scripted inspection.""" + (step_dir / f"{task_name}_metrics.json").write_text(json.dumps(metric, indent=2), encoding="utf-8") + + +# --- Pure plotters ------------------------------------------------------------------------- + + +def plot_parity( + true: np.ndarray, + pred: np.ndarray, + task_name: str, + r2: float, + step_dir: Path, + *, + title: str, +) -> None: + """Regression parity scatter (true vs predicted) with ideal-line and an R² annotation.""" + fig, ax = plt.subplots(figsize=(5, 5)) + # Uniform colour/alpha for every regression parity scatter — set in PR #18. + ax.scatter(true, pred, s=14, alpha=0.55, color=SCATTER_COLOR, edgecolor="none") + lo, hi = float(min(true.min(), pred.min())), float(max(true.max(), pred.max())) + ax.plot([lo, hi], [lo, hi], color="#444444", ls="--", lw=1.2, label="ideal") + ax.set_xlabel("True") + ax.set_ylabel("Predicted") + ax.set_title(title) + ax.text( + 0.04, + 0.96, + f"R² = {r2:.3f}\nn = {len(true)}", + transform=ax.transAxes, + ha="left", + va="top", + fontsize=10, + bbox=dict(boxstyle="round,pad=0.4", facecolor="white", edgecolor="#d0d0d0", alpha=0.9), + ) + ax.legend(loc="lower right") + fig.savefig(step_dir / f"{task_name}_parity.png") + plt.close(fig) + + +def plot_confusion( + true: np.ndarray, + pred: np.ndarray, + task_name: str, + acc: float, + step_dir: Path, + num_classes: int, + *, + title: str, + special_material_type: bool = False, +) -> None: + """Row-normalised confusion matrix. + + When ``special_material_type`` is set (the merged 3-class material_type task), axes are + reordered to ``MATERIAL_TYPE_DISPLAY_ORDER`` so the recall diagonal runs bottom-left → + top-right with the minority QC class in the upper-right corner. + """ + counts = np.zeros((num_classes, num_classes), dtype=int) + for t, p in zip(true, pred): + if 0 <= t < num_classes and 0 <= p < num_classes: + counts[t, p] += 1 + # Display order + bottom-left origin (PR #18 standardisation). + if special_material_type: + labels = list(MATERIAL_TYPE_DISPLAY_ORDER[:num_classes]) + perm = [MATERIAL_TYPE_CLASSES.index(lbl) for lbl in labels] + else: + labels = [str(i) for i in range(num_classes)] + perm = list(range(num_classes)) + counts = counts[np.ix_(perm, perm)] + # Colour by row-normalised fraction (recall) so a dominant class doesn't leave every other + # row invisible. Annotate each cell with both the fraction and the raw count. + row_sums = counts.sum(axis=1, keepdims=True) + row_frac = np.divide(counts, row_sums, out=np.zeros(counts.shape, dtype=float), where=row_sums > 0) + fig, ax = plt.subplots(figsize=(5.6, 5.2)) + im = ax.imshow(row_frac, cmap="Blues", vmin=0.0, vmax=1.0, origin="lower") + fig.colorbar(im, ax=ax, label="row-normalized fraction (recall)", fraction=0.046, pad=0.04) + ax.set_xticks(range(num_classes), labels, rotation=45, ha="right") + ax.set_yticks(range(num_classes), labels) + for i in range(num_classes): + for j in range(num_classes): + if counts[i, j]: + ax.text( + j, + i, + f"{row_frac[i, j] * 100:.0f}%\n{counts[i, j]}", + ha="center", + va="center", + fontsize=8, + color="white" if row_frac[i, j] > 0.5 else "#333333", + ) + ax.grid(False) + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + ax.set_title(title) + ax.text( + 0.5, + -0.22, + f"accuracy = {acc:.3f} · n = {int(counts.sum())}", + transform=ax.transAxes, + ha="center", + va="top", + fontsize=10, + ) + fig.savefig(step_dir / f"{task_name}_confusion.png") + plt.close(fig) + + +def plot_kr_sequences( + comps: list[str], + t_list: list, # list of torch.Tensor — kept as Any to avoid importing torch here + true_parts: list[np.ndarray], + pred: np.ndarray, + task_name: str, + step_dir: Path, + *, + title: str, +) -> None: + """Per-composition KR sequence panels — up to 3 panels, each with its own R² annotation. + + Empty ``comps`` (no test samples for the task at this step — possible on very small KR + datasets like ``magnetic_susceptibility``) used to silently break here: ``min(3, 0) == 0`` + skipped the loop, then ``fig.legend([line_true, line_pred], …)`` raised ``NameError`` on + unbound names. Now we short-circuit with a warning and return without writing a PNG. + """ + k = min(3, len(comps)) + if k == 0: + logger.warning(f"plot_kr_sequences: no compositions for '{task_name}' — skipping plot.") + return + fig, axes = plt.subplots(1, k, figsize=(4.2 * k, 3.7), squeeze=False) + offset = 0 + line_true = line_pred = None + for i in range(k): + ax = axes[0][i] + n = true_parts[i].size + t = t_list[i].cpu().numpy() + true_i = np.asarray(true_parts[i]) + pred_i = pred[offset : offset + n] + order = np.argsort(t) # left-to-right curve + (line_true,) = ax.plot(t[order], true_i[order], color="#444444", lw=1.8, label="True") + # Same blue as the regression parity scatter — keeps "Predicted" colour consistent + # across regression / kernel-regression panels. + (line_pred,) = ax.plot(t[order], pred_i[order], color=SCATTER_COLOR, lw=1.6, ls="--", label="Predicted") + ax.set_xlabel("t") + if i == 0: + ax.set_ylabel("Value") + r2_i = float(r2_score(true_i, pred_i)) if n >= 2 and float(np.var(true_i)) > 0 else float("nan") + ax.text( + 0.96, + 0.96, + f"R² = {r2_i:.3f}", + transform=ax.transAxes, + ha="right", + va="top", + fontsize=9, + bbox=dict(boxstyle="round,pad=0.4", facecolor="white", edgecolor="#d0d0d0", alpha=0.9), + ) + ax.set_title(comps[i], fontsize=9) + offset += n + if line_true is not None and line_pred is not None: + fig.legend( + [line_true, line_pred], + ["True", "Predicted"], + loc="lower left", + ncol=2, + bbox_to_anchor=(0.0, 1.10), + bbox_transform=axes[0][0].transAxes, + ) + fig.suptitle(title, y=1.24) + fig.savefig(step_dir / f"{task_name}_sequences.png") + plt.close(fig) + + +def plot_element_frequency_heatmap( + methods: list[dict[str, Any]], + seeds: list[str], + out_path: Path, + *, + top_k: int = 25, +) -> None: + """Per-method × top-K-element occurrence heatmap. + + For each method we count how many of its decoded recipes contain each element (i.e. its + symbol appears anywhere in the formatted ``decoded_composition`` string). The top ``top_k`` + elements globally are shown as columns; methods are rows. Elements absent from every seed + in ``seeds`` are highlighted on the x-axis as **bold orange** — the inverse-design + *element-discovery* signal. Underline is omitted (visually noisy under tight rotated + labels); bold + a distinct colour is enough. + + Parameters + ---------- + methods + One dict per row of the heatmap. Each dict must carry: + + * ``label`` — the human-readable y-tick label + (e.g. ``"latent α=0.25"``, ``"comp (seed, 5% all)"``). ``\\n`` is collapsed to a + single space so labels stay on one line. + * ``decoded_composition`` — list of formatted composition strings (``"Al0.473 Cu0.13 …"`` + or ``"Mg2 Zn1 Y1"``); typically one entry per seed for that method. + seeds + The seed composition list used by every method. Drives the "in any seed?" check that + marks elements as discovered (no underline; just bold + ``DISCOVERED_ELEMENT_COLOR``). + out_path + Full path to the PNG to write. + top_k + How many columns (elements) to show, ranked globally across methods (default 25). + """ + n = len(methods) + if n == 0: + logger.warning("plot_element_frequency_heatmap: no methods supplied; skipping.") + return + labels = [m["label"].replace("\n", " ") for m in methods] + + # Seed element multiplicity — used to decide which elements are "discovered" (0 in seeds). + seed_cnt: Counter[str] = Counter() + for s in seeds: + for el in element_set(s): + seed_cnt[el] += 1 + + # Per-method element-presence counts. + per_method: list[Counter[str]] = [] + for m in methods: + c: Counter[str] = Counter() + for d in m.get("decoded_composition", []) or []: + for el in element_set(d): + c[el] += 1 + per_method.append(c) + + # Globally top elements (rank by sum-of-top-8-per-method so single-method blow-ups don't + # dominate — matches the ranking the standalone post-hoc script used). + global_cnt: Counter[str] = Counter() + for c in per_method: + for el, k in c.most_common(8): + global_cnt[el] += k + top_elems = [e for e, _ in global_cnt.most_common(top_k)] + if not top_elems: + logger.warning(f"plot_element_frequency_heatmap: no elements found across methods; skipping {out_path}.") + return + + n_per_method = len(methods[0].get("decoded_composition") or []) or len(seeds) or 1 + mat = np.zeros((n, len(top_elems)), dtype=int) + for i, c in enumerate(per_method): + for j, el in enumerate(top_elems): + mat[i, j] = c[el] + + fig, ax = plt.subplots(figsize=(13, 6)) + im = ax.imshow(mat, aspect="auto", cmap="Blues", vmin=0, vmax=n_per_method) + ax.set_xticks(range(len(top_elems))) + ax.set_xticklabels(top_elems, fontsize=11) + ax.set_yticks(range(n)) + ax.set_yticklabels(labels, fontsize=9) + ax.set_title( + f"Element appearance counts per method (top {len(top_elems)})\n" + f"Bold orange element symbols = NOT in any of the {len(seeds)} seeds (introduced by the optimiser)", + fontsize=11, + pad=12, + ) + # Bold + orange for discovered elements; everything else stays in the default style. + for tick_label, el in zip(ax.get_xticklabels(), top_elems): + if seed_cnt[el] == 0: + tick_label.set_fontweight("bold") + tick_label.set_color(DISCOVERED_ELEMENT_COLOR) + # Cell annotations. + for i in range(n): + for j in range(len(top_elems)): + if mat[i, j]: + ax.text( + j, + i, + str(mat[i, j]), + ha="center", + va="center", + fontsize=8, + color="white" if mat[i, j] > n_per_method * 0.5 else "#333", + ) + cbar = fig.colorbar(im, ax=ax, fraction=0.03, pad=0.01) + cbar.set_label(f"appearance count (out of {n_per_method} outputs)") + # The shared plot style sets ``axes.grid = True`` globally, which on an ``imshow`` heatmap + # draws grid lines through every cell centre (major ticks coincide with cell centres). Turn + # the grid off here so the cells stay clean. + ax.grid(False) + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) diff --git a/src/foundation_model/scripts/continual_rehearsal_common_test.py b/src/foundation_model/scripts/continual_rehearsal_common_test.py new file mode 100644 index 0000000..cabbfde --- /dev/null +++ b/src/foundation_model/scripts/continual_rehearsal_common_test.py @@ -0,0 +1,177 @@ +# Copyright 2025 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the shared dump / plot helpers in :mod:`continual_rehearsal_common`. + +The runners are end-to-end-tested via smoke runs; these tests pin the pure-function behaviour +of the helpers they share, including the edge cases that motivated factoring them out. +""" + +from __future__ import annotations + +import json + +import numpy as np +import pandas as pd + +from foundation_model.scripts.continual_rehearsal_common import ( + MATERIAL_TYPE_CLASSES, + MATERIAL_TYPE_DISPLAY_ORDER, + SCATTER_COLOR, + dump_kr_predictions, + dump_metrics, + dump_predictions, + plot_confusion, + plot_kr_sequences, + plot_parity, +) + + +# --- shared constants --- + + +def test_scatter_colour_is_the_project_blue(): + """The blue must match the project palette; ``paper_inverse_comparison`` and slide deck + reference this exact hex. Changing it without coordinating breaks the slide colour story.""" + assert SCATTER_COLOR == "#2563EB" + + +def test_material_type_canonical_and_display_orders_are_consistent_three_classes(): + assert sorted(MATERIAL_TYPE_CLASSES) == sorted(MATERIAL_TYPE_DISPLAY_ORDER) + assert len(MATERIAL_TYPE_CLASSES) == 3 + + +# --- dumpers --- + + +def test_dump_predictions_writes_parquet_with_expected_columns(tmp_path): + step_dir = tmp_path / "step01_density" + step_dir.mkdir() + dump_predictions( + "density", + step_dir, + comps=["Mg1 Cu1", "Al1 Fe1"], + true=np.array([0.1, 0.2]), + pred=np.array([0.15, 0.18]), + ) + out = step_dir / "density_pred.parquet" + assert out.exists() + df = pd.read_parquet(out) + assert list(df.columns) == ["composition", "true", "pred"] + assert len(df) == 2 + assert df["composition"].tolist() == ["Mg1 Cu1", "Al1 Fe1"] + + +def test_dump_kr_predictions_long_form_round_trips(tmp_path): + """KR predictions are stored long-form (one row per (composition, t)); the flat ``pred`` + array is correctly re-split using ``true_parts`` lengths.""" + step_dir = tmp_path / "step08_dos_density" + step_dir.mkdir() + comps = ["Mg1 Cu1", "Al1 Fe1"] + t_list = [np.array([0.0, 1.0]), np.array([0.0, 1.0, 2.0])] # different lengths per comp + true_parts = [np.array([10.0, 11.0]), np.array([20.0, 21.0, 22.0])] + pred = np.array([10.5, 11.5, 20.5, 21.5, 22.5]) # flat across both comps + + dump_kr_predictions("dos_density", step_dir, comps=comps, t_list=t_list, true_parts=true_parts, pred=pred) + out = step_dir / "dos_density_pred.parquet" + assert out.exists() + df = pd.read_parquet(out) + # 2 + 3 = 5 rows total + assert len(df) == 5 + # Per-composition slices recover the right pred values from the long-form table. + mg = df[df["composition"] == "Mg1 Cu1"].sort_values("t") + assert mg["pred"].tolist() == [10.5, 11.5] + al = df[df["composition"] == "Al1 Fe1"].sort_values("t") + assert al["pred"].tolist() == [20.5, 21.5, 22.5] + + +def test_dump_metrics_writes_indented_json(tmp_path): + step_dir = tmp_path / "step01_density" + step_dir.mkdir() + dump_metrics("density", step_dir, {"r2": 0.95, "mae": 0.1, "samples": 100, "primary": 0.95}) + out = step_dir / "density_metrics.json" + assert out.exists() + body = json.loads(out.read_text()) + assert body == {"r2": 0.95, "mae": 0.1, "samples": 100, "primary": 0.95} + + +# --- plots --- + + +def test_plot_parity_writes_png(tmp_path): + step_dir = tmp_path / "step01_density" + step_dir.mkdir() + true = np.linspace(0.0, 1.0, 50) + pred = true + np.random.default_rng(0).normal(0, 0.05, 50) + plot_parity(true, pred, "density", r2=0.95, step_dir=step_dir, title="Density (normalized)") + assert (step_dir / "density_parity.png").exists() + + +def test_plot_confusion_writes_png_for_generic_and_material_type(tmp_path): + step_dir = tmp_path / "step11_material_type" + step_dir.mkdir() + rng = np.random.default_rng(0) + true = rng.integers(0, 3, size=100) + pred = rng.integers(0, 3, size=100) + + plot_confusion( + true, + pred, + "material_type", + acc=0.5, + step_dir=step_dir, + num_classes=3, + title="Material type", + special_material_type=True, + ) + assert (step_dir / "material_type_confusion.png").exists() + + plot_confusion( + true, + pred, + "another_clf", + acc=0.5, + step_dir=step_dir, + num_classes=3, + title="Another classifier", + special_material_type=False, + ) + assert (step_dir / "another_clf_confusion.png").exists() + + +def test_plot_kr_sequences_returns_silently_on_empty_comps(tmp_path): + """The PR #18 regression that motivated the refactor: empty ``comps`` used to crash with + ``NameError`` inside ``fig.legend``. Now it returns early without writing anything.""" + step_dir = tmp_path / "step08_dos_density" + step_dir.mkdir() + plot_kr_sequences( + comps=[], + t_list=[], + true_parts=[], + pred=np.array([]), + task_name="dos_density", + step_dir=step_dir, + title="DOS density", + ) + assert not (step_dir / "dos_density_sequences.png").exists() + + +def test_plot_kr_sequences_renders_panels_when_data_present(tmp_path): + """Single-composition smoke: a sequence panel is rendered without raising.""" + import torch + + step_dir = tmp_path / "step08_dos_density" + step_dir.mkdir() + t = torch.linspace(0.0, 1.0, 8) + true_part = np.linspace(0.0, 1.0, 8) + pred = np.linspace(0.05, 0.95, 8) + plot_kr_sequences( + comps=["Mg1 Cu1"], + t_list=[t], + true_parts=[true_part], + pred=pred, + task_name="dos_density", + step_dir=step_dir, + title="DOS density", + ) + assert (step_dir / "dos_density_sequences.png").exists() diff --git a/src/foundation_model/scripts/continual_rehearsal_demo.py b/src/foundation_model/scripts/continual_rehearsal_demo.py index a777970..12a57ac 100644 --- a/src/foundation_model/scripts/continual_rehearsal_demo.py +++ b/src/foundation_model/scripts/continual_rehearsal_demo.py @@ -23,8 +23,9 @@ Every step evaluates *all* active heads on the fixed test split and plots the new head plus the per-task forgetting trajectory. -After all tasks are learned, an **inverse-design** stage optimizes the latent space toward a -condition (2 regression targets + increased quasicrystal probability) and decodes the optimized +After all tasks are learned, an **inverse-design** stage seeds from the highest-QC training +compositions and optimizes the latent to **raise quasicrystal probability** (primary) with low +formation energy and high lattice thermal conductivity (secondary), then decodes the optimized KMD descriptor back to a composition via ``KMD.inverse``. Run: @@ -38,6 +39,7 @@ import ast import base64 import json +import re from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -65,6 +67,24 @@ OptimizerConfig, RegressionTaskConfig, ) + +# Shared evaluation / plot helpers, used by both demo and full runners. Live in a sibling module +# so the two runners can't drift again (the ``_plot_kr_sequences`` ``NameError`` regression that +# motivated this refactor only existed because demo and full each carried their own copy). +# ``MATERIAL_TYPE_CLASSES`` / ``MATERIAL_TYPE_DISPLAY_ORDER`` / ``_SCATTER_COLOR`` are re-exported +# from this module for backward compatibility — ``continual_rehearsal_full`` and other callers +# previously did ``from continual_rehearsal_demo import _SCATTER_COLOR``. +from foundation_model.scripts.continual_rehearsal_common import ( # noqa: F401 (re-exports) + MATERIAL_TYPE_CLASSES, + MATERIAL_TYPE_DISPLAY_ORDER, + SCATTER_COLOR as _SCATTER_COLOR, + dump_kr_predictions, + dump_metrics, + dump_predictions, + plot_confusion, + plot_kr_sequences, + plot_parity, +) from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS, KMD, element_features, formula_to_composition # --- Task catalogue ---------------------------------------------------------- @@ -80,7 +100,7 @@ "column": "Power factor (normalized)", "t_column": "Power factor (T/K)", }, - "material_type": {"source": "qc", "kind": "clf", "column": "Material type (label)", "num_classes": 5}, + "material_type": {"source": "qc", "kind": "clf", "column": "Material type (label)", "num_classes": 3}, "tc": {"source": "superconductor", "kind": "reg", "column": "Transition temperature[K]"}, "pressure": {"source": "superconductor", "kind": "reg", "column": "Pressure[GPa]"}, "curie": {"source": "magnetic", "kind": "reg", "column": "Curie temperature[K]"}, @@ -92,9 +112,110 @@ # Raw (non-qc) regression targets span orders of magnitude (thermal conductivity, magnetization); # they are log1p-compressed, z-scored, then clipped to tame heavy tails. _RAW_TARGET_CLIP = 5.0 -DEFAULT_SEQUENCE = list(TASK_SPECS.keys()) -# Quasicrystal classes for the material_type label encoder (DAC=0, DQC=1, IAC=2, IQC=3, others=4). -QC_CLASSES = [1, 3] +# The first nine tasks may be added in any order; the last three are fixed as +# formation_energy → klat → material_type so the inverse-design heads (and especially the QC +# classifier) are the freshest at the end, when inverse design runs. +DEFAULT_SEQUENCE = [ + "density", + "dos_density", + "power_factor", + "tc", + "pressure", + "curie", + "magnetization", + "neel", + "kp", + "formation_energy", + "klat", + "material_type", +] +# The raw encoder has 5 classes (DAC=0, DQC=1, IAC=2, IQC=3, others=4). They are too imbalanced +# and finely split to learn, so we merge the approximant/quasicrystal pairs into 3 classes: +# AC = DAC + IAC, QC = DQC + IQC, others. (index == merged class id) +_MATERIAL_TYPE_MERGE = {0: 0, 2: 0, 1: 1, 3: 1, 4: 2} +# ``MATERIAL_TYPE_CLASSES`` (canonical index order) and ``MATERIAL_TYPE_DISPLAY_ORDER`` (bottom- +# left → top-right confusion-matrix order) now live in ``continual_rehearsal_common`` and are +# re-exported through this module's import block above so existing ``from … import`` paths +# still work for callers (notably continual_rehearsal_full). +# Quasicrystal class index (merged) used as the inverse-design classification objective. +QC_CLASSES = [1] + +# --- Presentation ------------------------------------------------------------- +# Human-readable, properly capitalized task names for every plot title / axis / table cell. +TASK_DISPLAY: dict[str, str] = { + "density": "Density", + "formation_energy": "Formation Energy", + "dos_density": "DOS Density", + "power_factor": "Power Factor", + "material_type": "Material Type", + "tc": "Critical Temperature (Tc)", + "pressure": "Pressure", + "curie": "Curie Temperature", + "magnetization": "Magnetization", + "neel": "Néel Temperature", + "kp": "Phonon Conductivity (κₚ)", + "klat": "Lattice Conductivity (κ_lat)", +} +# A 12-colour qualitative palette (Seaborn "deep" + extras) so every task keeps one stable colour +# across all figures — no default-cycle collisions when 12 tasks share a legend. +_PALETTE = [ + "#4C72B0", + "#DD8452", + "#55A868", + "#C44E52", + "#8172B3", + "#937860", + "#DA8BC3", + "#8C8C8C", + "#CCB974", + "#64B5CD", + "#E377C2", + "#17BECF", +] +# Single colour for every regression parity scatter (per-task colours stay for the line plots). +# Defined in ``continual_rehearsal_common.SCATTER_COLOR`` and re-exported above as +# ``_SCATTER_COLOR`` so any caller doing ``from continual_rehearsal_demo import _SCATTER_COLOR`` +# (notably continual_rehearsal_full) keeps working. + + +def _display(task: str) -> str: + """Pretty, capitalized task name for plots/tables.""" + return TASK_DISPLAY.get(task, task.replace("_", " ").title()) + + +def _scale_label(task: str) -> str: + """Plotted target scale (every target is preprocessed, so there is no raw physical unit).""" + return "normalized" if TASK_SPECS[task]["source"] == "qc" else "log1p, z-scored" + + +def _title(task: str) -> str: + """Plot title: property name + the scale it is plotted in (metrics go inside the axes).""" + return f"{_display(task)} ({_scale_label(task)})" + + +def _apply_plot_style() -> None: + """One white-background, consistent matplotlib look for every figure in the demo.""" + plt.rcParams.update( + { + "figure.facecolor": "white", + "axes.facecolor": "white", + "savefig.facecolor": "white", + "savefig.bbox": "tight", + "figure.dpi": 130, + "savefig.dpi": 150, + "font.size": 11, + "axes.titlesize": 13, + "axes.titleweight": "semibold", + "axes.labelsize": 11, + "axes.spines.top": False, + "axes.spines.right": False, + "axes.grid": True, + "grid.alpha": 0.25, + "grid.linestyle": "-", + "legend.frameon": False, + "lines.linewidth": 1.6, + } + ) @dataclass @@ -126,12 +247,31 @@ class ContinualRehearsalConfig: kr_lr: float = 5e-4 kr_decay: float = 5e-5 - # Inverse-design stage + # Inverse-design stage: primary objective = raise QC probability; secondary = low formation + # energy + high lattice thermal conductivity. Seeds are the highest-QC training compositions. inverse_n_seeds: int = 16 inverse_steps: int = 300 inverse_lr: float = 0.05 - inverse_reg_tasks: list[str] = field(default_factory=lambda: ["density", "formation_energy"]) - inverse_reg_targets: list[float] = field(default_factory=lambda: [1.5, -1.5]) + inverse_class_weight: float = 5.0 # weight of the QC objective relative to the regression ones + # Cycle-consistency: pulls the optimized latent toward what the AE can faithfully reconstruct, + # so after-decode predictions stay close to in-latent values. 0 = off; 0.1–1.0 typical. + # ae_align_scale for the latent inverse-design path: [0, 1], 0 = no alignment penalty (the + # failure-mode baseline shown in PR #18), 1 = strong alignment, 0.5 = empirical sweet spot. + inverse_ae_align_scale: float = 0.5 + inverse_reg_tasks: list[str] = field(default_factory=lambda: ["formation_energy", "klat"]) + inverse_reg_targets: list[float] = field(default_factory=lambda: [-2.0, 2.0]) # low f.e., high klat + # How the optimization's starting latents are seeded: + # "top_qc" – the inverse_seed_split compositions the model scores highest on QC probability; + # "random" – a random sample from inverse_seed_split; + # "explicit" – the exact compositions listed in inverse_seed_compositions. + inverse_seed_strategy: str = "top_qc" + inverse_seed_split: str = "train" # split to draw seeds from ("train"/"val"/"test"/"all") + inverse_seed_compositions: list[str] = field(default_factory=list) # used when strategy == "explicit" + # Compositions appended to the strategy-selected seeds regardless of QC ranking. Each is + # required to have a computable descriptor (we fail-fast on those that don't). The output + # ``seeds.json`` records the explicit-append entries separately from the strategy-selected + # ones — see ``_select_seeds`` below. + inverse_seed_explicit_append: list[str] = field(default_factory=list) random_seed: int = 2025 datamodule_random_seed: int = 42 @@ -146,6 +286,24 @@ def __post_init__(self) -> None: raise ValueError("replay_ratio must be in [0, 1] (0 = no rehearsal).") if len(self.inverse_reg_tasks) != len(self.inverse_reg_targets): raise ValueError("inverse_reg_tasks and inverse_reg_targets must have equal length.") + if self.inverse_seed_strategy not in {"top_qc", "random", "explicit"}: + raise ValueError("inverse_seed_strategy must be 'top_qc', 'random', or 'explicit'.") + if self.inverse_seed_split not in {"train", "val", "test", "all"}: + raise ValueError("inverse_seed_split must be 'train', 'val', 'test', or 'all'.") + if self.inverse_seed_strategy == "explicit" and not self.inverse_seed_compositions: + raise ValueError("inverse_seed_strategy='explicit' requires inverse_seed_compositions.") + # ``n_seeds <= 0`` is silently broken — ``_select_seeds`` returns only the + # explicit-append entries (sometimes zero of them), and downstream code crashes much + # later with confusing shape errors. Fail loudly at config-load time instead. + if self.inverse_n_seeds <= 0: + raise ValueError(f"inverse_n_seeds must be > 0, got {self.inverse_n_seeds}.") + # ``ae_align_scale ∉ [0, 1]`` would eventually be rejected by ``optimize_latent`` at + # runtime; catching it at the config layer points the user at the TOML, not at a + # backtrace inside the model. + if not 0.0 <= self.inverse_ae_align_scale <= 1.0: + raise ValueError( + f"inverse_ae_align_scale must be in [0, 1], got {self.inverse_ae_align_scale}." + ) def _as_float_array(cell: Any) -> np.ndarray: @@ -178,6 +336,9 @@ def __init__(self, config: ContinualRehearsalConfig): self.config = config self.output_dir = Path(config.output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) + _apply_plot_style() + # Stable colour per task (by position in the configured sequence) across every figure. + self._task_colors = {name: _PALETTE[i % len(_PALETTE)] for i, name in enumerate(config.task_sequence)} # KMD-1d featurizer over the bundled element features (invertible: descriptor -> composition). self._kmd = KMD(element_features.values, method="1d", n_grids=config.n_grids, sigma="auto", scale=True) self.x_dim = int(self._kmd.transform(np.eye(1, len(DEFAULT_ELEMENTS))).shape[1]) @@ -205,7 +366,12 @@ def _load_data(self) -> None: for name, df in sources.items(): df = df.copy() if cfg.sample_per_dataset is not None and cfg.sample_per_dataset < len(df): - df = df.iloc[rng.choice(len(df), size=cfg.sample_per_dataset, replace=False)] + if name == "qc" and "Material type (label)" in df.columns: + # Stratify: keep every minority (non-"others") material-type row so the rare + # AC/QC classes survive the cap, then fill the rest with random "others". + df = self._stratified_qc_sample(df, cfg.sample_per_dataset, rng) + else: + df = df.iloc[rng.choice(len(df), size=cfg.sample_per_dataset, replace=False)] comp_col = "composition" if name != "qc" else "composition" df["__key__"] = [_composition_key(v) for v in df[comp_col]] df = df.dropna(subset=["__key__"]).drop_duplicates(subset="__key__", keep="first").set_index("__key__") @@ -227,6 +393,9 @@ def _load_data(self) -> None: raise KeyError(f"Task '{task_name}': column '{col}' missing in {spec['source']} data.") frame = pd.DataFrame(index=df.index) values = df[col] + if task_name == "material_type": + # Merge the 5 fine labels into AC / QC / others (see _MATERIAL_TYPE_MERGE). + values = values.map(_MATERIAL_TYPE_MERGE) if spec["source"] != "qc" and spec["kind"] == "reg": # log1p compresses the orders-of-magnitude range, then z-score + clip tails. # Scaling stats come from *train* rows only to avoid leaking val/test distribution. @@ -254,6 +423,32 @@ def _load_qc(self) -> pd.DataFrame: df = df.loc[~df.index.isin(dropped)] return df + @staticmethod + def _stratified_qc_sample(df: pd.DataFrame, cap: int, rng: np.random.Generator) -> pd.DataFrame: + """Cap qc rows while keeping every minority (non-"others") material-type row.""" + labels = df["Material type (label)"] + minority = df[labels != 4] # DAC/DQC/IAC/IQC (others == 4) + others = df[labels == 4] + n_others = max(cap - len(minority), 0) + if n_others < len(others): + others = others.iloc[rng.choice(len(others), size=n_others, replace=False)] + out = pd.concat([minority, others]) + if len(out) > cap: # minorities alone exceed the cap (unlikely): subsample uniformly + out = out.iloc[rng.choice(len(out), size=cap, replace=False)] + return out + + def _class_weights(self, task_name: str) -> list[float]: + """Balanced (inverse-frequency) class weights from the train split, so a dominant class + doesn't collapse predictions onto itself.""" + spec = TASK_SPECS[task_name] + frame = self.task_frames[task_name] + num_classes = int(spec["num_classes"]) + train = frame.loc[frame["split"] == "train", spec["column"]].dropna().astype(int) + counts = np.bincount(train, minlength=num_classes).astype(float) + counts[counts == 0] = 1.0 # avoid divide-by-zero for an absent class + weights = counts.sum() / (num_classes * counts) # sklearn "balanced" scheme + return weights.tolist() + def descriptor_fn(self, compositions: list[str]) -> pd.DataFrame: """KMD-1d descriptors for composition keys (computed once per unique key, cached).""" uncached = [c for c in dict.fromkeys(compositions) if c not in self._desc_cache] @@ -297,6 +492,7 @@ def _build_task_config(self, task_name: str): data_column=spec["column"], dims=[ld, hd, 32], num_classes=spec["num_classes"], + class_weights=self._class_weights(task_name), # counter the others-class imbalance optimizer=OptimizerConfig(lr=cfg.head_lr, weight_decay=1e-5), ) train_t = self._collect_train_t(task_name) @@ -327,18 +523,29 @@ def _collect_train_t(self, task_name: str) -> np.ndarray: # ------------------------------------------------------------------ run - def run(self) -> None: + def _build_empty_model(self) -> FlexibleMultiTaskModel: cfg = self.config - seed_everything(cfg.random_seed, workers=True) - encoder_config = MLPEncoderConfig(hidden_dims=[self.x_dim, cfg.encoder_hidden, cfg.latent_dim]) - model = FlexibleMultiTaskModel( + return FlexibleMultiTaskModel( task_configs=[], encoder_config=encoder_config, enable_autoencoder=True, shared_block_optimizer=OptimizerConfig(lr=cfg.encoder_lr, weight_decay=1e-2), ) + def _build_full_model(self) -> FlexibleMultiTaskModel: + """Recreate the final post-training model (all tasks added in order) so a saved + state_dict can be loaded for inverse-only runs.""" + model = self._build_empty_model() + for task_name in self.config.task_sequence: + model.add_task(self._build_task_config(task_name)) + return model + + def run(self) -> None: + cfg = self.config + seed_everything(cfg.random_seed, workers=True) + model = self._build_empty_model() + task_configs: dict[str, Any] = {} metric_history: dict[str, list[tuple[int, float]]] = {name: [] for name in cfg.task_sequence} records: list[dict[str, Any]] = [] @@ -386,18 +593,54 @@ def run(self) -> None: metric = self._evaluate_task(model, name, step_dir, is_new=(name == task_name), test_keys=test_keys) step_metrics[name] = metric metric_history[name].append((step + 1, metric["primary"])) + # Persist a checkpoint after each step so the model state at any intermediate stage + # can be recovered (useful for "what did the encoder look like just after task K was + # introduced?" analyses, and for downstream restart without retraining the prefix). + step_ckpt = step_dir / "checkpoint.pt" + torch.save( + { + "model": model.state_dict(), + "task_sequence": list(cfg.task_sequence), + "step": step + 1, + "new_task": task_name, + "active_tasks": list(active), + }, + step_ckpt, + ) records.append({"step": step + 1, "new_task": task_name, "metrics": step_metrics}) summary = ", ".join(f"{k}={v['primary']:.3f}" for k, v in step_metrics.items()) - logger.info(f"Step {step + 1}: {summary}") + logger.info(f"Step {step + 1}: {summary} (ckpt: {step_ckpt.relative_to(self.output_dir)})") self._plot_forgetting(metric_history) (self.output_dir / "experiment_records.json").write_text(json.dumps(records, indent=2), encoding="utf-8") + # Persist the final model so inverse-design experiments can be re-run without retraining. + ckpt_path = self.output_dir / "final_model.pt" + torch.save({"model": model.state_dict(), "task_sequence": list(cfg.task_sequence)}, ckpt_path) + logger.info(f"Saved final model checkpoint to {ckpt_path}") + inverse = self._inverse_design(model) (self.output_dir / "inverse_design.json").write_text(json.dumps(inverse, indent=2), encoding="utf-8") self._write_report_html(records, inverse) - logger.info(f"Done. Outputs in {self.output_dir}") + + def run_inverse_only(self, ckpt_path: Path) -> None: + """Skip training; load a saved checkpoint and run only the inverse-design stage. + + Use this to iterate on the inverse-design objective (e.g. ``inverse_ae_align_scale``) without + repeating the multi-hour training. Data loading + descriptor computation still happen, but + no Trainer.fit calls. + """ + logger.info(f"=== Inverse-only mode: loading model checkpoint {ckpt_path} ===") + seed_everything(self.config.random_seed, workers=True) + model = self._build_full_model() + state = torch.load(ckpt_path, map_location="cpu", weights_only=True) + state_dict = state["model"] if isinstance(state, dict) and "model" in state else state + model.load_state_dict(state_dict) + model.eval() + inverse = self._inverse_design(model) + (self.output_dir / "inverse_design.json").write_text(json.dumps(inverse, indent=2), encoding="utf-8") + logger.info(f"Inverse-only done. Outputs in {self.output_dir}") # ------------------------------------------------------------------ eval @@ -419,6 +662,14 @@ def _descriptor_tensor(self, comps: list[str], device) -> tuple[torch.Tensor, li return torch.tensor(desc.loc[comps].values, dtype=torch.float32, device=device), comps def _evaluate_task(self, model, task_name, step_dir, *, is_new, test_keys=None) -> dict[str, float]: + """Evaluate ``task_name`` on the held-out test split and persist (predictions + metrics). + + At every step we now save the raw `(composition, true, pred)` for **every active head** + so the plots can be redrawn later from the parquet without re-running training. Plots + themselves still go via the per-task ``_plot_*`` helpers when ``is_new`` so the per-step + directory stays focused on the new task; old-task parity / confusion / KR plots can be + regenerated downstream from the parquet if desired. + """ spec = TASK_SPECS[task_name] kind = spec["kind"] model.eval() @@ -446,7 +697,9 @@ def _evaluate_task(self, model, task_name, step_dir, *, is_new, test_keys=None) "primary": r2, } if is_new: - self._plot_parity(true, pred, task_name, r2, step_dir) + plot_parity(true, pred, task_name, r2, step_dir, title=_title(task_name)) + dump_predictions(task_name, step_dir, comps=list(comps), true=true, pred=pred) + dump_metrics(task_name, step_dir, metric) return metric logits = head(h) pred = logits.argmax(dim=-1).cpu().numpy() @@ -459,7 +712,18 @@ def _evaluate_task(self, model, task_name, step_dir, *, is_new, test_keys=None) "primary": acc, } if is_new: - self._plot_confusion(true, pred, task_name, acc, step_dir, spec["num_classes"]) + plot_confusion( + true, + pred, + task_name, + acc, + step_dir, + spec["num_classes"], + title=_display(task_name), + special_material_type=(task_name == "material_type"), + ) + dump_predictions(task_name, step_dir, comps=list(comps), true=true, pred=pred) + dump_metrics(task_name, step_dir, metric) return metric # kernel regression @@ -490,9 +754,25 @@ def _evaluate_task(self, model, task_name, step_dir, *, is_new, test_keys=None) "primary": r2, } if is_new: - self._plot_kr_sequences(keep, t_list, true_parts, pred, task_name, r2, step_dir) + plot_kr_sequences(keep, t_list, true_parts, pred, task_name, step_dir, title=_title(task_name)) + # For KR tasks the parquet carries the t and y series per composition so the curves + # are fully reconstructible without rerunning the encoder. + dump_kr_predictions( + task_name, + step_dir, + comps=list(keep), + t_list=[t.cpu().numpy() for t in t_list], + true_parts=true_parts, + pred=pred, + ) + dump_metrics(task_name, step_dir, metric) return metric + # --- per-step persistence helpers -------------------------------------------------------- + # ``dump_predictions`` / ``dump_kr_predictions`` / ``dump_metrics`` now live in + # :mod:`continual_rehearsal_common` and are imported at the top of this file; the bound-method + # versions used to sit here but were verbatim duplicates of full's copies and caused drift. + # ------------------------------------------------------------------ inverse design def _inverse_design(self, model) -> dict[str, Any]: @@ -501,13 +781,6 @@ def _inverse_design(self, model) -> dict[str, Any]: device = next(model.parameters()).device model.eval() - # Seed from qc test compositions (material_type is defined there). - seeds = self._test_rows("material_type")[: cfg.inverse_n_seeds] - x_seed, seeds = self._descriptor_tensor(seeds, device) - if not seeds: - logger.warning("No seeds available for inverse design.") - return {} - reg_targets = {t: v for t, v in zip(cfg.inverse_reg_tasks, cfg.inverse_reg_targets)} def _qc_prob(x: torch.Tensor) -> np.ndarray: @@ -521,6 +794,14 @@ def _reg_preds(x: torch.Tensor) -> dict[str, np.ndarray]: h = torch.tanh(model.encoder(x)) return {t: model.task_heads[t](h).squeeze(-1).cpu().numpy() for t in reg_targets} + # Seed the optimization (strategy configurable: top_qc / random / explicit), then push the + # latents toward QC (primary) with low formation energy / high klat (secondary). + seeds = self._select_seeds(model, device, _qc_prob) + x_seed, seeds = self._descriptor_tensor(seeds, device) + if not seeds: + logger.warning("No seeds available for inverse design.") + return {} + before_qc = _qc_prob(x_seed) before_reg = _reg_preds(x_seed) @@ -528,6 +809,8 @@ def _reg_preds(x: torch.Tensor) -> dict[str, np.ndarray]: initial_input=x_seed, task_targets=reg_targets, class_targets={"material_type": QC_CLASSES}, + class_target_weight=cfg.inverse_class_weight, # QC probability is the primary objective + ae_align_scale=cfg.inverse_ae_align_scale, # keep optimised latent on the AE manifold optimize_space="latent", steps=cfg.inverse_steps, lr=cfg.inverse_lr, @@ -566,6 +849,108 @@ def _reg_preds(x: torch.Tensor) -> dict[str, np.ndarray]: logger.info(f"Inverse design QC prob (round-trip): {before_qc.mean():.3f} -> {after_qc.mean():.3f}") return {"reg_targets": reg_targets, "qc_classes": QC_CLASSES, "n_seeds": len(seeds), "records": records} + @staticmethod + def _element_system(composition: str) -> frozenset[str]: + """Element symbols (no amounts) in a composition string — used for system-level dedup.""" + return frozenset(re.findall(r"[A-Z][a-z]?", composition)) + + def _select_seeds(self, model, device, qc_prob_fn) -> list[str]: + """Inverse-design seed compositions, per the configured strategy (top_qc / random / explicit). + + Seeds are deduplicated by **element system** (the set of element symbols, ignoring ratios) + — keeping only the best-scoring representative for each element set. Without this, the + top-QC list tends to collapse into many near-duplicates of the same alloy family (e.g. + Mg-Al-Cu in slightly different ratios), which both wastes seed budget and is misleading + when reporting the diversity of inverse-design outputs. + + If ``inverse_seed_explicit_append`` is non-empty, those compositions are added on top of + the strategy-selected seeds (after the same element-system dedup). The strategy budget is + reduced by the number of appended seeds, so the total length equals ``inverse_n_seeds``. + Appended compositions whose descriptor cannot be computed are rejected fail-fast. + """ + cfg = self.config + n = cfg.inverse_n_seeds + + # Pre-validate the explicit-append seeds (if any) so we can fail fast on bad input. + appended: list[str] = [] + for raw in cfg.inverse_seed_explicit_append: + norm = normalize_composition(raw) or str(raw) + if norm not in self._desc_cache and self.descriptor_fn([norm]).empty: + raise ValueError( + f"inverse_seed_explicit_append entry {raw!r} has no computable descriptor " + "(check the formula and that all elements are in DEFAULT_ELEMENTS)." + ) + appended.append(norm) + # Dedup the appended list itself by element system (in case the user listed near-duplicates). + appended = self._dedupe_by_element_system(appended, len(appended)) + n_strategy = max(0, n - len(appended)) + + def _finalise(strategy_seeds: list[str]) -> list[str]: + """Combine strategy seeds + explicit-append, skipping any duplicate element systems.""" + return self._merge_strategy_and_explicit(strategy_seeds, appended, n_strategy) + + if cfg.inverse_seed_strategy == "explicit": + seeds = [normalize_composition(c) or str(c) for c in cfg.inverse_seed_compositions] + seeds = [c for c in seeds if c in self._desc_cache or not self.descriptor_fn([c]).empty] + return _finalise(self._dedupe_by_element_system(seeds, n_strategy)) + + # Candidate pool: the chosen split of the material_type frame, with a valid descriptor. + frame = self.task_frames["material_type"] + index = ( + frame.index if cfg.inverse_seed_split == "all" else frame.index[frame["split"] == cfg.inverse_seed_split] + ) + pool = [c for c in index if c in self._desc_cache or not self.descriptor_fn([c]).empty] + if not pool: + return appended # nothing in the pool — fall back to just the explicit appends + + if cfg.inverse_seed_strategy == "random": + rng = np.random.default_rng(cfg.random_seed) + # Shuffle the whole pool, then dedupe by element system to keep ``n`` unique families. + shuffled = [pool[i] for i in rng.permutation(len(pool))] + return _finalise(self._dedupe_by_element_system(shuffled, n_strategy)) + + # "top_qc": highest predicted QC probability — dedup keeps the best representative + # per element set, so 16 seeds means 16 distinct alloy families (not 16 ratio variants + # of three families). + x, pool = self._descriptor_tensor(pool, device) + probs = qc_prob_fn(x) + ranked = [pool[i] for i in np.argsort(probs)[::-1]] + return _finalise(self._dedupe_by_element_system(ranked, n_strategy)) + + @classmethod + def _dedupe_by_element_system(cls, candidates: list[str], n: int) -> list[str]: + """Walk ``candidates`` in order, keep the first occurrence of each element set, cap at ``n``.""" + seen: set[frozenset[str]] = set() + out: list[str] = [] + for comp in candidates: + key = cls._element_system(comp) + if not key or key in seen: + continue + seen.add(key) + out.append(comp) + if len(out) >= n: + break + return out + + @classmethod + def _merge_strategy_and_explicit( + cls, + strategy_seeds: list[str], + appended: list[str], + n_strategy: int, + ) -> list[str]: + """Combine strategy-selected seeds with explicit-append seeds, deduping by element-system. + + Strategy seeds whose element-system collides with any appended seed are dropped, then the + list is truncated to ``n_strategy`` so the final total length is ``n_strategy + len(appended)``. + Appended seeds always survive (they were already deduped against themselves upstream). + Extracted from ``_select_seeds._finalise`` so the dedup contract is unit-testable without + the full runner. + """ + seen_keys = {cls._element_system(c) for c in appended} + kept_strategy = [c for c in strategy_seeds if cls._element_system(c) not in seen_keys] + return kept_strategy[:n_strategy] + appended + def _decode_compositions(self, descriptors: np.ndarray) -> list[str]: """KMD.inverse: descriptor -> element weights -> compact formula string.""" try: @@ -582,87 +967,90 @@ def _decode_compositions(self, descriptors: np.ndarray) -> list[str]: # ------------------------------------------------------------------ plots - def _plot_parity(self, true, pred, task_name, r2, step_dir): - fig, ax = plt.subplots(figsize=(5, 5), dpi=130) - ax.scatter(true, pred, s=8, alpha=0.4, edgecolor="none") - lo, hi = float(min(true.min(), pred.min())), float(max(true.max(), pred.max())) - ax.plot([lo, hi], [lo, hi], "r--", lw=1) - ax.set_xlabel("true") - ax.set_ylabel("pred") - ax.set_title(f"{task_name} (new) — R²={r2:.3f}, n={len(true)}") - fig.tight_layout() - fig.savefig(step_dir / f"{task_name}_parity.png") - plt.close(fig) - - def _plot_confusion(self, true, pred, task_name, acc, step_dir, num_classes): - cm = np.zeros((num_classes, num_classes), dtype=int) - for t, p in zip(true, pred): - if 0 <= t < num_classes and 0 <= p < num_classes: - cm[t, p] += 1 - fig, ax = plt.subplots(figsize=(5, 4.5), dpi=130) - im = ax.imshow(cm, cmap="Blues") - fig.colorbar(im, ax=ax) - ax.set_xlabel("pred") - ax.set_ylabel("true") - ax.set_title(f"{task_name} (new) — acc={acc:.3f}, n={int(cm.sum())}") - fig.tight_layout() - fig.savefig(step_dir / f"{task_name}_confusion.png") - plt.close(fig) - - def _plot_kr_sequences(self, comps, t_list, true_parts, pred, task_name, r2, step_dir): - fig, ax = plt.subplots(figsize=(6, 4), dpi=130) - offset = 0 - for i in range(min(3, len(comps))): - n = true_parts[i].size - t = t_list[i].cpu().numpy() - ax.plot(t, true_parts[i], lw=1.2, alpha=0.8, label=f"true #{i}") - ax.plot(t, pred[offset : offset + n], lw=1.0, ls="--", alpha=0.8, label=f"pred #{i}") - offset += n - ax.set_xlabel("t") - ax.set_ylabel("value (norm)") - ax.set_title(f"{task_name} (new) — R²={r2:.3f}") - ax.legend(fontsize=7, ncol=2) - fig.tight_layout() - fig.savefig(step_dir / f"{task_name}_sequences.png") - plt.close(fig) + # ------------------------------------------------------------------ plots + # ``plot_parity`` / ``plot_confusion`` / ``plot_kr_sequences`` now live in + # :mod:`continual_rehearsal_common`. They used to be bound methods here, but every line + # was a verbatim copy of full's version — the duplication caused PR #18's K=0 + # ``NameError`` to ship in demo for several PRs before being noticed. The runner-specific + # plots that DO need ``self`` state (``_plot_forgetting`` below, ``_plot_inverse_design``) + # stay as bound methods. def _plot_forgetting(self, metric_history): - fig, ax = plt.subplots(figsize=(8, 5), dpi=130) + # Wide enough to spread many steps; legend sits outside so it scales to dozens of tasks. + n_tasks = sum(1 for pts in metric_history.values() if pts) + fig, ax = plt.subplots(figsize=(13, max(5.5, 0.32 * n_tasks + 3))) + all_steps: set[int] = set() for task_name, points in metric_history.items(): if not points: continue steps = [s for s, _ in points] vals = [v for _, v in points] - ax.plot(steps, vals, marker="o", label=task_name) - ax.set_xlabel("finetuning step") - ax.set_ylabel("primary metric (R² / accuracy)") - ax.set_title("Per-task performance vs continual finetuning step") - ax.grid(True, alpha=0.3) - ax.legend(fontsize=8, ncol=2) - fig.tight_layout() + all_steps.update(steps) + is_clf = TASK_SPECS[task_name]["kind"] == "clf" + ax.plot( + steps, + vals, + marker="s" if is_clf else "o", + ms=5, + ls="--" if is_clf else "-", + color=self._task_colors.get(task_name, "#888888"), + label=_display(task_name) + (" · accuracy" if is_clf else ""), + ) + if all_steps: + ax.set_xticks(sorted(all_steps)) + ax.set_xlabel("Continual finetuning step (a new task is added at each step)") + ax.set_ylabel("Primary metric · R² (regression) / accuracy (classification)") + ax.set_title("Per-task performance across continual finetuning") + ncol = 1 if n_tasks <= 20 else 2 + ax.legend(fontsize=8, ncol=ncol, loc="upper left", bbox_to_anchor=(1.01, 1.0), borderaxespad=0.0) fig.savefig(self.output_dir / "forgetting_trajectory.png") plt.close(fig) logger.info(f"Saved forgetting trajectory to {self.output_dir / 'forgetting_trajectory.png'}") def _plot_inverse_design(self, before_qc, after_qc, before_reg, reg_latent, after_reg, reg_targets): - n_panels = 1 + len(reg_targets) - fig, axes = plt.subplots(1, n_panels, figsize=(5 * n_panels, 4), dpi=130) - axes = np.atleast_1d(axes) - idx = np.arange(len(before_qc)) - axes[0].bar(idx - 0.2, before_qc, width=0.4, label="before") - axes[0].bar(idx + 0.2, after_qc, width=0.4, label="after (decode)") - axes[0].set_title("Quasicrystal probability") - axes[0].set_xlabel("seed") - axes[0].legend(fontsize=8) - for ax, (t, tgt) in zip(axes[1:], reg_targets.items()): - ax.bar(idx - 0.25, before_reg[t], width=0.25, label="before") - ax.bar(idx, reg_latent[t], width=0.25, label="achieved (latent)") - ax.bar(idx + 0.25, after_reg[t], width=0.25, label="after (decode)") - ax.axhline(tgt, color="r", ls="--", lw=1, label=f"target={tgt}") - ax.set_title(f"{t} prediction") - ax.set_xlabel("seed") - ax.legend(fontsize=7) - fig.tight_layout() + """Parallel-coordinates per objective. Primary: QC probability (seed → optimized/decoded), + which should rise toward 1. Secondary: the regression targets (seed → optimized-in-latent → + decoded round-trip), each toward its target line.""" + reg_names = list(reg_targets) + n_seeds = len(before_qc) + n_panels = 1 + len(reg_names) + fig, axes = plt.subplots(1, n_panels, figsize=(4.6 * n_panels, 4.2), squeeze=False) + axes = axes[0] + + # Primary objective: quasicrystal probability, seed → decoded round-trip. + axq = axes[0] + for i in range(n_seeds): + axq.plot([0, 1], [before_qc[i], after_qc[i]], color="#55A868", alpha=0.4, lw=1.0, marker="o", ms=3) + axq.axhline(1.0, color="#C44E52", ls="--", lw=1.6, label="target = 1.0") + axq.set_xticks([0, 1], ["seed", "optimized\n(decoded)"]) + axq.set_xlim(-0.3, 1.3) + axq.set_ylim(-0.02, 1.02) + axq.set_ylabel("P(quasicrystal)") + axq.set_title("Quasicrystal Probability ↑", fontsize=12) + axq.legend(loc="best", fontsize=9) + + # Secondary objectives: regression targets. + stages = ["seed\nprediction", "optimized\n(latent)", "decoded\n(round-trip)"] + for ax, t in zip(axes[1:], reg_names): + color = self._task_colors.get(t, _PALETTE[0]) + for i in range(n_seeds): + ax.plot( + [0, 1, 2], + [before_reg[t][i], reg_latent[t][i], after_reg[t][i]], + color=color, + alpha=0.35, + lw=1.0, + marker="o", + ms=3, + ) + ax.axhline(reg_targets[t], color="#C44E52", ls="--", lw=1.6, label=f"target = {reg_targets[t]:+.1f}") + ax.set_xticks([0, 1, 2], stages) + ax.set_xlim(-0.3, 2.3) + ax.set_ylabel("Predicted value") + ax.set_title(f"{_display(t)} {'↓' if reg_targets[t] < 0 else '↑'}", fontsize=12) + ax.legend(loc="best", fontsize=9) + + fig.suptitle("Inverse design — primary: raise QC probability · secondary: low f.e., high κ_lat", y=1.03) fig.savefig(self.output_dir / "inverse_design.png") plt.close(fig) logger.info(f"Saved inverse-design plot to {self.output_dir / 'inverse_design.png'}") @@ -686,7 +1074,7 @@ def _write_report_html(self, records: list[dict[str, Any]], inverse: dict[str, A spec = TASK_SPECS[task] metric_name = "acc" if spec["kind"] == "clf" else "R²" rows.append( - f"{task}{kind_label[spec['kind']]}{spec['source']}" + f"{_display(task)}{kind_label[spec['kind']]}{spec['source']}" f"{intro.get(task, float('nan')):+.3f}" f"{final.get(task, {}).get('primary', float('nan')):+.3f}{metric_name}" ) @@ -703,7 +1091,7 @@ def _write_report_html(self, records: list[dict[str, Any]], inverse: dict[str, A img = self._img_b64(f"step{i:02d}_{task}/{task}_{suffix}.png") if img: examples.append( - f'
{task} ({kind_label[kind]})
' + f'
{_display(task)} ({kind_label[kind]})
' ) seen.add(kind) @@ -718,7 +1106,7 @@ def _mean(field: str, sub: str) -> float: return float(np.mean(vals)) if vals else float("nan") inv_lines = "".join( - f"
  • {t}: {_mean('reg_before', t):+.2f} → {_mean('reg_achieved_latent', t):+.2f} " + f"
  • {_display(t)}: {_mean('reg_before', t):+.2f} → {_mean('reg_achieved_latent', t):+.2f} " f"(target {reg_targets[t]:+.1f})
  • " for t in reg_targets ) @@ -747,7 +1135,7 @@ def slide(body: str) -> str: "
  • Descriptor: invertible KMD-1d, computed on the fly (descriptor → composition via KMD.inverse).
  • " "
  • Continual finetuning: tasks added one at a time; AE head always on.
  • " f"
  • Rehearsal: learned tasks keep only {self.config.replay_ratio:.0%} of their training targets per step.
  • " - "
  • Inverse design: optimize the latent toward regression targets + quasicrystal probability, then decode a composition.
  • " + "
  • Inverse design: from the highest-QC training compositions, optimize the latent to raise quasicrystal probability (primary) with low formation energy & high κ_lat (secondary), then decode a composition.
  • " "" ), slide( @@ -765,17 +1153,18 @@ def slide(body: str) -> str: slide( "

    Inverse design

    " + (f"" if inv_img else "") - + "

    Latent optimization reached targets

      " + + "

      Primary — quasicrystal probability

      " + + f"

      mean P(QC) over seeds: {qc_before:.3f} → {qc_after:.3f} (round-trip)

      " + + "

      Secondary — regression targets (in latent)

        " + inv_lines - + f"

      Quasicrystal probability (round-trip): {qc_before:.3f} → {qc_after:.3f}

      " - + "

      Decoded compositions (KMD.inverse)

        " + + "

      Decoded compositions (KMD.inverse)

        " + decoded + "
    " ), slide( "

    Takeaways

      " "
    • One shared encoder serves regression, kernel regression, classification & reconstruction across 4 inorganic datasets.
    • " - "
    • 5% rehearsal keeps well-learned tasks (density, formation energy, material type) near their peak while new heads are added.
    • " + "
    • 5% rehearsal keeps well-learned tasks (Density, Formation Energy) near their peak while new heads are added.
    • " "
    • Latent-space optimization with regression + classification conditions hits the targets and decodes back to real compositions via the invertible KMD descriptor.
    • " "
    " ), @@ -841,13 +1230,20 @@ def _load_toml(path: Path) -> dict[str, Any]: return tomllib.loads(Path(path).read_text(encoding="utf-8")) -def _parse_args(argv: list[str] | None = None) -> ContinualRehearsalConfig: +def _parse_args(argv: list[str] | None = None) -> tuple[ContinualRehearsalConfig, argparse.Namespace]: parser = argparse.ArgumentParser(description="Continual rehearsal + inverse-design demo.") parser.add_argument("--config-file", type=Path, default=None) parser.add_argument("--output-dir", type=Path, default=None) parser.add_argument("--sample-per-dataset", type=int, default=None) parser.add_argument("--max-epochs-per-step", type=int, default=None) parser.add_argument("--accelerator", type=str, default=None) + parser.add_argument( + "--inverse-only", + type=Path, + default=None, + metavar="CKPT", + help="Skip training; load a final_model.pt checkpoint and run only the inverse-design stage.", + ) args = parser.parse_args(argv) data = _load_toml(args.config_file) if args.config_file else {} @@ -871,11 +1267,16 @@ def _parse_args(argv: list[str] | None = None) -> ContinualRehearsalConfig: logger.warning(f"Ignoring unknown config key '{key}'.") continue kwargs[key] = Path(value) if key in path_fields and value is not None else value - return ContinualRehearsalConfig(**kwargs) + return ContinualRehearsalConfig(**kwargs), args def main(argv: list[str] | None = None) -> None: - ContinualRehearsalRunner(_parse_args(argv)).run() + config, args = _parse_args(argv) + runner = ContinualRehearsalRunner(config) + if args.inverse_only is not None: + runner.run_inverse_only(args.inverse_only) + else: + runner.run() if __name__ == "__main__": diff --git a/src/foundation_model/scripts/continual_rehearsal_demo_test.py b/src/foundation_model/scripts/continual_rehearsal_demo_test.py new file mode 100644 index 0000000..4d33591 --- /dev/null +++ b/src/foundation_model/scripts/continual_rehearsal_demo_test.py @@ -0,0 +1,242 @@ +# Copyright 2025 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the configuration / pure helpers in :mod:`continual_rehearsal_demo`. + +The runner's training loop is exercised end-to-end by smoke runs (it needs real parquet +data + a GPU/MPS device), so this file targets the *units that don't need either*: + +* ``ContinualRehearsalConfig`` validation in ``__post_init__``. +* The element-system seed dedup / explicit-append logic. +* The ``_plot_kr_sequences`` regression (the function used to raise ``NameError`` when + ``comps`` was empty — see the PR #18 code review). +* The material-type 5→3 class merge map shape. +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest + +from foundation_model.scripts.continual_rehearsal_common import plot_kr_sequences +from foundation_model.scripts.continual_rehearsal_demo import ( + DEFAULT_SEQUENCE, + MATERIAL_TYPE_CLASSES, + MATERIAL_TYPE_DISPLAY_ORDER, + QC_CLASSES, + TASK_SPECS, + ContinualRehearsalConfig, + ContinualRehearsalRunner, + _MATERIAL_TYPE_MERGE, +) + + +# --- ContinualRehearsalConfig --------------------------------------------------------------- + + +def _base_kwargs(**overrides): + """Minimal valid config kwargs — only the fields without sane defaults need to be filled in. + + Paths are dummies; the validators inside ``__post_init__`` don't touch the filesystem. + """ + defaults = { + "qc_data_path": Path("/tmp/qc.parquet"), + "qc_preprocessing_path": None, + "superconductor_path": Path("/tmp/sc.parquet"), + "magnetic_path": Path("/tmp/mag.parquet"), + "phonix_path": Path("/tmp/ph.parquet"), + "output_dir": Path("/tmp/out"), + "task_sequence": list(DEFAULT_SEQUENCE), + } + defaults.update(overrides) + return defaults + + +def test_config_default_post_init_accepts_default_sequence(): + cfg = ContinualRehearsalConfig(**_base_kwargs()) + assert cfg.task_sequence == list(DEFAULT_SEQUENCE) + # Every task in the default sequence is registered in TASK_SPECS — would otherwise raise. + assert set(cfg.task_sequence) <= set(TASK_SPECS) + + +def test_config_rejects_unknown_task(): + with pytest.raises(ValueError, match="Unknown task"): + ContinualRehearsalConfig(**_base_kwargs(task_sequence=["density", "this_task_does_not_exist"])) + + +def test_config_rejects_bad_replay_ratio(): + with pytest.raises(ValueError, match="replay_ratio must be in"): + ContinualRehearsalConfig(**_base_kwargs(replay_ratio=-0.1)) + with pytest.raises(ValueError, match="replay_ratio must be in"): + ContinualRehearsalConfig(**_base_kwargs(replay_ratio=1.5)) + + +def test_config_rejects_reg_task_target_length_mismatch(): + with pytest.raises(ValueError, match="inverse_reg_tasks and inverse_reg_targets"): + ContinualRehearsalConfig( + **_base_kwargs(inverse_reg_tasks=["formation_energy", "klat"], inverse_reg_targets=[-2.0]) + ) + + +def test_config_rejects_unknown_seed_strategy(): + with pytest.raises(ValueError, match="inverse_seed_strategy must be"): + ContinualRehearsalConfig(**_base_kwargs(inverse_seed_strategy="oracle")) + + +def test_config_explicit_strategy_requires_compositions(): + with pytest.raises(ValueError, match="requires inverse_seed_compositions"): + ContinualRehearsalConfig(**_base_kwargs(inverse_seed_strategy="explicit", inverse_seed_compositions=[])) + + +def test_config_rejects_nonpositive_n_seeds(): + """``inverse_n_seeds <= 0`` would silently return only the explicit-append entries; fail + loudly at config-load time so the misuse points at the TOML, not at a downstream shape error.""" + with pytest.raises(ValueError, match="inverse_n_seeds must be > 0"): + ContinualRehearsalConfig(**_base_kwargs(inverse_n_seeds=0)) + with pytest.raises(ValueError, match="inverse_n_seeds must be > 0"): + ContinualRehearsalConfig(**_base_kwargs(inverse_n_seeds=-3)) + + +def test_config_rejects_ae_align_scale_out_of_range(): + """``ae_align_scale ∉ [0, 1]`` is rejected by the model at runtime; we catch it earlier so + the error message points at the TOML field rather than a deep model backtrace.""" + with pytest.raises(ValueError, match="inverse_ae_align_scale must be in"): + ContinualRehearsalConfig(**_base_kwargs(inverse_ae_align_scale=-0.1)) + with pytest.raises(ValueError, match="inverse_ae_align_scale must be in"): + ContinualRehearsalConfig(**_base_kwargs(inverse_ae_align_scale=1.5)) + + +# --- material-type 5→3 merge map ------------------------------------------------------------ + + +def test_material_type_merge_covers_all_5_classes_and_3_targets(): + # Source labels are 0..4 (5 classes); merged labels are 0..2 (3 classes: AC / QC / others). + assert set(_MATERIAL_TYPE_MERGE.keys()) == {0, 1, 2, 3, 4} + assert set(_MATERIAL_TYPE_MERGE.values()) == {0, 1, 2} + # QC label index must agree with QC_CLASSES. + assert QC_CLASSES == [_MATERIAL_TYPE_MERGE[1]] == [_MATERIAL_TYPE_MERGE[3]] + + +def test_material_type_class_names_and_display_order_consistent(): + # 3 merged classes, both lists carry exactly those names. + assert len(MATERIAL_TYPE_CLASSES) == 3 + assert sorted(MATERIAL_TYPE_CLASSES) == sorted(MATERIAL_TYPE_DISPLAY_ORDER) + + +# --- element-system dedup (classmethod, no runner state needed) ------------------------------ + + +def test_dedupe_by_element_system_keeps_first_per_set(): + # First occurrence per element-set wins. Mg-Al-Cu appears twice; only the first survives. + candidates = [ + "Mg12 Cu3 Ni3", # {Mg, Cu, Ni} + "Mg2 Cu1 Ni1", # {Mg, Cu, Ni} ← duplicate set, dropped + "Y8.7 Mg34.6 Zn56.8", # {Y, Mg, Zn} + "Y1 Mg1 Zn1", # {Y, Mg, Zn} ← duplicate set, dropped + "Au65 Ga20 Gd15", # {Au, Ga, Gd} + ] + out = ContinualRehearsalRunner._dedupe_by_element_system(candidates, n=10) + assert out == ["Mg12 Cu3 Ni3", "Y8.7 Mg34.6 Zn56.8", "Au65 Ga20 Gd15"] + + +def test_dedupe_by_element_system_respects_n_cap(): + candidates = [ + "Mg1", # {Mg} + "Al1", # {Al} + "Cu1", # {Cu} + "Ni1", # {Ni} + ] + out = ContinualRehearsalRunner._dedupe_by_element_system(candidates, n=2) + assert out == ["Mg1", "Al1"] + + +def test_dedupe_by_element_system_ignores_empty_strings(): + out = ContinualRehearsalRunner._dedupe_by_element_system(["", "Mg1", " ", "Al1"], n=5) + assert out == ["Mg1", "Al1"] + + +def test_merge_strategy_and_explicit_drops_strategy_seeds_sharing_element_system(): + """When an explicit-append seed (Au-Ga-Gd) shares an element-system with a strategy seed, + the *strategy* seed is dropped — the explicit-append wins because it's the user's deliberate + pick. Mirrors ``_select_seeds._finalise``'s contract end-to-end.""" + strategy = [ + "Mg12 Cu3 Ni3", # {Mg, Cu, Ni} — kept + "Au70 Ga20 Gd10", # {Au, Ga, Gd} — *dropped*, overlaps the explicit append + "Y8 Mg34 Zn58", # {Y, Mg, Zn} — kept + "Al6 Co1 Cu3", # {Al, Co, Cu} — kept + ] + appended = ["Au65 Ga20 Gd15"] # {Au, Ga, Gd} + out = ContinualRehearsalRunner._merge_strategy_and_explicit(strategy, appended, n_strategy=3) + assert out == ["Mg12 Cu3 Ni3", "Y8 Mg34 Zn58", "Al6 Co1 Cu3", "Au65 Ga20 Gd15"] + + +def test_merge_strategy_and_explicit_caps_strategy_after_dedup(): + """``n_strategy`` is the post-dedup cap on the strategy portion. Total output length is + ``n_strategy + len(appended)`` — the appended entries are always preserved.""" + strategy = ["Mg1 Cu1", "Al1 Fe1", "Zn1 Cd1"] + appended = ["Au1 Ga1"] + out = ContinualRehearsalRunner._merge_strategy_and_explicit(strategy, appended, n_strategy=2) + assert out == ["Mg1 Cu1", "Al1 Fe1", "Au1 Ga1"] + + +def test_merge_strategy_and_explicit_handles_empty_appended(): + """No explicit-append entries ⇒ just truncates the (already-deduped) strategy list.""" + out = ContinualRehearsalRunner._merge_strategy_and_explicit( + ["Mg1 Cu1", "Al1 Fe1", "Zn1 Cd1"], [], n_strategy=2 + ) + assert out == ["Mg1 Cu1", "Al1 Fe1"] + + +def test_element_system_extracts_symbols_ignoring_amounts(): + # Static-method shape: returns a frozenset of element symbols, no stoichiometry leaks through. + es = ContinualRehearsalRunner._element_system("Au65 Ga20 Gd15") + assert es == frozenset({"Au", "Ga", "Gd"}) + # Multi-digit / float amounts handled the same way. + es = ContinualRehearsalRunner._element_system("Mg36.3 Al32 Zn31.7") + assert es == frozenset({"Mg", "Al", "Zn"}) + + +# --- plot_kr_sequences empty-comps regression (P1 bug from PR #18 code review) ------------- +# The function is now in ``continual_rehearsal_common`` (PR #18 refactor); pre-refactor it lived +# as a bound method on each runner and the empty-comps NameError silently shipped on the demo +# side for several PRs. These tests pin the post-refactor behaviour from both call sites. + + +def test_plot_kr_sequences_handles_empty_comps_without_crashing(tmp_path): + """Empty ``comps`` used to raise ``NameError: line_true`` from ``fig.legend(...)``. Now it + logs a warning and returns early; no file is written.""" + out_dir = tmp_path / "step01_density" + out_dir.mkdir() + plot_kr_sequences( + comps=[], + t_list=[], + true_parts=[], + pred=np.array([]), + task_name="dos_density", + step_dir=out_dir, + title="DOS density", + ) + assert not (out_dir / "dos_density_sequences.png").exists() + + +def test_plot_kr_sequences_renders_when_comps_nonempty(tmp_path): + """Smoke: one composition's sequence renders a PNG with no errors.""" + import torch + + out_dir = tmp_path / "step01_density" + out_dir.mkdir() + t = torch.linspace(0.0, 1.0, 8) + true_part = np.linspace(0.0, 1.0, 8) + pred = np.linspace(0.05, 0.95, 8) + plot_kr_sequences( + comps=["Mg1 Cu1"], + t_list=[t], + true_parts=[true_part], + pred=pred, + task_name="dos_density", + step_dir=out_dir, + title="DOS density", + ) + assert (out_dir / "dos_density_sequences.png").exists() diff --git a/src/foundation_model/scripts/continual_rehearsal_full.py b/src/foundation_model/scripts/continual_rehearsal_full.py new file mode 100644 index 0000000..648b556 --- /dev/null +++ b/src/foundation_model/scripts/continual_rehearsal_full.py @@ -0,0 +1,2855 @@ +# Copyright 2025 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +""" +Continual multi-task rehearsal + inverse-design — **full / formal** run. + +A larger, "formal training" sibling of :mod:`continual_rehearsal_demo`. It covers the complete +inorganic task catalogue (24 supervised tasks + always-on autoencoder) over four datasets and, +relative to the demo, adds: + +* **Tiered rehearsal** — a configurable high-replay set (the inverse-design-relevant tail tasks, + e.g. formation_energy / magnetic_moment / tc / klat / material_type) keeps ``replay_ratio_high`` + of its labels when replayed as an *old* task, while every other learned task keeps ``replay_ratio``. +* **EarlyStopping** on ``val_final_loss`` (full data ⇒ ``max_epochs_per_step`` is just a ceiling). +* **Per-stage raw artifacts** — at every step, every active head's test ``(composition, true, pred)`` + is dumped to parquet (kernel heads additionally store the ``t`` series), alongside a per-task + ``_metrics.json`` and a per-step ``checkpoint.pt`` (model state + active-task metadata). + Everything lives under ``training/stepNN_/`` so any intermediate stage can be revisited. +* **Final checkpoint** — ``training/final_model.pt`` + ``training/final_model_taskconfigs.json``. +* **Multiple inverse-design scenarios** — the same final model is optimized through **eight + PR #18 paths per scenario** (3 latent ``ae_align_scale`` sweep points + 5 composition configs: + strict seed / blended seed / alloy palette / alloy + low diversity / random init), with + results, an 8-path comparison plot, an element-frequency heatmap (discovered elements + highlighted in bold orange), and `targets.json` written to ``inverse_design//``. +* **Slide-prep deliverables (no auto PPT / HTML)** — the runner emits ``SLIDE_PREP.md`` (9-section + outline + raw-data pointers), ``ANALYSIS.md`` (long-form English narrative), ``README.md`` + (directory index), and per-scenario ``comparison.png`` / ``element_frequency_heatmap.png`` + inside ``inverse_design//``. The three scenarios are first-class results — the runner + does **not** promote any single scenario as the headline (that was a demo-only convention). + The slide author builds the deck externally; every figure is reproducible from the raw arrays + without retraining. + +No layers are frozen: every step jointly trains the shared encoder + all active task heads +(``freeze_shared_encoder=False``, per-task ``freeze_parameters=False``). The "continual" behaviour +comes purely from the rehearsal mask, not from freezing. + +Run: + ./run_continual_rehearsal_full.sh samples/continual_rehearsal_full_config.toml + python -m foundation_model.scripts.continual_rehearsal_full --config-file +""" + +from __future__ import annotations + +import argparse +import datetime as _datetime +import json +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import matplotlib + +matplotlib.use("Agg") # headless + +import joblib # type: ignore[import-untyped] +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +from lightning import Trainer, seed_everything +from lightning.pytorch.callbacks import Callback, EarlyStopping +from loguru import logger +from sklearn.metrics import accuracy_score, f1_score, mean_absolute_error, r2_score # type: ignore[import-untyped] +from torch.utils.data import DataLoader + +from foundation_model.data.composition_sources import normalize_composition +from foundation_model.data.datamodule import CompoundDataModule +from foundation_model.models.flexible_multi_task_model import FlexibleMultiTaskModel +from foundation_model.models.model_config import ( + ClassificationTaskConfig, + KernelRegressionTaskConfig, + MLPEncoderConfig, + OptimizerConfig, + RegressionTaskConfig, +) +from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS, KMD, element_features, formula_to_composition + +# Shared dump/plot helpers live in the common module. The material_type constants and the +# scatter colour are consumed *inside* the common functions, so this file no longer needs to +# import them directly — they used to be imported here because the bound-method plot helpers +# inlined them. +from foundation_model.scripts.continual_rehearsal_common import ( + dump_kr_predictions, + dump_metrics, + dump_predictions, + plot_confusion, + plot_element_frequency_heatmap, + plot_kr_sequences, + plot_parity, +) +from foundation_model.scripts.continual_rehearsal_demo import ( + _PALETTE, + _apply_plot_style, + _as_float_array, + _composition_key, + _init_kernels, +) +from foundation_model.scripts.paper_inverse_comparison import ( + _emit_trajectory_outputs, + _path_slug, + _plot_qc_vs_reg_scatter, + _plot_seed_to_optimized_mapping, +) + +# --- Task catalogue ---------------------------------------------------------- +# source: dataset the task's targets come from. qc columns are pre-normalized; raw NEMAD/phonix +# regression columns are log1p + z-scored (train-only stats) + clipped at load time. +TASK_SPECS: dict[str, dict[str, Any]] = { + # --- qc: regression (9) --- + "density": {"source": "qc", "kind": "reg", "column": "Density (normalized)"}, + "efermi": {"source": "qc", "kind": "reg", "column": "Efermi (normalized)"}, + "final_energy": {"source": "qc", "kind": "reg", "column": "Final energy per atom (normalized)"}, + "formation_energy": {"source": "qc", "kind": "reg", "column": "Formation energy per atom (normalized)"}, + "total_magnetization": {"source": "qc", "kind": "reg", "column": "Total magnetization (normalized)"}, + "volume": {"source": "qc", "kind": "reg", "column": "Volume (normalized)"}, + "dielectric_total": {"source": "qc", "kind": "reg", "column": "Dielectric total (normalized)"}, + "dielectric_ionic": {"source": "qc", "kind": "reg", "column": "Dielectric ionic (normalized)"}, + "dielectric_electronic": {"source": "qc", "kind": "reg", "column": "Dielectric electronic (normalized)"}, + # --- qc: kernel regression (7) --- + "dos_density": {"source": "qc", "kind": "kr", "column": "DOS density (normalized)", "t_column": "DOS energy"}, + "electrical_resistivity": { + "source": "qc", + "kind": "kr", + "column": "Electrical resistivity (normalized)", + "t_column": "Electrical resistivity (T/K)", + }, + "power_factor": { + "source": "qc", + "kind": "kr", + "column": "Power factor (normalized)", + "t_column": "Power factor (T/K)", + }, + "seebeck": { + "source": "qc", + "kind": "kr", + "column": "Seebeck coefficient (normalized)", + "t_column": "Seebeck coefficient (T/K)", + }, + "thermal_conductivity": { + "source": "qc", + "kind": "kr", + "column": "Thermal conductivity (normalized)", + "t_column": "Thermal conductivity (T/K)", + }, + "zt": {"source": "qc", "kind": "kr", "column": "ZT (normalized)", "t_column": "ZT (T/K)"}, + "magnetic_susceptibility": { + "source": "qc", + "kind": "kr", + "column": "Magnetic susceptibility (normalized)", + "t_column": "Magnetic susceptibility (T/K)", + }, + # --- qc: classification (1) --- + "material_type": {"source": "qc", "kind": "clf", "column": "Material type (label)", "num_classes": 3}, + # --- phonix-db: regression (2) --- + "kp": {"source": "phonix", "kind": "reg", "column": "kp[W/mK]"}, + "klat": {"source": "phonix", "kind": "reg", "column": "klat[W/mK]"}, + # --- NEMAD superconductor: regression (1) --- + "tc": {"source": "superconductor", "kind": "reg", "column": "Transition temperature[K]"}, + # --- NEMAD magnetic: regression (4) --- + "magnetic_moment": {"source": "magnetic", "kind": "reg", "column": "Magnetic moment[μB/f.u.]"}, + "magnetization": {"source": "magnetic", "kind": "reg", "column": "Magnetization[A·m²/mol]"}, + "curie": {"source": "magnetic", "kind": "reg", "column": "Curie temperature[K]"}, + "neel": {"source": "magnetic", "kind": "reg", "column": "Neel temperature[K]"}, +} + +# Raw (non-qc) regression targets span orders of magnitude; log1p-compress, z-score, clip tails. +_RAW_TARGET_CLIP = 5.0 + +# Default 24-task sequence: 19 free-order tasks, then the fixed inverse-design tail (kept freshest). +DEFAULT_SEQUENCE = [ + # qc regression (free) + "density", + "efermi", + "final_energy", + "total_magnetization", + "volume", + "dielectric_total", + "dielectric_ionic", + "dielectric_electronic", + # qc kernel regression (free) + "dos_density", + "electrical_resistivity", + "power_factor", + "seebeck", + "thermal_conductivity", + "zt", + "magnetic_susceptibility", + # magnetic + phonix (free) + "magnetization", + "curie", + "neel", + "kp", + # fixed tail (inverse-design heads, freshest at the end) + "formation_energy", + "magnetic_moment", + "tc", + "klat", + "material_type", +] +# The inverse-design-relevant tail: kept at the higher replay ratio when replayed as an old task. +DEFAULT_FIXED_TAIL = ["formation_energy", "magnetic_moment", "tc", "klat", "material_type"] + +# 5 fine labels merged into AC / QC / others (index == merged class id). +# ``MATERIAL_TYPE_CLASSES`` / ``MATERIAL_TYPE_DISPLAY_ORDER`` now live in +# :mod:`continual_rehearsal_common` and are imported above; the runner-specific merge map stays. +_MATERIAL_TYPE_MERGE = {0: 0, 2: 0, 1: 1, 3: 1, 4: 2} +QC_CLASSES = [1] # merged quasicrystal class index — inverse-design classification objective. + +# --- Presentation ------------------------------------------------------------- +TASK_DISPLAY: dict[str, str] = { + "density": "Density", + "efermi": "E_Fermi", + "final_energy": "Final Energy / atom", + "formation_energy": "Formation Energy", + "total_magnetization": "Total Magnetization", + "volume": "Volume", + "dielectric_total": "Dielectric (total)", + "dielectric_ionic": "Dielectric (ionic)", + "dielectric_electronic": "Dielectric (electronic)", + "dos_density": "DOS Density", + "electrical_resistivity": "Electrical Resistivity", + "power_factor": "Power Factor", + "seebeck": "Seebeck Coefficient", + "thermal_conductivity": "Thermal Conductivity", + "zt": "ZT", + "magnetic_susceptibility": "Magnetic Susceptibility", + "material_type": "Material Type", + "kp": "Phonon Conductivity (κₚ)", + "klat": "Lattice Conductivity (κ_lat)", + "tc": "Critical Temperature (Tc)", + "magnetic_moment": "Magnetic Moment", + "magnetization": "Magnetization", + "curie": "Curie Temperature", + "neel": "Néel Temperature", +} +SOURCE_DISPLAY = { + "qc": "qc_ac_te_mp", + "phonix": "phonix-db", + "superconductor": "NEMAD superconductor", + "magnetic": "NEMAD magnetic", +} +KIND_LABEL = {"reg": "regression", "kr": "kernel regression", "clf": "classification"} + +# --- Inverse design — paths + element constraints ---------------------------- +# 48-element alloy palette for the composition-space ``C-alloy`` path (plan §5, extended). Covers +# classic i-QC / d-QC formers (Mg–Zn–RE, Al–Mn, Al–Cu–Fe, Al–Ni–Co, Au–Ga–RE …), the Sc–Zn +# 4th-period TMs, the Y–Cd 5th-period TMs (Tc excluded for radioactivity), the full Hf–Pt 5d TM +# row (added 2026-05 — broadens the heavy-TM coverage for the composition search and lets the +# optimiser reach refractory / noble-metal i-QC families like Hf–Pd / Ta–Ni / Ir-based phases), +# Au (Au–Ga–Ln seeds need it), group 13/14 enablers (B/Al/Ga/In/Tl, Si/Ge), and the 12 easy +# lanthanides. Pm/Tc are radioactive; Tm/Lu are scarce. The three explicit-append Au–Ga–Ln seeds +# (Gd/Tb/Dy) all fit in this palette. +ALLOY_PALETTE: list[str] = [ + "Mg", + "Ca", + "B", + "Al", + "Ga", + "In", + "Tl", + "Si", + "Ge", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Y", + "Zr", + "Nb", + "Mo", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + # 5d transition metals (Hf–Pt). Added 2026-05 to extend the previous 41-element palette; + # placed between Cd (end of 5th-period TMs) and Au so the 6th-period TM block is contiguous. + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "La", + "Ce", + "Pr", + "Nd", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Yb", +] + +# Inverse-design comparison configurations, one row per box in ``comparison.png``. Mirrors the +# PR #18 demo's ``paper_inverse_comparison.py``: a 3-point ``ae_align_scale`` sweep on the latent +# side (failure α=0 / mid α=0.25 / max α=1.0) plus five composition configurations that layer +# blend, palette and diversity-scale knobs against a random-init control. The ``allowed`` field +# uses the sentinel ``"__palette__"`` to refer to ``config.inverse_composition_allowed_elements`` +# (the 48-element ``ALLOY_PALETTE`` by default); every other field is fixed at the module level so +# the comparison is a stable plan-§5 ablation across runs. +_PALETTE_SENTINEL = "__palette__" +INVERSE_PATH_CONFIGS: list[dict[str, Any]] = [ + {"key": "latent_align0p0", "label": "latent α=0", "method": "latent", "ae_align_scale": 0.0}, + {"key": "latent_align0p25", "label": "latent α=0.25", "method": "latent", "ae_align_scale": 0.25}, + {"key": "latent_align1p0", "label": "latent α=1", "method": "latent", "ae_align_scale": 1.0}, + { + "key": "comp_seed", + "label": "comp (seed)", + "method": "composition", + "init": "seed", + "blend": 1.0, + "allowed": "all", + "diversity": 1.0, + }, + { + "key": "comp_seed_blend", + "label": "comp (seed, 5% all)", + "method": "composition", + "init": "seed", + "blend": 0.95, + "allowed": "all", + "diversity": 1.0, + }, + { + "key": "comp_seed_blend_palette", + "label": "comp (seed, 5% all, element list)", + "method": "composition", + "init": "seed", + "blend": 0.95, + "allowed": _PALETTE_SENTINEL, + "diversity": 1.0, + }, + { + # Ablation: clamp diversity to 0 → max entropy penalty → forced peaky few-element recipes. + "key": "comp_seed_blend_palette_lowdiv", + "label": "comp (seed, 5% all, element list, low diversity)", + "method": "composition", + "init": "seed", + "blend": 0.95, + "allowed": _PALETTE_SENTINEL, + "diversity": 0.0, + }, + { + "key": "comp_random", + "label": "comp (random)", + "method": "composition", + "init": "random", + "blend": 0.95, + "allowed": "all", + "diversity": 1.0, + }, +] +INVERSE_PATHS: list[str] = [c["key"] for c in INVERSE_PATH_CONFIGS] +INVERSE_PATH_CONFIGS_BY_KEY: dict[str, dict[str, Any]] = {c["key"]: c for c in INVERSE_PATH_CONFIGS} + +# Per-regression-task panel title (units + arrow). Matches the demo's REG_TASK_TITLES so plots +# read the same across both runners. Falls back to the bare task name if a task isn't listed. +REG_TASK_TITLES: dict[str, str] = { + "formation_energy": "Formation energy [eV/atom] ↓", + "klat": "klat [W/mK] ↑", + "magnetic_moment": "Magnetic moment [μB/f.u.] ↑", + "tc": "Critical temperature [K] ↑", +} + + +def _seed_weights_from_compositions(seeds: list[str], n_components: int) -> torch.Tensor: + """Element-weight tensor ``(B, n_components)`` for seeding ``optimize_composition``. + + Order matches DEFAULT_ELEMENTS. Raises if any seed cannot be parsed — we fail fast rather than + silently dropping rows (callers rely on per-seed correspondence with the latent path). + """ + rows = [] + for c in seeds: + w = formula_to_composition(c) + if w is None: + raise ValueError(f"Cannot parse seed composition '{c}' to element weights.") + rows.append(np.asarray(w, dtype=np.float64)) + return torch.tensor(np.stack(rows), dtype=torch.float64) + + +def _format_weights(weights: np.ndarray, top_k: int = 6, eps: float = 1e-3) -> list[str]: + """Render element-weight rows as compact formula strings (top-K elements above ``eps``).""" + out: list[str] = [] + for row in weights: + order = np.argsort(row)[::-1] + parts = [f"{DEFAULT_ELEMENTS[i]}{row[i]:.3f}" for i in order[:top_k] if row[i] > eps] + out.append(" ".join(parts) if parts else "") + return out + + +def _display(task: str) -> str: + return TASK_DISPLAY.get(task, task.replace("_", " ").title()) + + +def _scale_label(task: str) -> str: + return "normalized" if TASK_SPECS[task]["source"] == "qc" else "log1p, z-scored" + + +def _title(task: str) -> str: + return f"{_display(task)} ({_scale_label(task)})" + + +def _arrow(value: float) -> str: + return "↓" if value < 0 else "↑" + + +@dataclass +class InverseScenario: + """One inverse-design objective set (primary = QC probability; secondary = regression targets).""" + + name: str + reg_tasks: list[str] + reg_targets: list[float] + + def __post_init__(self) -> None: + if len(self.reg_tasks) != len(self.reg_targets): + raise ValueError(f"Scenario '{self.name}': reg_tasks and reg_targets must have equal length.") + + +@dataclass +class ContinualRehearsalFullConfig: + """Configuration for the full continual rehearsal + inverse-design run.""" + + qc_data_path: Path = Path("data/qc_ac_te_mp_dos_reformat_20260515.pd.parquet") + qc_preprocessing_path: Path | None = None + superconductor_path: Path = Path("data/NEMAD_superconductor_20260425.parquet") + magnetic_path: Path = Path("data/NEMAD_magnetic_20260419.parquet") + phonix_path: Path = Path("data/phonix-db-filtered_20260425.parquet") + output_dir: Path = Path("artifacts/continual_rehearsal_full") + + task_sequence: list[str] = field(default_factory=lambda: list(DEFAULT_SEQUENCE)) + fixed_tail: list[str] = field(default_factory=lambda: list(DEFAULT_FIXED_TAIL)) + replay_ratio: float = 0.05 # ordinary old-task replay ratio + replay_ratio_high: float = 0.10 # replay ratio for fixed_tail tasks when replayed as old + sample_per_dataset: int | None = None # cap rows per dataset (for fast/smoke runs) + + max_epochs_per_step: int = 100 # ceiling; EarlyStopping usually stops sooner + early_stop_patience: int = 8 + early_stop_min_delta: float = 1e-4 + batch_size: int = 256 + num_workers: int = 0 + + n_grids: int = 8 + latent_dim: int = 128 + encoder_hidden: int = 256 + head_hidden_dim: int = 64 + head_lr: float = 5e-3 + encoder_lr: float = 5e-3 + n_kernel: int = 15 + kr_lr: float = 5e-4 + kr_decay: float = 5e-5 + + # Inverse design (shared across scenarios). Primary objective is QC probability ↑; each + # scenario runs the eight PR #18 paths (3 latent + 5 composition configs) — see plan §5. + inverse_n_seeds: int = 20 # 17 top-QC dedup + 3 explicit Au-Ga-Ln formers (plan §5) + inverse_steps: int = 300 + inverse_lr: float = 0.05 + inverse_class_weight: float = 5.0 + # 48-element ``ALLOY_PALETTE`` for the composition rows that whitelist elements. Configurable + # in case the slide author wants a wider or narrower palette; everything else (ae_align_scale + # sweep, seed_blend, diversity_scale) is fixed at the module level in ``INVERSE_PATH_CONFIGS`` + # so the comparison is a stable ablation across runs. + inverse_composition_allowed_elements: list[str] = field(default_factory=lambda: list(ALLOY_PALETTE)) + inverse_seed_strategy: str = "top_qc" # "top_qc" | "random" | "explicit" + # Held-out test split is the right default for the formal full run: the model has seen the + # train compositions during training, so its top-QC ranking there is part memorisation; test + # compositions are held out, so the ranking is a genuine prediction → seeds are real novel QC + # candidates. (Override to "train" only when reproducing the demo / paper baseline.) + inverse_seed_split: str = "test" # "train" | "val" | "test" | "all" + inverse_seed_compositions: list[str] = field(default_factory=list) + # Compositions appended to the strategy-selected seeds regardless of QC ranking. Each must + # have a computable descriptor (fail-fast in _select_seeds). The strategy budget is reduced + # by len(explicit_append) so total seeds == inverse_n_seeds. Defaults to the three Au-Ga-Ln + # i-QC formers used in plan §5 (Au65 Ga20 Gd/Tb/Dy15). + inverse_seed_explicit_append: list[str] = field( + default_factory=lambda: ["Au65 Ga20 Gd15", "Au65 Ga20 Tb15", "Au65 Ga20 Dy15"] + ) + inverse_scenarios: list[InverseScenario] = field( + default_factory=lambda: [ + InverseScenario("scenario1_fe_down_moment_up", ["formation_energy", "magnetic_moment"], [-2.0, 2.0]), + InverseScenario("scenario2_fe_tc_moment", ["formation_energy", "tc", "magnetic_moment"], [-2.0, 2.0, 2.0]), + InverseScenario("scenario3_fe_down_klat_up", ["formation_energy", "klat"], [-2.0, 2.0]), + ] + ) + + random_seed: int = 2025 + datamodule_random_seed: int = 42 + accelerator: str = "auto" + devices: int = 1 + + def __post_init__(self) -> None: + unknown = [t for t in self.task_sequence if t not in TASK_SPECS] + if unknown: + raise ValueError(f"Unknown task(s) {unknown}. Available: {sorted(TASK_SPECS)}") + if len(set(self.task_sequence)) != len(self.task_sequence): + raise ValueError("task_sequence contains duplicates.") + bad_tail = [t for t in self.fixed_tail if t not in self.task_sequence] + if bad_tail: + raise ValueError(f"fixed_tail tasks {bad_tail} are not in task_sequence.") + for ratio_name, ratio in (("replay_ratio", self.replay_ratio), ("replay_ratio_high", self.replay_ratio_high)): + if not 0.0 <= ratio <= 1.0: + raise ValueError(f"{ratio_name} must be in [0, 1].") + if not self.inverse_composition_allowed_elements: + raise ValueError("inverse_composition_allowed_elements must be non-empty.") + unknown_palette = [e for e in self.inverse_composition_allowed_elements if e not in DEFAULT_ELEMENTS] + if unknown_palette: + raise ValueError( + f"inverse_composition_allowed_elements contains symbols not in DEFAULT_ELEMENTS: {unknown_palette}" + ) + if self.inverse_seed_strategy not in {"top_qc", "random", "explicit"}: + raise ValueError("inverse_seed_strategy must be 'top_qc', 'random', or 'explicit'.") + if self.inverse_seed_split not in {"train", "val", "test", "all"}: + raise ValueError("inverse_seed_split must be 'train', 'val', 'test', or 'all'.") + if self.inverse_seed_strategy == "explicit" and not self.inverse_seed_compositions: + raise ValueError("inverse_seed_strategy='explicit' requires inverse_seed_compositions.") + # Every scenario's tasks must be regression tasks present in the sequence. + for sc in self.inverse_scenarios: + for t in sc.reg_tasks: + if t not in self.task_sequence: + raise ValueError(f"Scenario '{sc.name}': task '{t}' not in task_sequence.") + if TASK_SPECS[t]["kind"] != "reg": + raise ValueError(f"Scenario '{sc.name}': task '{t}' must be a (scalar) regression task.") + if "material_type" not in self.task_sequence: + raise ValueError("task_sequence must contain 'material_type' (QC classifier for inverse design).") + + +class _DropLastTrainCompoundDataModule(CompoundDataModule): + """``CompoundDataModule`` variant whose train loader sets ``drop_last=True``. + + PyTorch ``BatchNorm1d`` in training mode raises ``ValueError: Expected more than 1 value per + channel`` on a batch of size 1. With ``shuffle=True`` and ``drop_last=False`` (the upstream + default), any train subset whose size ``mod batch_size == 1`` will eventually feed that + single-row tail batch into the encoder's ``fc_layers`` BN and crash mid-epoch — exactly what + happened in the first attempted full-data MPS run (Step 1, ``density``). + + Dropping the final partial batch costs at most ``batch_size − 1`` rows per epoch (~256 / 35k + rows in the qc train split ≈ 0.7 %), which is well within the noise of the rehearsal mask. We + only touch the train loader; val / test / predict keep ``drop_last=False`` so every held-out + row is evaluated. ``_train_sampler`` (used only by the DDP path) is left untouched — we are + not using DDP here. + """ + + def train_dataloader(self): + base = super().train_dataloader() + if base is None: + return None + return DataLoader( + base.dataset, + batch_size=base.batch_size, + shuffle=True, + num_workers=base.num_workers, + pin_memory=base.pin_memory, + collate_fn=base.collate_fn, + drop_last=True, + ) + + +class ContinualRehearsalFullRunner: + def __init__(self, config: ContinualRehearsalFullConfig): + self.config = config + self.output_dir = Path(config.output_dir) + # Plan §6 layout: training/ for per-step artifacts (incl. final_model.pt and forgetting + # trajectory), inverse_design/ for the dual-path scenarios, slide-prep / analysis / readme + # at the top level. Subdirs are created lazily where needed. + self.training_dir = self.output_dir / "training" + self.inverse_root = self.output_dir / "inverse_design" + self.output_dir.mkdir(parents=True, exist_ok=True) + self.training_dir.mkdir(parents=True, exist_ok=True) + _apply_plot_style() + self._task_colors = {name: _PALETTE[i % len(_PALETTE)] for i, name in enumerate(config.task_sequence)} + self._kmd = KMD(element_features.values, method="1d", n_grids=config.n_grids, sigma="auto", scale=True) + self.x_dim = int(self._kmd.transform(np.eye(1, len(DEFAULT_ELEMENTS))).shape[1]) + self._desc_cache: dict[str, np.ndarray] = {} + self._load_data() + + # ------------------------------------------------------------------ data + + def _load_data(self) -> None: + cfg = self.config + rng = np.random.default_rng(cfg.datamodule_random_seed) + self.task_frames: dict[str, pd.DataFrame] = {} + split_by_key: dict[str, str] = {} + + sources = { + "qc": self._load_qc(), + "superconductor": pd.read_parquet(cfg.superconductor_path), + "magnetic": pd.read_parquet(cfg.magnetic_path), + "phonix": pd.read_parquet(cfg.phonix_path), + } + + keyed: dict[str, pd.DataFrame] = {} + for name, df in sources.items(): + df = df.copy() + if cfg.sample_per_dataset is not None and cfg.sample_per_dataset < len(df): + if name == "qc" and "Material type (label)" in df.columns: + df = self._stratified_qc_sample(df, cfg.sample_per_dataset, rng) + else: + df = df.iloc[rng.choice(len(df), size=cfg.sample_per_dataset, replace=False)] + df["__key__"] = [_composition_key(v) for v in df["composition"]] + df = df.dropna(subset=["__key__"]).drop_duplicates(subset="__key__", keep="first").set_index("__key__") + keyed[name] = df + if "split" in df.columns: + for k, s in df["split"].items(): + split_by_key.setdefault(str(k), str(s)) + else: + for k in df.index: + split_by_key.setdefault(str(k), rng.choice(["train", "val", "test"], p=[0.7, 0.15, 0.15])) + + for task_name in cfg.task_sequence: + spec = TASK_SPECS[task_name] + df = keyed[spec["source"]] + col = spec["column"] + if col not in df.columns: + raise KeyError(f"Task '{task_name}': column '{col}' missing in {spec['source']} data.") + frame = pd.DataFrame(index=df.index) + values = df[col] + if task_name == "material_type": + values = values.map(_MATERIAL_TYPE_MERGE) + if spec["source"] != "qc" and spec["kind"] == "reg": + v = np.log1p(df[col].astype(float).clip(lower=0.0)) + is_train = np.array([split_by_key.get(str(k)) == "train" for k in df.index]) + ref = v[is_train] if is_train.any() else v + mean = float(ref.mean()) + std = float(ref.std(ddof=0)) or 1.0 + values = ((v - mean) / std).clip(-_RAW_TARGET_CLIP, _RAW_TARGET_CLIP) + frame[col] = values + if spec["kind"] == "kr": + frame[spec["t_column"]] = df[spec["t_column"]] + frame["split"] = [split_by_key.get(str(k), "train") for k in frame.index] + self.task_frames[task_name] = frame + + self.split_by_key = split_by_key + n_keys = len(set().union(*[set(f.index) for f in self.task_frames.values()])) + logger.info(f"Built {len(self.task_frames)} task frames over {n_keys} unique compositions; x_dim={self.x_dim}.") + + def _load_qc(self) -> pd.DataFrame: + cfg = self.config + df = pd.read_parquet(cfg.qc_data_path) + if cfg.qc_preprocessing_path is not None and Path(cfg.qc_preprocessing_path).exists(): + dropped = joblib.load(cfg.qc_preprocessing_path).get("dropped_idx", []) + df = df.loc[~df.index.isin(dropped)] + return df + + @staticmethod + def _stratified_qc_sample(df: pd.DataFrame, cap: int, rng: np.random.Generator) -> pd.DataFrame: + """Cap qc rows while keeping every minority (non-"others") material-type row.""" + labels = df["Material type (label)"] + minority = df[labels != 4] + others = df[labels == 4] + n_others = max(cap - len(minority), 0) + if n_others < len(others): + others = others.iloc[rng.choice(len(others), size=n_others, replace=False)] + out = pd.concat([minority, others]) + if len(out) > cap: + out = out.iloc[rng.choice(len(out), size=cap, replace=False)] + return out + + def _class_weights(self, task_name: str) -> list[float]: + spec = TASK_SPECS[task_name] + frame = self.task_frames[task_name] + num_classes = int(spec["num_classes"]) + train = frame.loc[frame["split"] == "train", spec["column"]].dropna().astype(int) + counts = np.bincount(train, minlength=num_classes).astype(float) + counts[counts == 0] = 1.0 + weights = counts.sum() / (num_classes * counts) + return weights.tolist() + + def descriptor_fn(self, compositions: list[str]) -> pd.DataFrame: + uncached = [c for c in dict.fromkeys(compositions) if c not in self._desc_cache] + if uncached: + weights = np.zeros((len(uncached), len(DEFAULT_ELEMENTS)), dtype=float) + valid: list[str] = [] + for key in uncached: + try: + w = formula_to_composition(key) + except Exception: + w = None + if w is None or float(w.sum()) <= 0: + continue + weights[len(valid)] = w + valid.append(key) + if valid: + desc = self._kmd.transform(weights[: len(valid)]) + for j, key in enumerate(valid): + self._desc_cache[key] = desc[j] + present = [c for c in compositions if c in self._desc_cache] + if not present: + return pd.DataFrame() + return pd.DataFrame(np.stack([self._desc_cache[c] for c in present]), index=present) + + # ------------------------------------------------------------------ configs + + def _build_task_config(self, task_name: str): + cfg = self.config + spec = TASK_SPECS[task_name] + ld, hd = cfg.latent_dim, cfg.head_hidden_dim + if spec["kind"] == "reg": + return RegressionTaskConfig( + name=task_name, + data_column=spec["column"], + dims=[ld, hd, 1], + optimizer=OptimizerConfig(lr=cfg.head_lr, weight_decay=1e-5), + ) + if spec["kind"] == "clf": + return ClassificationTaskConfig( + name=task_name, + data_column=spec["column"], + dims=[ld, hd, 32], + num_classes=spec["num_classes"], + class_weights=self._class_weights(task_name), + optimizer=OptimizerConfig(lr=cfg.head_lr, weight_decay=1e-5), + ) + train_t = self._collect_train_t(task_name) + centers, sigmas = _init_kernels(train_t, cfg.n_kernel) + return KernelRegressionTaskConfig( + name=task_name, + data_column=spec["column"], + t_column=spec["t_column"], + x_dim=[ld, 128, 64], + t_dim=[16, 8], + kernel_num_centers=cfg.n_kernel, + kernel_centers_init=centers or None, + kernel_sigmas_init=sigmas or None, + kernel_learnable_centers=True, + kernel_learnable_sigmas=True, + enable_mu3=False, + optimizer=OptimizerConfig(lr=cfg.kr_lr, weight_decay=cfg.kr_decay), + ) + + def _collect_train_t(self, task_name: str) -> np.ndarray: + spec = TASK_SPECS[task_name] + frame = self.task_frames[task_name] + mask = frame[spec["column"]].notna() & (frame["split"] == "train") + cells = frame.loc[mask, spec["t_column"]].dropna() + if cells.empty: + return np.array([]) + return np.concatenate([_as_float_array(c) for c in cells]) + + # ------------------------------------------------------------------ run + + def _build_empty_model(self) -> FlexibleMultiTaskModel: + """The bare model used as the starting point for both ``run`` and ``run_inverse_only``.""" + cfg = self.config + encoder_config = MLPEncoderConfig(hidden_dims=[self.x_dim, cfg.encoder_hidden, cfg.latent_dim]) + return FlexibleMultiTaskModel( + task_configs=[], + encoder_config=encoder_config, + enable_autoencoder=True, + shared_block_optimizer=OptimizerConfig(lr=cfg.encoder_lr, weight_decay=1e-2), + ) + + def _build_full_model(self) -> FlexibleMultiTaskModel: + """Rebuild the post-training model (all tasks added in sequence order) so a saved + ``final_model.pt`` ``state_dict`` can be loaded for inverse-only runs.""" + model = self._build_empty_model() + for task_name in self.config.task_sequence: + model.add_task(self._build_task_config(task_name)) + return model + + def run( + self, + *, + record_trajectory: bool = True, + per_seed_trajectories: bool = False, + animation_formats: tuple[str, ...] = ("gif",), + ) -> None: + cfg = self.config + seed_everything(cfg.random_seed, workers=True) + model = self._build_empty_model() + + task_configs: dict[str, Any] = {} + metric_history: dict[str, list[tuple[int, float]]] = {name: [] for name in cfg.task_sequence} + records: list[dict[str, Any]] = [] + fixed_tail = set(cfg.fixed_tail) + + for step, task_name in enumerate(cfg.task_sequence): + logger.info(f"=== Step {step + 1}/{len(cfg.task_sequence)}: add task '{task_name}' ===") + task_configs[task_name] = self._build_task_config(task_name) + model.add_task(task_configs[task_name]) + + active = cfg.task_sequence[: step + 1] + # New task fully active; old tasks replayed — fixed-tail tasks at the higher ratio. + for name in active: + if name == task_name: + ratio = 1.0 + elif name in fixed_tail: + ratio = cfg.replay_ratio_high + else: + ratio = cfg.replay_ratio + task_configs[name].task_masking_ratio = ratio + + datamodule = _DropLastTrainCompoundDataModule( + task_configs=[task_configs[name] for name in active], + descriptor_fn=self.descriptor_fn, + task_frames={name: self.task_frames[name] for name in active}, + composition_column="composition", + random_seed=cfg.datamodule_random_seed, + batch_size=cfg.batch_size, + num_workers=cfg.num_workers, + ) + callbacks: list[Callback] = [ + EarlyStopping( + monitor="val_final_loss", + mode="min", + patience=cfg.early_stop_patience, + min_delta=cfg.early_stop_min_delta, + ) + ] + trainer = Trainer( + max_epochs=cfg.max_epochs_per_step, + accelerator=cfg.accelerator, + devices=cfg.devices, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + callbacks=callbacks, + ) + trainer.fit(model, datamodule=datamodule) + + test_keys: set[str] | None = None + if datamodule.split_series is not None: + resolved = datamodule.split_series + test_keys = set(resolved.index[resolved == "test"].astype(str)) + + step_dir = self.training_dir / f"step{step + 1:02d}_{task_name}" + step_dir.mkdir(parents=True, exist_ok=True) + step_metrics: dict[str, dict[str, float]] = {} + for name in active: + # Plot only the freshly-added head; dump raw (composition, true, pred) + per-task + # metrics.json for every active head so the forgetting trajectory is backed by + # raw data and per-task numbers at each stage. + metric = self._evaluate_task(model, name, step_dir, is_new=(name == task_name), test_keys=test_keys) + step_metrics[name] = metric + metric_history[name].append((step + 1, metric["primary"])) + # Per-step model checkpoint (mirrors the demo, PR #18). Lets analysts revisit any + # intermediate stage ("what did the encoder look like just after task K was added?") + # without retraining the prefix, and feeds downstream finetune scripts. + step_ckpt = step_dir / "checkpoint.pt" + torch.save( + { + "model": model.state_dict(), + "task_sequence": list(cfg.task_sequence), + "step": step + 1, + "new_task": task_name, + "active_tasks": list(active), + }, + step_ckpt, + ) + records.append( + {"step": step + 1, "new_task": task_name, "epochs_run": trainer.current_epoch, "metrics": step_metrics} + ) + summary = ", ".join(f"{k}={v['primary']:.3f}" for k, v in step_metrics.items()) + rel_ckpt = step_ckpt.relative_to(self.output_dir) + logger.info(f"Step {step + 1} ({trainer.current_epoch} epochs): {summary} (ckpt: {rel_ckpt})") + + self._plot_forgetting(metric_history) + (self.training_dir / "experiment_records.json").write_text(json.dumps(records, indent=2), encoding="utf-8") + self._write_metrics_table(records) + self._save_final_model(model, task_configs) + + inverse = self._inverse_design( + model, + record_trajectory=record_trajectory, + per_seed_trajectories=per_seed_trajectories, + animation_formats=animation_formats, + ) + (self.inverse_root / "inverse_design.json").write_text(json.dumps(inverse, indent=2), encoding="utf-8") + + # Slide-prep deliverables (plan §6) — no more PPT/HTML; the slide author works from + # SLIDE_PREP.md + the raw arrays + the standard image set. The three scenarios are + # treated as equal first-class results — no demo-style "headline scenario" promotion. + self._write_inverse_summary_md(inverse) + self._write_analysis_md(records, inverse) + self._write_slide_prep_md(records, inverse) + self._write_readme(records, inverse) + logger.info(f"Done. Outputs in {self.output_dir}") + + def _save_final_model(self, model, task_configs: dict[str, Any]) -> None: + # Schema matches the demo's ``final_model.pt`` (PR #18) so the same downstream consumers — + # ``paper_inverse_comparison.py`` / ``finetune_inverse_heads.py`` / ``--inverse-only`` — + # can ingest checkpoints from either runner without translation. + ckpt = self.training_dir / "final_model.pt" + torch.save({"model": model.state_dict(), "task_sequence": list(self.config.task_sequence)}, ckpt) + spec_dump = { + name: { + "kind": TASK_SPECS[name]["kind"], + "column": TASK_SPECS[name]["column"], + "source": TASK_SPECS[name]["source"], + } + for name in self.config.task_sequence + } + (self.training_dir / "final_model_taskconfigs.json").write_text( + json.dumps(spec_dump, indent=2), encoding="utf-8" + ) + logger.info(f"Saved final model checkpoint to {ckpt}") + + def run_inverse_only( + self, + ckpt_path: Path, + *, + record_trajectory: bool = True, + per_seed_trajectories: bool = False, + animation_formats: tuple[str, ...] = ("gif",), + ) -> None: + """Skip training; load a saved ``final_model.pt`` and run only the inverse-design stage. + + Use this to iterate on inverse-design knobs (seed split, palette, scenarios, …) without + repeating the multi-hour training. Data loading + descriptor computation still happen — + they're prerequisites for seed selection and the composition-path kernel — but no + ``Trainer.fit`` is called. + + After the inverse-design pass we also **refresh the slide-prep deliverables** + (``ANALYSIS.md`` / ``SLIDE_PREP.md`` / ``README.md``) by loading the previous run's + ``training/experiment_records.json`` — without that, those documents would still quote + the inverse-design numbers from the previous pass. The training-derived sections + (forgetting trajectory, headline-task R² / accuracy) come from ``records`` unchanged. + If the records file is missing (e.g. inverse-only against a checkpoint from a different + run that didn't expose it), the deliverables are skipped with a warning. + """ + logger.info(f"=== Inverse-only mode: loading model checkpoint {ckpt_path} ===") + seed_everything(self.config.random_seed, workers=True) + model = self._build_full_model() + state = torch.load(ckpt_path, map_location="cpu", weights_only=True) + state_dict = state["model"] if isinstance(state, dict) and "model" in state else state + model.load_state_dict(state_dict) + model.eval() + inverse = self._inverse_design( + model, + record_trajectory=record_trajectory, + per_seed_trajectories=per_seed_trajectories, + animation_formats=animation_formats, + ) + (self.inverse_root / "inverse_design.json").write_text(json.dumps(inverse, indent=2), encoding="utf-8") + self._write_inverse_summary_md(inverse) + + # Refresh the slide-prep deliverables so their inverse-design tables / seed lists match + # the values we just re-ran. The training records live next to the checkpoint. + records_path = self.training_dir / "experiment_records.json" + if records_path.exists(): + records = json.loads(records_path.read_text(encoding="utf-8")) + self._write_analysis_md(records, inverse) + self._write_slide_prep_md(records, inverse) + self._write_readme(records, inverse) + logger.info(f"Refreshed ANALYSIS.md / SLIDE_PREP.md / README.md from {records_path}") + else: + logger.warning( + f"{records_path} not found — keeping previous ANALYSIS.md / SLIDE_PREP.md / " + "README.md unchanged. Inverse-design numbers in those docs may be stale." + ) + logger.info(f"Inverse-only done. Outputs in {self.output_dir}") + + def _write_metrics_table(self, records: list[dict[str, Any]]) -> None: + final = records[-1]["metrics"] if records else {} + intro = {r["new_task"]: r["metrics"][r["new_task"]] for r in records} + rows = [] + for task in self.config.task_sequence: + spec = TASK_SPECS[task] + metric_name = "accuracy" if spec["kind"] == "clf" else "R2" + rows.append( + { + "task": task, + "display": _display(task), + "type": KIND_LABEL[spec["kind"]], + "dataset": SOURCE_DISPLAY[spec["source"]], + "metric": metric_name, + "at_intro": intro.get(task, {}).get("primary", float("nan")), + "final": final.get(task, {}).get("primary", float("nan")), + "final_mae": final.get(task, {}).get("mae", float("nan")), + "samples": final.get(task, {}).get("samples", 0), + } + ) + pd.DataFrame(rows).to_csv(self.training_dir / "metrics_table.csv", index=False) + + # ------------------------------------------------------------------ eval + + def _test_rows(self, task_name: str, test_keys: set[str] | None = None) -> list[str]: + spec = TASK_SPECS[task_name] + frame = self.task_frames[task_name] + mask = frame[spec["column"]].notna() + mask &= frame.index.isin(test_keys) if test_keys is not None else (frame["split"] == "test") + return list(frame.index[mask]) + + def _descriptor_tensor(self, comps: list[str], device) -> tuple[torch.Tensor, list[str]]: + desc = self.descriptor_fn(comps) + comps = [c for c in comps if c in desc.index] + return torch.tensor(desc.loc[comps].values, dtype=torch.float32, device=device), comps + + def _evaluate_task(self, model, task_name, step_dir, *, is_new, test_keys=None) -> dict[str, float]: + spec = TASK_SPECS[task_name] + kind = spec["kind"] + model.eval() + device = next(model.parameters()).device + comps = self._test_rows(task_name, test_keys) + if not comps: + return {"primary": float("nan"), "samples": 0} + frame = self.task_frames[task_name] + head = model.task_heads[task_name] + + with torch.no_grad(): + if kind in ("reg", "clf"): + x, comps = self._descriptor_tensor(comps, device) + if not comps: + return {"primary": float("nan"), "samples": 0} + h = torch.tanh(model.encoder(x)) + if kind == "reg": + pred = head(h).squeeze(-1).cpu().numpy() + true = frame.loc[comps, spec["column"]].astype(float).to_numpy() + r2 = float(r2_score(true, pred)) + metric = { + "r2": r2, + "mae": float(mean_absolute_error(true, pred)), + "samples": len(comps), + "primary": r2, + } + dump_predictions(task_name, step_dir, comps=list(comps), true=true, pred=pred) + dump_metrics(task_name, step_dir, metric) + if is_new: + plot_parity(true, pred, task_name, r2, step_dir, title=_title(task_name)) + return metric + logits = head(h) + pred = logits.argmax(dim=-1).cpu().numpy() + true = frame.loc[comps, spec["column"]].astype(int).to_numpy() + acc = float(accuracy_score(true, pred)) + metric = { + "accuracy": acc, + "macro_f1": float(f1_score(true, pred, average="macro", zero_division=0)), + "samples": len(comps), + "primary": acc, + } + dump_predictions(task_name, step_dir, comps=list(comps), true=true, pred=pred) + dump_metrics(task_name, step_dir, metric) + if is_new: + plot_confusion( + true, + pred, + task_name, + acc, + step_dir, + spec["num_classes"], + title=_display(task_name), + special_material_type=(task_name == "material_type"), + ) + return metric + + # kernel regression + keep, t_list, true_parts = [], [], [] + for comp in comps: + if comp not in self._desc_cache and self.descriptor_fn([comp]).empty: + continue + y_arr = _as_float_array(frame.at[comp, spec["column"]]) + t_arr = _as_float_array(frame.at[comp, spec["t_column"]]) + if y_arr.size == 0 or y_arr.size != t_arr.size: + continue + keep.append(comp) + t_list.append(torch.tensor(t_arr, dtype=torch.float32, device=device)) + true_parts.append(y_arr) + if not keep: + return {"primary": float("nan"), "samples": 0} + xk, _ = self._descriptor_tensor(keep, device) + h_k = torch.tanh(model.encoder(xk)) + expanded_h, expanded_t = model._expand_for_kernel_regression(h_k, t_list) + pred = head(expanded_h, t=expanded_t).squeeze(-1).cpu().numpy() + true = np.concatenate(true_parts) + r2 = float(r2_score(true, pred)) + metric = { + "r2": r2, + "mae": float(mean_absolute_error(true, pred)), + "samples": len(keep), + "points": int(true.size), + "primary": r2, + } + dump_kr_predictions( + task_name, + step_dir, + comps=keep, + t_list=[t.cpu().numpy() for t in t_list], + true_parts=true_parts, + pred=pred, + ) + dump_metrics(task_name, step_dir, metric) + if is_new: + plot_kr_sequences(keep, t_list, true_parts, pred, task_name, step_dir, title=_title(task_name)) + return metric + + # --- per-task artifact dump helpers -------------------------------------- + # ``dump_predictions`` / ``dump_kr_predictions`` / ``dump_metrics`` now live in + # :mod:`continual_rehearsal_common`; imported at the top of this file and called inline + # in ``_evaluate_task``. The bound-method versions were verbatim copies of demo's and + # caused drift (PR #18 code review). + + # ------------------------------------------------------------------ inverse design + + @staticmethod + def _element_system(composition: str) -> frozenset[str]: + """Element symbols (no amounts) in a composition string — used for system-level dedup.""" + return frozenset(re.findall(r"[A-Z][a-z]?", composition)) + + @classmethod + def _dedupe_by_element_system(cls, candidates: list[str], n: int) -> list[str]: + """Walk ``candidates`` in order, keep the first occurrence of each element set, cap at ``n``. + + Empty / malformed compositions (those that parse to an empty element-set) are silently + skipped so a bad row in the source dataframe doesn't blow up the seed picker — matches + the demo runner's behaviour at ``continual_rehearsal_demo._dedupe_by_element_system`` + (the two used to differ; aligning them prevents drift when this gets shared into + ``continual_rehearsal_common``). + """ + seen: set[frozenset[str]] = set() + out: list[str] = [] + for comp in candidates: + key = cls._element_system(comp) + if not key or key in seen: + continue + seen.add(key) + out.append(comp) + if len(out) >= n: + break + return out + + def _select_seeds(self, model, device, qc_prob_fn) -> dict[str, list[str]]: + """Pick seed compositions for inverse design (mirrors demo's PR #18 behaviour). + + Returns ``{"strategy_seeds": […], "explicit_seeds": […]}``. Element-system dedup keeps the + best representative per element set (so 17 strategy seeds = 17 distinct alloy families, + not 17 ratio variants of three). ``inverse_seed_explicit_append`` is fail-fast validated + (each appended composition must have a computable descriptor) and the strategy budget is + reduced by its length so the total length equals ``inverse_n_seeds``. + """ + cfg = self.config + n = cfg.inverse_n_seeds + + # Pre-validate the explicit-append seeds so we fail fast on bad input. + appended: list[str] = [] + for raw in cfg.inverse_seed_explicit_append: + norm = normalize_composition(raw) or str(raw) + if norm not in self._desc_cache and self.descriptor_fn([norm]).empty: + raise ValueError( + f"inverse_seed_explicit_append entry {raw!r} has no computable descriptor " + "(check the formula and that all elements are in DEFAULT_ELEMENTS)." + ) + appended.append(norm) + # Dedup the appended list itself (in case the user listed near-duplicates). + appended = self._dedupe_by_element_system(appended, len(appended)) + n_strategy = max(0, n - len(appended)) + + def _finalise(strategy_seeds: list[str]) -> dict[str, list[str]]: + """Combine strategy seeds + explicit-append, skipping any duplicate element systems.""" + seen_keys = {self._element_system(c) for c in appended} + kept_strategy = [c for c in strategy_seeds if self._element_system(c) not in seen_keys][:n_strategy] + return {"strategy_seeds": kept_strategy, "explicit_seeds": appended} + + if cfg.inverse_seed_strategy == "explicit": + seeds = [normalize_composition(c) or str(c) for c in cfg.inverse_seed_compositions] + seeds = [c for c in seeds if c in self._desc_cache or not self.descriptor_fn([c]).empty] + return _finalise(self._dedupe_by_element_system(seeds, n_strategy)) + + # Candidate pool: chosen split of the material_type frame, with a valid descriptor. + frame = self.task_frames["material_type"] + index = ( + frame.index if cfg.inverse_seed_split == "all" else frame.index[frame["split"] == cfg.inverse_seed_split] + ) + pool = [c for c in index if c in self._desc_cache or not self.descriptor_fn([c]).empty] + if not pool: + return {"strategy_seeds": [], "explicit_seeds": appended} + + if cfg.inverse_seed_strategy == "random": + rng = np.random.default_rng(cfg.random_seed) + shuffled = [pool[i] for i in rng.permutation(len(pool))] + return _finalise(self._dedupe_by_element_system(shuffled, n_strategy)) + + # "top_qc": highest predicted QC probability, then element-system dedup. + x, pool = self._descriptor_tensor(pool, device) + probs = qc_prob_fn(x) + ranked = [pool[i] for i in np.argsort(probs)[::-1]] + return _finalise(self._dedupe_by_element_system(ranked, n_strategy)) + + def _decode_compositions_from_descriptor(self, descriptors: np.ndarray) -> list[str]: + """Latent-path composition output: AE-decoded descriptor → KMD.inverse → formula string.""" + try: + weights = self._kmd.inverse(descriptors) + except Exception as exc: # pragma: no cover - QP edge cases + logger.warning(f"KMD.inverse failed ({exc}); skipping composition decoding.") + return [""] * descriptors.shape[0] + return _format_weights(weights) + + def _inverse_design( + self, + model, + *, + record_trajectory: bool = False, + per_seed_trajectories: bool = False, + animation_formats: tuple[str, ...] = ("gif",), + ) -> dict[str, Any]: + """Run the 8 inverse-design configurations against each scenario on the same seeds. + + The configurations are defined at module level in :data:`INVERSE_PATH_CONFIGS`, mirroring + the demo's ``paper_inverse_comparison.py``: + + * **latent** (3 rows): ``optimize_latent`` with ``ae_align_scale ∈ {0.0, 0.25, 1.0}`` + (failure / mid / max alignment). + * **composition** (5 rows): ``optimize_composition`` with seed_blend / palette / diversity + knobs swept — strict seed, blended seed, blended + palette, blended + palette + low + diversity, and random init (no seed) as the no-seed-bias control. + + Saves per-path JSON + plot under ``inverse_design///`` plus a per-scenario + ``summary.json`` aggregating headline stats, and a top-level ``seeds.json`` recording the + strategy- vs explicit-appended seed split. + + When ``record_trajectory`` is set we additionally emit per-step trajectory artefacts + (``trajectories/.npz`` + ``trajectories/trajectory__.{png,gif,…}``) per + scenario, using ``paper_inverse_comparison._emit_trajectory_outputs`` so the figures + match the demo verbatim. ``animation_formats`` controls the animation outputs; pass + ``("none",)`` to skip animations (the static plot still appears). ``per_seed_trajectories`` + additionally emits one plot+animation per ``(path × seed)``. + """ + cfg = self.config + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + model.eval() + inv_root = self.output_dir / "inverse_design" + inv_root.mkdir(parents=True, exist_ok=True) + + def _qc_prob(x: torch.Tensor) -> np.ndarray: + with torch.no_grad(): + h = torch.tanh(model.encoder(x)) + probs = torch.softmax(model.task_heads["material_type"](h), dim=-1) + return probs[:, QC_CLASSES].sum(dim=-1).cpu().numpy() + + def _reg_preds(x: torch.Tensor, tasks: list[str]) -> dict[str, np.ndarray]: + with torch.no_grad(): + h = torch.tanh(model.encoder(x)) + return {t: model.task_heads[t](h).squeeze(-1).cpu().numpy() for t in tasks} + + # Same seeds for every scenario, so all eight paths are directly comparable. + seed_split = self._select_seeds(model, device, _qc_prob) + seeds_all = seed_split["strategy_seeds"] + seed_split["explicit_seeds"] + if not seeds_all: + logger.warning("No seeds available for inverse design.") + return {} + x_seed, seeds = self._descriptor_tensor(seeds_all, device) + if not seeds: + logger.warning("No seeds have computable descriptors; aborting inverse design.") + return {} + + # Composition path shares: kernel + per-seed initial weight tensor (B, n_components). + kmd_kernel = self._kmd.kernel_torch(device=device, dtype=dtype) + w_seed = _seed_weights_from_compositions(seeds, n_components=len(DEFAULT_ELEMENTS)).to( + device=device, dtype=dtype + ) + + # Top-level seeds.json with the strategy / explicit split (single source of truth across + # all scenarios). Per-path subdirs record their own ``seeds`` field for completeness. + seeds_meta = { + "strategy": cfg.inverse_seed_strategy, + "strategy_split": cfg.inverse_seed_split, + "n_target": cfg.inverse_n_seeds, + "n_used": len(seeds), + "strategy_seeds": [c for c in seed_split["strategy_seeds"] if c in seeds], + "explicit_seeds": [c for c in seed_split["explicit_seeds"] if c in seeds], + "all_seeds_used": seeds, + } + (inv_root / "seeds.json").write_text(json.dumps(seeds_meta, indent=2), encoding="utf-8") + + # The shared ``plot_element_frequency_heatmap`` reads the seed list directly so it can + # mark x-tick labels that are absent from every seed as "discovered" — we no longer + # need to pre-compute a seed_element_pool here. + + out: dict[str, Any] = {"seeds": seeds_meta, "scenarios": {}} + for sc in cfg.inverse_scenarios: + logger.info(f"=== Inverse design [{sc.name}]: targets={dict(zip(sc.reg_tasks, sc.reg_targets))} ===") + sc_dir = inv_root / sc.name + sc_dir.mkdir(parents=True, exist_ok=True) + reg_targets = {t: v for t, v in zip(sc.reg_tasks, sc.reg_targets)} + + # Per-scenario targets.json (plan §5) — separate from results so a slide author can + # quote the objective without parsing the full result dump. + (sc_dir / "targets.json").write_text( + json.dumps( + { + "name": sc.name, + "primary": {"task": "material_type", "class_indices": QC_CLASSES, "direction": "max"}, + "secondary": [ + {"task": t, "target": v, "direction": "min" if v < 0 else "max"} + for t, v in reg_targets.items() + ], + }, + indent=2, + ), + encoding="utf-8", + ) + + before_qc = _qc_prob(x_seed) + before_reg = _reg_preds(x_seed, sc.reg_tasks) + + paths: dict[str, dict[str, Any]] = {} + for path_cfg in INVERSE_PATH_CONFIGS: + key = path_cfg["key"] + path_dir = sc_dir / key + if path_cfg["method"] == "latent": + paths[key] = self._run_latent_path( + model, + x_seed, + seeds, + reg_targets, + path_dir, + ae_align_scale=path_cfg["ae_align_scale"], + label=path_cfg["label"], + _qc_prob_fn=_qc_prob, + _reg_preds_fn=_reg_preds, + record_trajectory=record_trajectory, + ) + else: + # Composition row: resolve the palette sentinel and seed/random init. + allowed = ( + list(cfg.inverse_composition_allowed_elements) + if path_cfg["allowed"] == _PALETTE_SENTINEL + else path_cfg["allowed"] + ) + init = path_cfg["init"] + paths[key] = self._run_composition_path( + model, + kmd_kernel, + w_seed if init == "seed" else None, + seeds, + reg_targets, + path_dir, + init=init, + blend=path_cfg["blend"] if init == "seed" else None, + allowed=allowed, + diversity=path_cfg["diversity"], + label=path_cfg["label"], + _qc_prob_fn=_qc_prob, + _reg_preds_fn=_reg_preds, + record_trajectory=record_trajectory, + ) + + scenario_summary = { + "name": sc.name, + "reg_targets": reg_targets, + "n_seeds": len(seeds), + "qc_before_mean": float(before_qc.mean()), + "paths": { + path_name: { + "qc_after_mean": float(np.mean(p["qc_after_decode"])), + "qc_after_std": float(np.std(p["qc_after_decode"])), + "reg_after_decode_mean": {t: float(np.mean(p["reg_after_decode"][t])) for t in reg_targets}, + "reg_after_decode_std": {t: float(np.std(p["reg_after_decode"][t])) for t in reg_targets}, + } + for path_name, p in paths.items() + }, + } + (sc_dir / "summary.json").write_text(json.dumps(scenario_summary, indent=2), encoding="utf-8") + self._plot_inverse_scenario(sc, before_qc, before_reg, paths, reg_targets, sc_dir) + # Shared heatmap: pass per-path ``label`` + ``decoded_composition`` lists so the + # x-tick / colourbar / title styling matches the demo's paper_inverse_comparison. + heatmap_methods = [ + { + "label": INVERSE_PATH_CONFIGS_BY_KEY[key]["label"], + "decoded_composition": p.get("decoded_composition", []) or [], + } + for key, p in paths.items() + if key in INVERSE_PATH_CONFIGS_BY_KEY + ] + plot_element_frequency_heatmap(heatmap_methods, list(seeds), sc_dir / "element_frequency_heatmap.png") + + # ── per-scenario figures copied from the demo's ``paper_inverse_comparison.py`` ── + # The runner used to emit only the (boxplot) ``comparison.png`` and the + # ``element_frequency_heatmap.png``; the per-seed scatter and 1:1 mapping figures + # lived only in the demo. We import and call the demo's helpers directly so the + # two surfaces never drift on plot style or legend ordering. Inputs are built once + # per scenario from the same ``paths`` dict the existing plotters consume — no extra + # forward passes, no training touch-up. + results_for_demo_helpers = [ + { + "method": paths[c["key"]]["method"], + "label": paths[c["key"]]["label"], + "qc_after_decode": paths[c["key"]]["qc_after_decode"], + "reg_after_decode": paths[c["key"]]["reg_after_decode"], + # ``_plot_seed_to_optimized_mapping`` doesn't need these but the scatter + # helper's legend grouping reads ``method``; carry them anyway so a future + # change picking up ``align_scale`` doesn't break silently. + "align_scale": paths[c["key"]].get("ae_align_scale"), + "decoded_composition": paths[c["key"]].get("decoded_composition", []), + } + for c in INVERSE_PATH_CONFIGS + if c["key"] in paths + ] + _plot_qc_vs_reg_scatter( + results_for_demo_helpers, + reg_targets, + sc_dir / "qc_vs_secondary_scatter.png", + title=f"QC probability vs secondary properties · {sc.name}", + seed_qc=before_qc, + seed_reg=before_reg, + ) + # Per-path seed → optimised composition mapping. Skip ``comp_random`` (no per-row + # seed correspondence — its ``seeds`` field is a ``random_start_N`` placeholder). + for c in INVERSE_PATH_CONFIGS: + key = c["key"] + if key not in paths or key == "comp_random": + continue + p = paths[key] + decoded = p.get("decoded_composition", []) + if not decoded: + continue + _plot_seed_to_optimized_mapping( + seeds=list(seeds), + decoded=list(decoded), + out_path=sc_dir / f"seed_to_optimized__{key}.png", + title=f"Seed → optimised composition · {c['label']}", + seed_qc=before_qc, + seed_reg=before_reg, + optimized_qc=np.asarray(p["qc_after_decode"]), + optimized_reg={t: np.asarray(p["reg_after_decode"][t]) for t in reg_targets}, + reg_targets=reg_targets, + ) + + # ── trajectory persistence + figures ── + # When ``record_trajectory`` is on, every path's ``_run_*_path`` returned a result + # carrying ``trajectory_targets`` (steps, B, T) and ``trajectory_weights`` + # (steps, B, n_components). For a 300-step / B=20 / 94-component run those arrays + # together weigh ~3 MB per path × 8 paths × 3 scenarios ≈ 72 MB — too heavy to inline + # into ``inverse_design.json``. Persist as compressed npz next to each scenario's + # plots, then pop the inline arrays so the json stays browsable. Filenames use + # ``paper_inverse_comparison._path_slug`` so the demo's trajectory consumers can + # ingest these files directly. + if record_trajectory: + traj_dir = sc_dir / "trajectories" + traj_dir.mkdir(exist_ok=True) + results_for_traj: list[dict[str, Any]] = [] + for key, p in paths.items(): + if "trajectory_targets" not in p or "trajectory_weights" not in p: + continue + # ``_path_slug`` reads ``method``, ``label``, and (for latent) ``align_scale``. + # Our latent rows store ``ae_align_scale``; mirror it onto ``align_scale`` for + # the slug call (and so the demo's ``_emit_trajectory_outputs`` can group + # latents by α). + slug_record: dict[str, Any] = { + "method": p["method"], + "label": p["label"], + "align_scale": p.get("ae_align_scale"), + } + slug = _path_slug(slug_record) + npz_path = traj_dir / f"{slug}.npz" + np.savez_compressed( + npz_path, + targets=np.asarray(p["trajectory_targets"], dtype=np.float32), + weights=np.asarray(p["trajectory_weights"], dtype=np.float32), + ) + # Drop the huge arrays now that they live on disk; carry a reference in their + # place so ``inverse_design.json`` consumers can find them. + p.pop("trajectory_targets", None) + p.pop("trajectory_weights", None) + p["trajectory_file"] = str(npz_path.relative_to(sc_dir)) + # ``_emit_trajectory_outputs`` reads the npz via ``out_dir / r["trajectory_file"]``, + # so the result dict here has to use the *scenario-relative* path too. + results_for_traj.append( + { + **slug_record, + "qc_after_decode": p["qc_after_decode"], + "reg_after_decode": p["reg_after_decode"], + "trajectory_file": p["trajectory_file"], + } + ) + if results_for_traj: + _emit_trajectory_outputs( + results=results_for_traj, + reg_targets=reg_targets, + seed_qc=before_qc, + seed_reg=before_reg, + out_dir=sc_dir, + traj_dir=traj_dir, + per_seed=per_seed_trajectories, + animation_formats=animation_formats, + ) + + # Explicit guard: ``list and float`` was a clever but fragile non-empty check — + # an empty ``qc_after_decode`` (no successful seeds for a path) returned the empty + # list, which then crashed ``f"{...:.3f}"`` with ``TypeError`` on format. NaN keeps + # the join uniform and is the natural "no data" sentinel for downstream readers. + def _qc_mean(path_name: str) -> float: + qc = paths[path_name].get("qc_after_decode") or [] + return float(np.mean(qc)) if qc else float("nan") + + qc_summary = " · ".join(f"{name}={_qc_mean(name):.3f}" for name in INVERSE_PATHS) + logger.info(f"[{sc.name}] QC after-decode mean — {qc_summary}") + + out["scenarios"][sc.name] = {**scenario_summary, "paths_details": paths} + return out + + # --- inverse path runners ------------------------------------------------- + + def _run_latent_path( + self, + model, + x_seed: torch.Tensor, + seeds: list[str], + reg_targets: dict[str, float], + path_dir: Path, + *, + ae_align_scale: float, + label: str, + _qc_prob_fn, + _reg_preds_fn, + record_trajectory: bool = False, + ) -> dict[str, Any]: + """Latent-space optimisation with cycle-consistency at a fixed ``ae_align_scale``. + + When ``record_trajectory`` is set we (a) ask ``optimize_latent`` to keep its per-step + AE-decoded input, and (b) decode each step through ``KMD.inverse`` to recover the per-step + composition recipe — same trick the demo's ``_run_latent_method`` uses, so the trajectory + is on the same surface as the final ``reg_after_decode`` values. The huge ``(steps, B, *)`` + arrays land in ``result["trajectory_targets"]`` / ``result["trajectory_weights"]``; the + caller is responsible for persisting them as a compressed npz and popping them off the + result dict so they don't bloat ``inverse_design.json``. + """ + cfg = self.config + path_dir.mkdir(parents=True, exist_ok=True) + reg_names = list(reg_targets) + + before_qc = _qc_prob_fn(x_seed) + before_reg = _reg_preds_fn(x_seed, reg_names) + + res = model.optimize_latent( + initial_input=x_seed, + task_targets=reg_targets, + class_targets={"material_type": QC_CLASSES}, + class_target_weight=cfg.inverse_class_weight, + ae_align_scale=ae_align_scale, + optimize_space="latent", + steps=cfg.inverse_steps, + lr=cfg.inverse_lr, + record_input_trajectory=record_trajectory, + ) + achieved_latent = res.optimized_target[:, 0, :].cpu().numpy() + optimized_desc = res.optimized_input[:, 0, :] + optimized_desc_np = optimized_desc.detach().cpu().numpy() + after_qc = _qc_prob_fn(optimized_desc) + after_reg = _reg_preds_fn(optimized_desc, reg_names) + try: + optimized_weights = self._kmd.inverse(optimized_desc_np) + except Exception as exc: # pragma: no cover + logger.warning(f"KMD.inverse failed for latent path ({exc}); weights left empty.") + optimized_weights = np.zeros((optimized_desc_np.shape[0], len(DEFAULT_ELEMENTS))) + decoded = _format_weights(optimized_weights) + + result = { + "method": "latent", + "label": label, + "ae_align_scale": ae_align_scale, + "seeds": list(seeds), + "qc_before": before_qc.tolist(), + "qc_after_decode": after_qc.tolist(), + "reg_before": {t: before_reg[t].tolist() for t in reg_names}, + "reg_achieved_latent": {t: achieved_latent[:, j].tolist() for j, t in enumerate(reg_names)}, + "reg_after_decode": {t: after_reg[t].tolist() for t in reg_names}, + "decoded_composition": decoded, + "optimized_descriptor": optimized_desc_np.tolist(), + "optimized_weights": optimized_weights.tolist(), + } + # Trajectory arrays (kept out of result.json — caller persists them as a separate npz). + if record_trajectory and res.input_trajectory is not None and res.trajectory is not None: + # ``res.trajectory`` is (B, R=1, steps, T) — squeeze restart, permute to (steps, B, T). + result["trajectory_targets"] = res.trajectory[:, 0, :, :].cpu().numpy().transpose(1, 0, 2) + # ``res.input_trajectory`` is (B, R=1, steps, input_dim) → (steps, B, input_dim); + # ``KMD.inverse`` then maps each step's descriptor batch → (B, n_components). + per_step_inputs = res.input_trajectory[:, 0, :, :].cpu().numpy().transpose(1, 0, 2) + result["trajectory_weights"] = np.stack( + [self._kmd.inverse(per_step_inputs[s]) for s in range(per_step_inputs.shape[0])] + ) # (steps, B, n_components) — one QP solve per (step × seed), ~10 % overhead. + # Write result.json without the trajectory arrays (they live in the npz once persisted). + json_payload = {k: v for k, v in result.items() if k not in {"trajectory_targets", "trajectory_weights"}} + (path_dir / "result.json").write_text(json.dumps(json_payload, indent=2), encoding="utf-8") + return result + + def _run_composition_path( + self, + model, + kmd_kernel: torch.Tensor, + w_seed: torch.Tensor | None, + seeds: list[str], + reg_targets: dict[str, float], + path_dir: Path, + *, + init: str, + blend: float | None, + allowed: str | list[str], + diversity: float, + label: str, + _qc_prob_fn, + _reg_preds_fn, + record_trajectory: bool = False, + ) -> dict[str, Any]: + """Composition-space optimisation via differentiable KMD (``optimize_composition``). + + ``init="seed"`` uses ``w_seed`` + ``seed_blend``; ``init="random"`` ignores ``w_seed`` and + runs ``n_starts = len(seeds)`` so the per-row budget matches the latent run. + + When ``record_trajectory`` is set, the per-step weight + reg-target trajectories come + straight from ``optimize_composition`` (composition's optim variable already lives on the + right surface, so no per-step KMD.inverse is needed — unlike the latent path). + """ + cfg = self.config + path_dir.mkdir(parents=True, exist_ok=True) + reg_names = list(reg_targets) + + if init == "seed": + if w_seed is None: + raise ValueError("Composition path with init='seed' requires w_seed.") + init_kwargs: dict[str, Any] = {"initial_weights": w_seed, "seed_blend": blend} + elif init == "random": + init_kwargs = {"initial_weights": None, "n_starts": len(seeds)} + else: + raise ValueError(f"Unknown init mode in composition path: {init!r}") + + res = model.optimize_composition( + kmd_kernel, + task_targets=reg_targets, + class_targets={"material_type": QC_CLASSES}, + class_target_weight=cfg.inverse_class_weight, + diversity_scale=diversity, + allowed_elements=allowed, + steps=cfg.inverse_steps, + lr=cfg.inverse_lr, + record_weights_trajectory=record_trajectory, + **init_kwargs, + ) + # Composition's result tensors are 2D — ``(B, x_dim)`` / ``(B, n_components)`` / + # ``(B, T)`` — no restart axis, so no ``[:, 0, :]`` slicing (unlike ``optimize_latent``). + optimized_desc = res.optimized_descriptor # (B, x_dim) — w @ K, no AE round-trip + optimized_desc_np = optimized_desc.detach().cpu().numpy() + w_final = res.optimized_weights.detach().cpu().numpy() + achieved_latent = res.optimized_target.detach().cpu().numpy() # (B, T) + after_qc = _qc_prob_fn(optimized_desc) + after_reg = _reg_preds_fn(optimized_desc, reg_names) + decoded = _format_weights(w_final) + + # Random init has no per-row correspondence with the seed list — preserve the seed list + # only when the init was seeded; otherwise label the rows as random restarts. + seed_labels = list(seeds) if init == "seed" else [f"random_start_{i}" for i in range(len(seeds))] + + result = { + "method": "composition", + "label": label, + "init": init, + "seed_blend": blend, + "allowed_elements": allowed, + "diversity_scale": diversity, + "seeds": seed_labels, + "qc_after_decode": after_qc.tolist(), + "reg_achieved_latent": {t: achieved_latent[:, j].tolist() for j, t in enumerate(reg_names)}, + "reg_after_decode": {t: after_reg[t].tolist() for t in reg_names}, + "decoded_composition": decoded, + "optimized_descriptor": optimized_desc_np.tolist(), + "optimized_weights": w_final.tolist(), + } + # Trajectory arrays — same shape convention as the latent path so ``_emit_trajectory_outputs`` + # consumes both interchangeably. ``res.trajectory`` is already (steps, B, T) and + # ``res.weights_trajectory`` is already (steps, B, n_components) — no transpose / decode. + if record_trajectory and res.weights_trajectory is not None and res.trajectory is not None: + result["trajectory_targets"] = res.trajectory.cpu().numpy() + result["trajectory_weights"] = res.weights_trajectory.cpu().numpy() + json_payload = {k: v for k, v in result.items() if k not in {"trajectory_targets", "trajectory_weights"}} + (path_dir / "result.json").write_text(json.dumps(json_payload, indent=2), encoding="utf-8") + return result + + # ------------------------------------------------------------------ plots + # ``plot_parity`` / ``plot_confusion`` / ``plot_kr_sequences`` now live in + # :mod:`continual_rehearsal_common`; they were verbatim copies of demo's and caused PR + # #18's K=0 ``NameError`` to ship in demo for several PRs. The runner-specific plots + # below (``_plot_forgetting`` uses ``self._task_colors``; the inverse-design plotters use + # the 8-path layout) stay as bound methods. + + def _plot_forgetting(self, metric_history): + n_tasks = sum(1 for pts in metric_history.values() if pts) + fig, ax = plt.subplots(figsize=(14, max(5.5, 0.32 * n_tasks + 3))) + all_steps: set[int] = set() + for task_name, points in metric_history.items(): + if not points: + continue + steps = [s for s, _ in points] + vals = [v for _, v in points] + all_steps.update(steps) + is_clf = TASK_SPECS[task_name]["kind"] == "clf" + ax.plot( + steps, + vals, + marker="s" if is_clf else "o", + ms=5, + ls="--" if is_clf else "-", + color=self._task_colors.get(task_name, "#888888"), + label=_display(task_name) + (" · accuracy" if is_clf else ""), + ) + if all_steps: + ax.set_xticks(sorted(all_steps)) + ax.set_xlabel("Continual finetuning step (a new task is added at each step)") + ax.set_ylabel("Primary metric · R² (regression) / accuracy (classification)") + ax.set_title("Per-task performance across continual finetuning") + ncol = 1 if n_tasks <= 20 else 2 + ax.legend(fontsize=8, ncol=ncol, loc="upper left", bbox_to_anchor=(1.01, 1.0), borderaxespad=0.0) + out_path = self.training_dir / "forgetting_trajectory.png" + fig.savefig(out_path) + plt.close(fig) + logger.info(f"Saved forgetting trajectory to {out_path}") + + def _plot_inverse_scenario( + self, + scenario, + before_qc: np.ndarray, + before_reg: dict[str, np.ndarray], + paths: dict[str, dict[str, Any]], + reg_targets: dict[str, float], + sc_dir: Path, + ) -> None: + """Compare the 8 inverse-design configurations side-by-side on QC + each reg target. + + Mirrors the demo's ``paper_inverse_comparison.py`` plot — same suptitle, panel titles + (via ``REG_TASK_TITLES``), x-tick labels (``INVERSE_PATH_CONFIGS[*]["label"]``), and + two-tone colour code (green ``#55A868`` for latent rows, blue ``#2563EB`` for composition + rows). We keep our boxplot style (vs the demo's bar+errorbar) to surface the full per-seed + distribution. Per the user override, the QC panel title is ``"Probability (QC)"``. + """ + reg_names = list(reg_targets) + n_panels = 1 + len(reg_names) + fig, axes = plt.subplots(1, n_panels, figsize=(5.6 * n_panels, 5.6), squeeze=False) + axes = axes[0] + + configs_in_order = [c for c in INVERSE_PATH_CONFIGS if c["key"] in paths] + path_labels = [c["label"] for c in configs_in_order] + # Two-tone colour code, matching the demo. + face_colors = ["#55A868" if c["method"] == "latent" else "#2563EB" for c in configs_in_order] + x_pos = list(range(len(configs_in_order))) + + def _boxplot(ax, vals_per_path: list[list[float]]) -> None: + """Two-tone per-row boxplot. Box face matches the row's method colour at α=0.25.""" + bp = ax.boxplot( + vals_per_path, + positions=x_pos, + widths=0.6, + patch_artist=True, + medianprops=dict(color="#222222", lw=1.4), + flierprops=dict(marker="o", mec="none", ms=3, alpha=0.55), + ) + for patch, fc in zip(bp["boxes"], face_colors): + patch.set(facecolor=fc, alpha=0.25, edgecolor=fc) + for whisker, fc in zip(bp["whiskers"], [c for c in face_colors for _ in range(2)]): + whisker.set_color(fc) + for cap, fc in zip(bp["caps"], [c for c in face_colors for _ in range(2)]): + cap.set_color(fc) + for flier, fc in zip(bp["fliers"], face_colors): + flier.set(markerfacecolor=fc) + + def _set_xticks(ax) -> None: + ax.set_xticks(x_pos) + ax.set_xticklabels(path_labels, rotation=45, ha="right", fontsize=9) + + # Panel 1: QC probability. Title is the user-specified override "Probability (QC)"; + # ylabel + target line follow the demo. + axq = axes[0] + qc_vals = [list(paths[c["key"]]["qc_after_decode"]) for c in configs_in_order] + _boxplot(axq, qc_vals) + axq.axhline(1.0, color="#C44E52", ls="--", lw=1.4, label="target = 1.0") + _set_xticks(axq) + axq.set_ylim(-0.02, 1.05) + axq.set_ylabel("P(quasicrystal)") + axq.set_title("Probability (QC)") + axq.legend(fontsize=9, loc="lower right") + + # Remaining panels: per regression target. Title pulled from REG_TASK_TITLES with units + # and an arrow indicating whether the target is below (↓) or above (↑) the model's baseline. + for ax, (t, tgt) in zip(axes[1:], reg_targets.items()): + vals = [list(paths[c["key"]]["reg_after_decode"][t]) for c in configs_in_order] + _boxplot(ax, vals) + ax.axhline(tgt, color="#C44E52", ls="--", lw=1.4, label=f"target = {tgt:+.1f}") + _set_xticks(ax) + ax.set_ylabel("Predicted value") + ax.set_title(REG_TASK_TITLES.get(t, t)) + ax.legend(fontsize=9, loc="best") + + fig.suptitle( + "Inverse-design comparison: latent (ae_align_scale sweep) vs differentiable KMD (configs)", + y=1.00, + ) + out = sc_dir / "comparison.png" + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved inverse-design comparison plot to {out}") + + # ------------------------------------------------------------------ slide-prep (plan §6) + + def _counts(self) -> dict[str, int]: + seq = self.config.task_sequence + return { + "n_tasks": len(seq), + "n_reg": sum(1 for t in seq if TASK_SPECS[t]["kind"] == "reg"), + "n_kr": sum(1 for t in seq if TASK_SPECS[t]["kind"] == "kr"), + "n_clf": sum(1 for t in seq if TASK_SPECS[t]["kind"] == "clf"), + } + + def _dataset_summary(self) -> list[tuple[str, int, int]]: + """(dataset display, #tasks, #unique compositions used) per source, in stable order.""" + rows = [] + for src in ("qc", "phonix", "superconductor", "magnetic"): + tasks = [t for t in self.config.task_sequence if TASK_SPECS[t]["source"] == src] + if not tasks: + continue + keys = set().union(*[set(self.task_frames[t].index) for t in tasks]) + rows.append((SOURCE_DISPLAY[src], len(tasks), len(keys))) + return rows + + def _final_target_metrics(self, records: list[dict[str, Any]]) -> dict[str, dict[str, float]]: + """Final-step metrics for the headline tasks the summary must report.""" + final = records[-1]["metrics"] if records else {} + headline = ["formation_energy", "magnetic_moment", "tc", "klat", "material_type"] + return {t: final.get(t, {}) for t in headline if t in self.config.task_sequence} + + # --- element-frequency heatmap ------------------------------------------ + # The runner used to carry its own bound-method heatmap that consumed the per-path + # ``optimized_weights`` directly; we now share ``plot_element_frequency_heatmap`` from + # ``continual_rehearsal_common`` with the demo runner (same x-tick discovered-element + # styling, same colourbar label, same title format). The shared helper reads from the + # already-decoded ``decoded_composition`` strings (already in ``paths[key]``), so we + # don't need ``DEFAULT_ELEMENTS`` order or an ``eps`` threshold here. + + # --- markdown writers (plan §6) ------------------------------------------- + + def _write_inverse_summary_md(self, inverse: dict[str, Any]) -> None: + """Compact cross-scenario summary (plan §6). + + Scenarios have **heterogeneous** regression-target sets (e.g. scenario2 has 3 reg targets + vs 2 for the others), so a single flat table would let later rows spill past the header. + We keep the cross-scenario table to **QC only** (the metric every scenario shares), and + emit a per-scenario reg-target block underneath. + """ + scenarios = inverse.get("scenarios", {}) if isinstance(inverse, dict) else {} + if not scenarios: + return + lines: list[str] = [ + "# Inverse design — compact cross-scenario summary\n", + "Auto-generated. The headline QC table aggregates across all scenarios; per-scenario " + "reg-target tables follow. Full per-seed arrays in " + "`inverse_design///result.json`.\n", + ] + + # Cross-scenario QC table — the one metric every scenario shares. + lines.append("## QC probability after decode\n") + lines.append("| scenario | path | QC mean | QC std |") + lines.append("|---|---|---:|---:|") + for name, data in scenarios.items(): + paths_meta = data.get("paths", {}) + for path_name in INVERSE_PATHS: + meta = paths_meta.get(path_name, {}) + qc_m = meta.get("qc_after_mean", float("nan")) + qc_s = meta.get("qc_after_std", float("nan")) + lines.append(f"| {name} | {path_name} | {qc_m:.3f} | {qc_s:.3f} |") + lines.append("") + + # Per-scenario regression targets (columns match that scenario's reg_targets). + for name, data in scenarios.items(): + reg_targets = data.get("reg_targets", {}) + paths_meta = data.get("paths", {}) + lines.append(f"## {name} — regression targets (after decode)\n") + secondary = " · ".join(f"{_display(t)} {_arrow(v)} {v:+.1f}" for t, v in reg_targets.items()) + lines.append(f"Targets: {secondary}\n") + header = ["path", *[_display(t) for t in reg_targets]] + lines.append("| " + " | ".join(header) + " |") + lines.append("|" + "|".join(["---"] * len(header)) + "|") + for path_name in INVERSE_PATHS: + meta = paths_meta.get(path_name, {}) + row = [path_name] + for t in reg_targets: + row.append(f"{meta.get('reg_after_decode_mean', {}).get(t, float('nan')):+.2f}") + lines.append("| " + " | ".join(row) + " |") + lines.append("") + + (self.inverse_root / "SUMMARY.md").write_text("\n".join(lines), encoding="utf-8") + logger.info(f"Saved inverse-design SUMMARY.md to {self.inverse_root / 'SUMMARY.md'}") + + def _write_analysis_md(self, records: list[dict[str, Any]], inverse: dict[str, Any]) -> None: + """Long-form analysis (English, plan §0a). Reads as speaker-notes feedstock for SLIDE_PREP.""" + c = self._counts() + intro = {r["new_task"]: r["metrics"][r["new_task"]]["primary"] for r in records} + final = records[-1]["metrics"] if records else {} + lines: list[str] = [] + lines.append("# Analysis — continual rehearsal + inverse design\n") + lines.append( + "Long-form narrative analysis of this run. The structured slide outline lives in\n" + "[`SLIDE_PREP.md`](SLIDE_PREP.md); the compact cross-scenario table lives in\n" + "[`inverse_design/SUMMARY.md`](inverse_design/SUMMARY.md). Numbers below are\n" + "regenerable from the raw arrays under `training/` and `inverse_design/`.\n", + ) + + lines.append("## Run scale\n") + lines.append( + f"- **{c['n_tasks']} supervised tasks**: {c['n_reg']} regression · " + f"{c['n_kr']} kernel regression · {c['n_clf']} classification, plus the always-on autoencoder.\n" + ) + lines.append("- Datasets (tasks · unique compositions used):") + for name, ntask, nkeys in self._dataset_summary(): + lines.append(f" - {name}: {ntask} · {nkeys}") + lines.append("") + + lines.append("## Continual learning — is there forgetting?\n") + drops = [] + for task in self.config.task_sequence: + i = intro.get(task) + f_v = final.get(task, {}).get("primary") + if i is not None and f_v is not None and np.isfinite(i) and np.isfinite(f_v): + drops.append((task, i, f_v, f_v - i)) + early = drops[: max(1, len(drops) // 2)] + mean_early_delta = float(np.mean([d for *_, d in early])) if early else float("nan") + verdict = "stable (no clear forgetting)" if mean_early_delta > -0.05 else "some forgetting" + lines.append( + f"Mean (final − at-intro) primary metric over the *earlier-trained half* is " + f"**{mean_early_delta:+.3f}** → **{verdict}**. The full per-step trajectory is in " + "`training/forgetting_trajectory.png`; per-task raw `(composition, true, pred)` for " + "every step is in `training/stepNN_/_pred.parquet` + `_metrics.json` " + "— rebuild any panel from those without retraining.\n" + ) + lines.append("| task | at intro | final | Δ |") + lines.append("|---|---:|---:|---:|") + for task, i, f_v, d in drops: + lines.append(f"| {_display(task)} | {i:+.3f} | {f_v:+.3f} | {d:+.3f} |") + lines.append("") + + lines.append("## Final model — headline targets (inverse-design heads)\n") + lines.append("| task | metric | value |") + lines.append("|---|---|---:|") + for task, m in self._final_target_metrics(records).items(): + spec = TASK_SPECS[task] + metric_name = "accuracy" if spec["kind"] == "clf" else "R²" + val = m.get("primary", float("nan")) + lines.append(f"| {_display(task)} | {metric_name} | {val:+.3f} |") + lines.append("") + + lines.append("## Inverse design — 3 scenarios × 4 paths\n") + lines.append( + "Each scenario shares the same 20 seeds (17 top-QC element-system-dedup + 3 explicit " + "Au-Ga-Ln). Path semantics: **latent** uses `optimize_latent(ae_align_scale=0.5)` " + "(PR #18 sweet spot); **composition_strict** locks the seed element support " + "(`seed_blend=1.0`); **composition_alloy** is the paper-headline path " + f"(`seed_blend≈0.95`, {len(ALLOY_PALETTE)}-element ALLOY_PALETTE — allows discovery of QC-prone " + "elements outside the seeds); **composition_random** ablates the seed entirely " + "(`n_starts=N`) to surface the model's global QC attractor — useful to motivate the " + "need for chemistry-constrained palettes when the global attractor falls on " + "unsynthesisable elements.\n" + ) + scenarios = inverse.get("scenarios", {}) if isinstance(inverse, dict) else {} + for name, data in scenarios.items(): + reg_targets = data.get("reg_targets", {}) + paths_meta = data.get("paths", {}) + paths_details = data.get("paths_details", {}) + secondary = ", ".join(f"{_display(t)} {_arrow(v)} {v:+.1f}" for t, v in reg_targets.items()) + lines.append(f"### {name}\n") + lines.append(f"- Secondary targets: {secondary}") + lines.append(f"- Seed mean QC (before): **{data.get('qc_before_mean', float('nan')):.3f}**") + lines.append("") + header_cells = ["path", "QC after (mean ± std)"] + [_display(t) for t in reg_targets] + lines.append("| " + " | ".join(header_cells) + " |") + lines.append("|" + "|".join(["---"] * len(header_cells)) + "|") + for path_name in INVERSE_PATHS: + meta = paths_meta.get(path_name, {}) + qc_m = meta.get("qc_after_mean", float("nan")) + qc_s = meta.get("qc_after_std", float("nan")) + row_cells = [path_name, f"{qc_m:.3f} ± {qc_s:.3f}"] + for t in reg_targets: + row_cells.append(f"{meta.get('reg_after_decode_mean', {}).get(t, float('nan')):+.2f}") + lines.append("| " + " | ".join(row_cells) + " |") + lines.append("") + lines.append("One decoded example per path:") + for path_name in INVERSE_PATHS: + decoded = paths_details.get(path_name, {}).get("decoded_composition", []) + if decoded: + lines.append(f"- **{path_name}**: `{decoded[0]}`") + lines.append("") + lines.append( + f"Element-discovery heatmap: `inverse_design/{name}/element_frequency_heatmap.png`. " + f"Side-by-side path comparison: `inverse_design/{name}/comparison.png`. " + f"Per-path raw arrays: `inverse_design/{name}//result.json` (keys `optimized_weights` " + "(B, 94), `optimized_descriptor` (B, x_dim), `qc_after_decode`, `reg_after_decode`).\n" + ) + + (self.output_dir / "ANALYSIS.md").write_text("\n".join(lines), encoding="utf-8") + logger.info(f"Saved Markdown analysis to {self.output_dir / 'ANALYSIS.md'}") + + def _write_slide_prep_md(self, records: list[dict[str, Any]], inverse: dict[str, Any]) -> None: + """9-slide structured handoff for the external slide author. + + Mirrors the polish level of the demo's ``inverse_design_run/SLIDE_PREP.md``: + every section names a takeaway, slide content, speaker notes, and the canonical figure + (with raw-data paths so the slide author can rebuild the figure if the auto-emitted one + doesn't fit their layout). Numbers are computed from this run's data; interpretation + threads are templated stubs the slide author fills in after sanity-checking against the + plan §5 expected baselines (also reproduced inline). + + When ``sample_per_dataset`` is set (i.e. this is a smoke / partial run rather than the + formal full run), a disclaimer is rendered at the top of the document; the numbers are + still real but the magnitudes will not match the plan §5 expected baselines. + """ + cfg = self.config + counts = self._counts() + intro = {r["new_task"]: r["metrics"][r["new_task"]]["primary"] for r in records} + final = records[-1]["metrics"] if records else {} + scenarios = inverse.get("scenarios", {}) if isinstance(inverse, dict) else {} + seeds_meta = inverse.get("seeds", {}) if isinstance(inverse, dict) else {} + strategy_seeds = list(seeds_meta.get("strategy_seeds", [])) + explicit_seeds = list(seeds_meta.get("explicit_seeds", [])) + all_seeds = strategy_seeds + explicit_seeds + seed_pool: set[str] = set() + for s in all_seeds: + seed_pool |= self._element_system(s) + + is_smoke = cfg.sample_per_dataset is not None or cfg.max_epochs_per_step < 20 + run_date = _datetime.date.today().isoformat() + + def _discovered( + path_data: dict[str, Any], threshold: float = 0.95, eps: float = 1e-3 + ) -> list[tuple[str, float]]: + """Elements present in ≥ ``threshold`` fraction of a path's outputs but **0** in any seed.""" + w = np.asarray(path_data.get("optimized_weights", []), dtype=float) + if w.size == 0: + return [] + occ = (w > eps).mean(axis=0) + out: list[tuple[str, float]] = [] + for i, frac in enumerate(occ): + sym = DEFAULT_ELEMENTS[i] + if frac >= threshold and sym not in seed_pool: + out.append((sym, float(frac))) + out.sort(key=lambda kv: -kv[1]) + return out + + def _headline(task: str) -> str: + spec = TASK_SPECS.get(task, {"kind": "reg"}) + metric_name = "accuracy" if spec["kind"] == "clf" else "R²" + val = final.get(task, {}).get("primary", float("nan")) + return f"`{task}` ({metric_name} = **{val:+.3f}**)" + + lines: list[str] = [] + # ── Header ──────────────────────────────────────────────────────────────────────── + lines.append("# Slide-prep document — handoff for the slide author (claude coworker)\n") + lines.append( + "> **What this is.** A structured outline a slide author can convert directly into deck\n" + "> pages. Each section corresponds to one slide / slide group and lists: (a) the\n" + "> takeaway, (b) what to put on the slide, (c) which file in this folder is the visual,\n" + "> (d) speaker-note bullets. The slide author has **full creative freedom** for layout,\n" + "> colours, and visual style — this document only specifies *what* to communicate, not\n" + "> *how*.\n" + ) + lines.append(f"**Folder this document lives in:** `{self.output_dir.name}/`") + lines.append(f"**Run date:** {run_date}") + lines.append( + "**Data sources for every number cited:** " + "`training/experiment_records.json` (per-task metrics across the " + f"{counts['n_tasks']} training stages), " + "`training/metrics_table.csv` (flat per-task at-intro / final), " + "`training/stepNN_/_pred.parquet` (per-step raw test predictions for every " + "active head), `inverse_design/inverse_design.json` (full nested inverse-design dump), " + "and per-path `inverse_design///result.json` (raw per-seed arrays)." + ) + lines.append( + "**Companion docs:** [`README.md`](README.md) (folder index), [`ANALYSIS.md`](ANALYSIS.md) (long-form writeup), [`inverse_design/SUMMARY.md`](inverse_design/SUMMARY.md) (compact cross-scenario table).\n" + ) + + if is_smoke: + lines.append( + "> **⚠️ Run quality note — this is a SMOKE / partial run.**\n" + f"> `sample_per_dataset = {cfg.sample_per_dataset}`, " + f"`max_epochs_per_step = {cfg.max_epochs_per_step}` " + "(formal full run uses `sample_per_dataset = null` and " + "`max_epochs_per_step = 100` + EarlyStopping). The artifact tree is structurally\n" + "> complete (every section below has real numbers from THIS run), but the\n" + "> *magnitudes* will not match the formal full-run expected baselines documented in\n" + "> [`docs/continual_rehearsal_full_PLAN.md`](../../docs/continual_rehearsal_full_PLAN.md) §5.\n" + "> The expected-baseline tables below give the slide author the magnitudes to\n" + "> sanity-check against before quoting numbers from this smoke run.\n" + ) + + lines.append("---\n") + + # ── Slide 1 — Experimental goal ─────────────────────────────────────────────────── + lines.append("## Slide 1 — Experimental goal: multi-property joint optimisation\n") + lines.append( + "**Takeaway.** Real materials development asks for *several properties at once* (is " + "the material a quasi-crystal? does it have low formation energy? does it have high " + "Tc / high κ_lat / high magnetic moment?). Single-property inverse-design tools don't " + "help. We need a joint-optimisation framework around a model that learned all those " + "properties together.\n" + ) + lines.append("**Slide content.**") + lines.append('- Opening line: *"The materials-design question is rarely about a single property."*') + lines.append( + "- 2–3 illustrative property combinations to ground the audience — pulled from this run's scenarios:" + ) + for name, data in scenarios.items(): + reg_targets = data.get("reg_targets", {}) + arrowed = ", ".join(f"{_display(t)} {_arrow(v)}" for t, v in reg_targets.items()) + lines.append(f" - **{name}** — QC ↑ + {arrowed}") + lines.append('- A "wishlist → recipe" arrow showing the inverse direction: target properties → composition.\n') + lines.append("**Speaker notes.**") + lines.append( + "- DFT / experiment loops are prohibitively expensive for joint searches over many target dimensions." + ) + lines.append( + "- A surrogate model that maps composition → multiple properties + supports gradient-based inverse design lets us search jointly.\n" + ) + lines.append("**Visual asset.** Slide author draws; no pre-rendered figure.\n") + lines.append("---\n") + + # ── Slide 2 — Model structure ───────────────────────────────────────────────────── + lines.append("## Slide 2 — Model structure + inverse-design strategies\n") + lines.append( + "**Takeaway.** A shared-encoder foundation model with multiple task heads; **two** " + "inverse-design paths (latent vs composition) operate on the **same trained model** " + "so the comparison is a fair head-to-head test.\n" + ) + lines.append("**Slide content.**") + lines.append( + "- Architecture diagram: " + "`composition → KMD-1d descriptor x → encoder → latent h → tanh → {head_1, …, head_K}`." + ) + lines.append( + "- Highlight the always-on autoencoder head (decoder back to descriptor) — required by the latent path." + ) + lines.append("- Two strategy boxes:") + lines.append( + " - **Latent path** (`optimize_latent`): gradient-descend on `h`, decode with AE back to descriptor, " + "evaluate heads. Failure mode without `ae_align_scale > 0`: AE round-trip drift drops QC." + ) + lines.append( + ' - **Composition path** (`optimize_composition`, "differentiable KMD"): gradient-descend directly ' + "on the 94-d element-weight simplex `w`, descriptor = `w · K`. No AE in the loop." + ) + lines.append("- Two user knobs, both on `[0, 1]` (bigger = more of the named thing):") + lines.append( + " - `ae_align_scale` — latent path; 0 = no AE-alignment penalty (failure-mode " + "baseline), 1 = strongest alignment to AE fixed set. Compared at 0 / 0.25 / 1 in this run." + ) + lines.append( + " - `diversity_scale` — composition path; 0 = peaky few-element recipes, 1 = " + "multi-element recipes (default). Compared at 1.0 and 0.0 (low-diversity ablation) in this run." + ) + lines.append( + "- Optional composition add-ons: `allowed_elements` (whitelist palette), `seed_blend` (5 % uniform mix lets non-seed elements have reachable logits).\n" + ) + lines.append("**Speaker notes.**") + lines.append("- KMD-1d is differentiable in PR #17 → composition-space optimisation possible at all.") + lines.append( + '- Knob naming follows "bigger value = more of the named thing"; user doesn\'t need to read the docstring.' + ) + lines.append( + "- Same model handles both paths, so latent vs composition is a fair experiment, not an architecture comparison.\n" + ) + lines.append("**Visual asset.** Slide author draws; no pre-rendered figure.\n") + lines.append("---\n") + + # ── Slide 3 — Datasets + task types ────────────────────────────────────────────── + lines.append("## Slide 3 — Datasets and task types\n") + lines.append( + f"**Takeaway.** The framework is trained on a heterogeneous task suite " + f"({counts['n_tasks']} tasks across 4 data sources × 3 task types) joined by composition formula.\n" + ) + lines.append("**Slide content (suggested 3-column layout).**\n") + lines.append("| Task type | Count | Tasks |") + lines.append("|---|---:|---|") + for kind, label in (("reg", "Regression"), ("kr", "Kernel regression"), ("clf", "Classification")): + tasks = [t for t in cfg.task_sequence if TASK_SPECS[t]["kind"] == kind] + if tasks: + lines.append(f"| **{label}** | {len(tasks)} | {', '.join(f'`{t}`' for t in tasks)} |") + lines.append("") + lines.append("Datasets supplying these tasks:\n") + for name, ntask, nkeys in self._dataset_summary(): + lines.append(f"- **{name}** — {ntask} tasks · {nkeys} unique compositions used") + lines.append("") + lines.append("**Speaker notes.**") + lines.append( + "- Cross-source joining: every dataset has a `composition` column; the canonical formula is the join key." + ) + lines.append( + "- Kernel regression predicts an entire `(t, value)` series per composition — one head learns the shape vs `t` (DOS energy or temperature)." + ) + lines.append( + '- Classification uses inverse-frequency `class_weights` so the rare QC / AC classes stay alive against ~48k "others" rows in the qc dataset.\n' + ) + lines.append( + "**Visual asset.** Slide author renders the 3-column callout. Optional teaser: [`training/forgetting_trajectory.png`](training/forgetting_trajectory.png).\n" + ) + lines.append( + "**Raw-data pointer.** [`training/metrics_table.csv`](training/metrics_table.csv) is the flat task / type / dataset / at-intro / final / metric table.\n" + ) + lines.append("---\n") + + # ── Slide 4 — Continual training ────────────────────────────────────────────────── + lines.append("## Slide 4 — Continual training without catastrophic forgetting\n") + lines.append( + "**Takeaway.** Tasks are introduced one at a time across " + f"**{counts['n_tasks']} stages**; tiered rehearsal (5 %/10 %) keeps the older heads " + "alive. The forgetting trajectory shows every head holds its R² / accuracy as new " + "tasks are added.\n" + ) + lines.append( + "**Primary figure:** [`training/forgetting_trajectory.png`](training/forgetting_trajectory.png) " + "— per-step metric for every active head across all stages." + ) + # Build the tail-task chain from whatever ``fixed_tail`` actually contains, instead of + # hard-indexing ``[0..4]`` — a smaller-scale config might legitimately have fewer tail + # tasks, and a future plan revision could change the count. + tail_chain = " → ".join(cfg.fixed_tail) if cfg.fixed_tail else "(no fixed tail)" + lines.append( + f"Annotate the fixed-tail tasks (the last {len(cfg.fixed_tail)} steps, " + f"`{tail_chain}`) as the focus for the inverse-design section that follows.\n" + ) + lines.append("**Final-step metrics for the inverse-design heads** (the heads inverse design actually uses):\n") + lines.append("| Head | Type | Final-step metric |") + lines.append("|---|---|---:|") + for t in ["formation_energy", "magnetic_moment", "tc", "klat", "material_type"]: + if t in final: + spec = TASK_SPECS[t] + metric_name = "accuracy" if spec["kind"] == "clf" else "R²" + val = final.get(t, {}).get("primary", float("nan")) + lines.append(f"| `{t}` | {KIND_LABEL[spec['kind']]} | **{val:+.3f}** ({metric_name}) |") + lines.append("") + lines.append("**Speaker notes.**") + lines.append( + f"- Rehearsal: `replay_ratio = {cfg.replay_ratio}` for ordinary old tasks, " + f"`replay_ratio_high = {cfg.replay_ratio_high}` for the inverse-design tail (every step). " + "No layer is frozen — encoder + every active head train jointly." + ) + lines.append( + "- Task ordering minimises rehearsal cost: 12 regression first (any order), then 7 " + "kernel-regression tasks **ascending by row count** (cheapest first), then the 5 fixed-" + "tail tasks — see plan §2 for the cost argument." + ) + lines.append( + "- **Per-step parquets + per-step checkpoints** are available under " + "`training/stepNN_/` so any per-task / per-step drill-down can be made later " + "**without retraining**." + ) + lines.append("- Raw data:") + lines.append( + " - [`training/forgetting_trajectory.png`](training/forgetting_trajectory.png) — the headline curve." + ) + lines.append( + " - `training/stepNN_/_pred.parquet` — `(composition, true, pred)` for every active head at every step." + ) + lines.append(" - `training/stepNN_/_metrics.json` — per-task metric dict at that step.") + lines.append( + " - `training/stepNN_/checkpoint.pt` — model state at that step (payload `{model, task_sequence, step, new_task, active_tasks}`)." + ) + lines.append( + " - [`training/experiment_records.json`](training/experiment_records.json) — every step × every active head, both at-intro and running metrics." + ) + lines.append(" - [`training/metrics_table.csv`](training/metrics_table.csv) — flat aggregated table.\n") + lines.append("---\n") + + # ── Slide 5 — Inverse design scenario setup ────────────────────────────────────── + lines.append("## Slide 5 — Inverse design: scenario setup\n") + lines.append( + "**Takeaway.** Three scenarios share the same model, the same 20 seeds, and the " + "same primary objective (**P(QC) ↑**). Secondary objectives differ — picking which " + "scenario to feature in the talk is the slide author's narrative choice.\n" + ) + lines.append("**Slide content.** A small table or three pill boxes:\n") + lines.append("| Scenario | Primary | Secondary objectives |") + lines.append("|---|---|---|") + for name, data in scenarios.items(): + reg_targets = data.get("reg_targets", {}) + secondary = ", ".join(f"{_display(t)} {_arrow(v)} {v:+.1f}" for t, v in reg_targets.items()) + lines.append(f"| `{name}` | P(QC) ↑ (target 1.0) | {secondary} |") + lines.append("") + lines.append("**Methodology (constant across scenarios).**") + lines.append("- 20 seeds shared across scenarios (slide 6 details the split).") + lines.append( + f"- Optimisation budget: **{cfg.inverse_steps} Adam steps**, **`lr = {cfg.inverse_lr}`**, " + f"**`class_target_weight = {cfg.inverse_class_weight}`** (so QC dominates the loss)." + ) + lines.append( + "- All metrics evaluated **after** decoding the optimised descriptor back to a real composition (round-trip)." + ) + lines.append("- 8 configurations per scenario (3 latent α + 5 composition) — see slide 6.\n") + lines.append("**Speaker notes.**") + lines.append( + '- All three scenarios are first-class — the runner does not pick a "headline" scenario. Slide author chooses which to feature based on the talk\'s narrative.' + ) + lines.append("- Plan §5 lists the rationale for each scenario.\n") + lines.append('**Visual asset.** Slide author can draw a small "target dial" visual. No pre-rendered figure.\n') + lines.append( + "**Raw-data pointer.** [`inverse_design/seeds.json`](inverse_design/seeds.json) (seeds), `inverse_design//targets.json` (objective definitions per scenario).\n" + ) + lines.append("---\n") + + # ── Slide 6 — Seeds + palette + config table ───────────────────────────────────── + lines.append("## Slide 6 — Initial seeds, the element palette, and the 8 configurations\n") + lines.append( + f"**Takeaway.** Three ingredients shape the search: (a) **{len(all_seeds)} seeds** " + f"for the optimiser to start from, (b) the **{len(ALLOY_PALETTE)}-element `ALLOY_PALETTE`** the " + "constrained composition paths are allowed to use, (c) **8 configurations** isolating " + "ae_align_scale / seed_blend / palette / diversity / random-init effects.\n" + ) + lines.append("### Seeds\n") + lines.append( + f"**N = {len(all_seeds)}** = {len(strategy_seeds)} top-QC dedup + {len(explicit_seeds)} explicit-append. " + "Element-system dedup keeps the best representative per element set so the seed list spans " + "**different alloy families** rather than ratio variants of a few.\n" + ) + lines.append( + f"- **{len(strategy_seeds)} top-QC dedup seeds** (from the training-set material_type frame, picked by predicted QC probability):" + ) + for s in strategy_seeds[:8]: + lines.append(f" - `{s}`") + if len(strategy_seeds) > 8: + lines.append(f" - … ({len(strategy_seeds) - 8} more in `inverse_design/seeds.json`)") + lines.append( + f"- **{len(explicit_seeds)} explicit-append seeds** (forced regardless of QC score — known Au–Ga–RE i-QC formers):" + ) + for s in explicit_seeds: + lines.append(f" - `{s}`") + lines.append("") + + lines.append( + f"### `ALLOY_PALETTE` ({len(ALLOY_PALETTE)} elements, slide author renders periodic-table highlight)\n" + ) + lines.append( + "Range design: covers classic i-QC / d-QC formers + easy 4th/5th-period TMs + accessible lanthanides + Au (so Au–Ga–Ln seeds are reachable). Pm / Tc and Pu-class radioactives are excluded; Tm / Lu excluded as rare and expensive.\n" + ) + lines.append("- **Light alkaline earth:** Mg, Ca") + lines.append("- **Group 13:** B, Al, Ga, In, Tl") + lines.append("- **Group 14:** Si, Ge") + lines.append("- **4th-period TM (10):** Sc Ti V Cr Mn Fe Co Ni Cu Zn") + lines.append("- **5th-period TM (9, Tc excluded as radioactive):** Y Zr Nb Mo Ru Rh Pd Ag Cd") + lines.append("- **6th-period noble (needed for Au–Ga–RE seeds):** Au") + lines.append("- **Accessible lanthanides (12, Pm/Tm/Lu excluded):** La Ce Pr Nd Sm Eu Gd Tb Dy Ho Er Yb\n") + + lines.append("### The 8 configurations — what each isolates\n") + lines.append("3 latent points (along `ae_align_scale`) + 5 composition configs:\n") + lines.append("| Config (x-axis label in `comparison.png`) | Knobs | What it tests |") + lines.append("|---|---|---|") + lines.append( + "| `latent α=0` | `ae_align_scale = 0` | AE-alignment off → failure mode in PR #18's paper-baseline run (QC collapses). With `dos_density` in the training mix the latent geometry may be more robust — check this run's number. |" + ) + lines.append("| `latent α=0.25` | `ae_align_scale = 0.25` | Low alignment — intermediate point. |") + lines.append( + "| `latent α=1` | `ae_align_scale = 1.0` | Max alignment — strongest cycle-consistency constraint. |" + ) + lines.append( + "| `comp (seed)` | `seed_blend = 1.0`, all elements allowed | Strict-seed baseline. Optimiser can only rebalance the seed's existing elements — no new element can enter the support set. |" + ) + lines.append( + "| `comp (seed, 5% all)` | `seed_blend = 0.95`, all allowed | Adds 5 % uniform mass over all 94 elements so non-seed elements have reachable logits. Optimiser *can* introduce new elements but otherwise unconstrained. |" + ) + lines.append( + f"| `comp (seed, 5% all, element list)` | (above) + `allowed_elements = ALLOY_PALETTE` | Restricts the support set to the {len(ALLOY_PALETTE)} feasible alloy elements. **Practical materials-design mode.** |" + ) + lines.append( + "| `comp (seed, 5% all, element list, low diversity)` | (above) + `diversity_scale = 0` | Adds max entropy penalty → forces peaky few-element recipes. Tests whether peaky recipes still satisfy the targets. |" + ) + lines.append( + '| `comp (random)` | `initial_weights = None`, all allowed | No seed, no palette. Pure "let the optimiser explore" — the no-bias control. |' + ) + lines.append("") + lines.append("**Speaker notes.**") + lines.append("- Each row of `inverse_design//comparison.png` x-axis maps to one of these configs.") + lines.append('- Labels read as "config A, then add knob B, then add knob C" — each comma = a knob change.') + lines.append( + '- "low diversity" = `diversity_scale = 0`, the most penalised end of the diversity knob → fewest elements per output.\n' + ) + lines.append( + f"**Visual asset.** Slide author renders the periodic-table highlight from the {len(ALLOY_PALETTE)}-element list above. No pre-rendered palette figure.\n" + ) + lines.append( + "**Raw-data pointer.** [`inverse_design/seeds.json`](inverse_design/seeds.json) for the seed list; palette literal in [`samples/continual_rehearsal_full_config.toml`](../../samples/continual_rehearsal_full_config.toml).\n" + ) + lines.append("---\n") + + # ── Slide 7 — Results & discussion (the central section) ───────────────────────── + lines.append("## Slide 7 — Results & discussion\n") + lines.append( + "**Takeaway** (templated stub — fill in based on the per-scenario tables below + " + "discovered-elements list). Typical claims the slide author chooses among:\n" + ) + lines.append( + "- **Headline claim.** `comp (seed, 5% all, element list)` is the practical winner on the scenario you pick to feature — tight, physically credible alloy recipes; element discovery (specific elements present in 100 % of outputs but 0 % of seeds)." + ) + lines.append( + "- **Constraints-matter claim.** `comp (random)` lands the optimiser on the model's unconstrained global QC attractor — often physically implausible elements; demonstrates that the palette + seed are doing real work, not just regularising." + ) + lines.append( + "- **Latent-knob claim.** The `ae_align_scale` sweep on `latent α=0 / 0.25 / 1` traces the AE-alignment effect on the three target axes." + ) + lines.append("") + lines.append( + "Pick the claim(s) the actual numbers support; the per-scenario tables below carry every figure you need.\n" + ) + + lines.append("**Primary figures (per scenario).**") + for name in scenarios: + lines.append( + f"- [`inverse_design/{name}/comparison.png`](inverse_design/{name}/comparison.png) — 8-config boxplot across P(QC) + each reg target." + ) + lines.append("") + lines.append("**Supporting figures (per scenario).**") + for name in scenarios: + lines.append( + f'- [`inverse_design/{name}/element_frequency_heatmap.png`](inverse_design/{name}/element_frequency_heatmap.png) — path × top-25 elements; **bold orange** x-tick labels = elements NOT in any seed → "discovered".' + ) + lines.append("") + + # Per-scenario per-config table + discovered elements + open questions + for name, data in scenarios.items(): + reg_targets = data.get("reg_targets", {}) + paths_meta = data.get("paths", {}) + paths_details = data.get("paths_details", {}) + + lines.append(f"### Scenario: `{name}`\n") + secondary = ", ".join(f"{_display(t)} {_arrow(v)} {v:+.1f}" for t, v in reg_targets.items()) + lines.append( + f"Targets: **P(QC) ↑ (target 1.0)**, {secondary}. " + f"Seed mean QC (before): **{data.get('qc_before_mean', float('nan')):.3f}**.\n" + ) + + # Per-config table (one row per config, columns: QC mean ± std, each reg target mean) + header = ["config", "QC after (mean ± std)"] + [REG_TASK_TITLES.get(t, t) for t in reg_targets] + lines.append("| " + " | ".join(header) + " |") + lines.append("|" + "|".join(["---"] + ["---:"] * (len(header) - 1)) + "|") + for path_cfg in INVERSE_PATH_CONFIGS: + key = path_cfg["key"] + label = path_cfg["label"] + meta = paths_meta.get(key, {}) + qc_m = meta.get("qc_after_mean", float("nan")) + qc_s = meta.get("qc_after_std", float("nan")) + row = [f"`{label}`", f"{qc_m:.3f} ± {qc_s:.3f}"] + for t in reg_targets: + row.append(f"{meta.get('reg_after_decode_mean', {}).get(t, float('nan')):+.2f}") + lines.append("| " + " | ".join(row) + " |") + lines.append("") + + # Discovered elements per config (≥ 95 % occupancy, 0 in seeds) + lines.append( + "**Element discovery** (occurrence ≥ 95 % in this config's 20 outputs, **and** 0 occurrence in any seed):" + ) + any_discovered = False + for path_cfg in INVERSE_PATH_CONFIGS: + key = path_cfg["key"] + disc = _discovered(paths_details.get(key, {})) + if disc: + any_discovered = True + payload = ", ".join(f"**{sym}** ({int(round(frac * 100))}%)" for sym, frac in disc) + lines.append(f"- `{path_cfg['label']}` → {payload}") + if not any_discovered: + lines.append( + "- *(none in this run — no element passes the ≥95 % occurrence + 0-in-seeds bar. " + "Either the optimiser is just rebalancing seed elements, or the run is too early " + "to surface discoveries. Smoke runs typically have none; the formal full run " + "is expected to surface discovered elements in `comp (seed, 5% all, element list)`.)*" + ) + lines.append("") + + # Decoded example per config + lines.append("**One decoded example per config** (highest-QC seed of that config):") + for path_cfg in INVERSE_PATH_CONFIGS: + key = path_cfg["key"] + decoded = paths_details.get(key, {}).get("decoded_composition", []) + if decoded: + lines.append(f"- `{path_cfg['label']}` → `{decoded[0]}`") + lines.append("") + + # Three discussion-thread stubs (templated for the slide author) + lines.append("### Discussion threads (templated stubs — verify against numbers above)\n") + lines.append( + "1. **Element discovery is the headline.** *Fill in:* in `comp (seed, 5% all, element list)`, " + "which element(s) appear in ≥95 % of outputs and 0 % of seeds? (See the discovery list " + 'per scenario above.) If non-empty, this is the central claim — "the model found ' + "something we didn't tell it about\".\n" + ) + lines.append( + "2. **Constraints matter.** *Fill in:* `comp (random)` QC vs `comp (seed, 5% all, element list)` QC. " + "If random-init lands far from the constrained QC, the seed + palette are doing real " + "work (not regularising). If random-init still finds high QC but with implausible " + "elements (Pu / F / Mn-rich), the *physicality* of the recipe is the constraint payoff, " + "not raw QC.\n" + ) + lines.append( + "3. **Latent path α-knob role.** *Fill in:* compare `latent α=0` vs `latent α=1` QC + reg " + "targets. In PR #18's pre-`dos_density` baseline α=0 was a catastrophe (QC ~ 0.39). " + "With `dos_density` in this run's training mix, check whether α=0 is still a " + "catastrophe (claim the failure-mode story), or whether the latent geometry is now " + 'robust to α=0 (claim the α-knob has shifted from "rescue QC" to "trade QC bias ' + 'against secondary-target reach").\n' + ) + + lines.append("### Plan §5 expected baselines (for sanity-check; slide author must verify)\n") + lines.append( + "Plan §5 reports the following PR #18 + 41-elem-smoke baselines for a single " + "scenario (QC↑ / FE↓ / klat↑, 16 seeds). The formal full run should land in similar " + "magnitudes; smoke / partial runs will not.\n" + ) + lines.append("| Config | QC after | FE after | klat after | pairwise L1 | mean #elems |") + lines.append("|---|---:|---:|---:|---:|---:|") + lines.append("| latent α=0 (failure) | 0.386 ± 0.315 | +2.46 ± 0.59 | −0.44 ± 0.27 | 1.07 | 5.2 |") + lines.append("| latent α=0.5 (sweet) | **0.960 ± 0.027** | +0.92 ± 1.16 | +1.07 ± 0.31 | 0.82 | 3.4 |") + lines.append("| latent α=1.0 (max) | 0.951 ± 0.027 | +0.40 ± 1.04 | +1.20 ± 0.35 | 1.06 | 3.6 |") + lines.append("| C-strict | 0.887 ± 0.053 | +1.27 ± 0.24 | +0.76 ± 0.67 | 1.42 | 2.6 |") + lines.append("| **C-alloy (12 elem)** | 0.870 ± 0.012 | +0.84 ± 0.03 | **+1.81 ± 0.07** | 0.17 | 5.6 |") + lines.append("| **C-alloy (41 elem)** | 0.842 ± 0.018 | +0.68 ± 0.07 | **+1.84 ± 0.06** | 1.02 | 6.0 |") + lines.append("| C-rand | 0.793 ± 0.005 | −0.78 ± 0.03 | +1.77 ± 0.02 | 0.10 | 6.0 |") + lines.append("") + + lines.append("### Open questions to flag\n") + lines.append( + "- **`comp (seed)` variance.** If `comp (seed)` σ is large (≥0.2 in PR #18 paper run), " + "per-seed audit: which seeds fail? Drill down via `inverse_design//comp_seed/result.json` " + "(`qc_after_decode` per seed; `seeds` list in same file)." + ) + lines.append( + "- **Au–Ga–Ln seeds.** The 3 explicit Au–Ga–Ln seeds are known QC candidates. Their " + "*per-seed* QC in `comp (seed)` should be high — if not, that's itself a notable finding." + ) + lines.append( + "- **Scenario coverage.** This run has 3 scenarios; the deck may not need all three. " + "Pick 1–2 the audience cares about and footnote the others.\n" + ) + lines.append("---\n") + + # ── Slide 8 — Summary ──────────────────────────────────────────────────────────── + lines.append("## Slide 8 — Summary\n") + lines.append("**Takeaway** (three bullets for the slide; numbers fill in from above).\n") + lines.append( + f"1. A shared-encoder foundation model trained continually across " + f"**{counts['n_tasks']} heterogeneous tasks** with tiered rehearsal — no catastrophic " + "forgetting on the inverse-design heads (slide 4 numbers)." + ) + lines.append( + "2. Two inverse-design paths on the same model, both exposed as user-friendly `[0, 1]` " + "knobs (`ae_align_scale`, `diversity_scale`). Eight configurations per scenario " + "isolate every effect (slide 6 table)." + ) + lines.append( + "3. On the scenario(s) you feature: the constrained composition path delivers " + "physically credible recipes; element-discovery signal surfaces " + "(see scenario-specific table in slide 7)." + ) + lines.append("") + lines.append("**Failure modes (also first-class — claim them honestly).**") + lines.append("- AE-roundtrip drift without `ae_align_scale > 0` (latent path).") + lines.append("- Seed-init support-set lock without `seed_blend < 1` (composition path with strict seed).") + lines.append("- Non-physical attractors without `allowed_elements` (composition random init).\n") + lines.append( + "**Slide content.** Three takeaway bullets + a thumbnail of one of the " + "`inverse_design//comparison.png` files (slide author picks).\n" + ) + lines.append("---\n") + + # ── Slide 9 — Future work ──────────────────────────────────────────────────────── + lines.append("## Slide 9 — Future work\n") + lines.append( + "**Takeaway.** The current framework is the foundation; the next step is to wrap it " + "in an agent system, then later wire into the broader AI4S agent ecosystem.\n" + ) + lines.append("### Beat 6 — agent-based inverse-design workbench\n") + lines.append('- Natural-language goals from the user ("I want a low-density QC formed from common metals").') + lines.append( + '- An AI agent decomposes the goal + applies domain knowledge ("QC + common metals → use `allowed_elements = ALLOY_PALETTE − lanthanides`").' + ) + lines.append( + "- Agent automatically sets optimiser knobs (`ae_align_scale`, `diversity_scale`, seed strategy, palette, target dict)." + ) + lines.append("- Runs `optimize_*`, decodes outputs, generates a visualisation + PDF report.\n") + lines.append("### Beat 7 — wider AI4S agent ecosystem\n") + lines.append( + "- Foundation model becomes the fast predictor + candidate generator in the centre of a larger stack." + ) + lines.append( + "- Other agents wrap DFT / MD simulators (slow but accurate validation), automated synthesis platforms (closed-loop experimental feedback)." + ) + lines.append( + "- Pipeline: user request → foundation-model candidates → DFT validation → robotic synthesis → results loop back to retrain the foundation model.\n" + ) + lines.append( + "**Slide content.** One bullet per beat, plus a concentric-circles sketch (foundation model at the centre, agent wrappers around it, the user / world outside).\n" + ) + lines.append("---\n") + + # ── Quick reference ────────────────────────────────────────────────────────────── + lines.append("## Quick reference — files in this run folder\n") + lines.append("| File | Used by which slide |") + lines.append("|---|---|") + lines.append( + "| [`training/forgetting_trajectory.png`](training/forgetting_trajectory.png) | Slide 4 (primary) |" + ) + lines.append("| `training/stepNN_/*.png` | Slide 4 appendix (drill-down per task) |") + lines.append("| `training/stepNN_/*_pred.parquet` | Replot any per-step figure without retraining |") + lines.append("| `training/stepNN_/*_metrics.json` | Per-task metric dict at that step |") + lines.append("| `training/stepNN_/checkpoint.pt` | Restore the model at any intermediate stage |") + lines.append( + "| [`training/experiment_records.json`](training/experiment_records.json) | Full records (step × head, at-intro + running) |" + ) + lines.append( + "| [`training/metrics_table.csv`](training/metrics_table.csv) | Flat task / type / dataset / at-intro / final table |" + ) + lines.append( + "| [`training/final_model.pt`](training/final_model.pt) | Final model state_dict + task_sequence |" + ) + lines.append( + "| `inverse_design//comparison.png` | Slide 7 (primary, per scenario), Slide 8 (thumbnail) |" + ) + lines.append( + "| `inverse_design//element_frequency_heatmap.png` | Slide 7 (supporting, per scenario) |" + ) + lines.append( + "| `inverse_design///result.json` | Per-config raw arrays — `optimized_weights` (20, 94), `optimized_descriptor` (20, x_dim), per-seed predictions |" + ) + lines.append( + "| `inverse_design//summary.json` | Per-scenario aggregated stats (per-config means + stds) |" + ) + lines.append("| `inverse_design//targets.json` | Primary + secondary objective definitions |") + lines.append( + "| [`inverse_design/seeds.json`](inverse_design/seeds.json) | Slide 6 (seed names + strategy/explicit split) |" + ) + lines.append( + "| [`inverse_design/SUMMARY.md`](inverse_design/SUMMARY.md) | Cross-scenario compact summary table |" + ) + lines.append( + "| [`inverse_design/inverse_design.json`](inverse_design/inverse_design.json) | Full nested inverse-design dump (every scenario × every path) |" + ) + lines.append("| [`ANALYSIS.md`](ANALYSIS.md) | Speaker-note source (long-form analysis) |") + lines.append("| [`README.md`](README.md) | Run-folder reference / directory map |") + lines.append("") + + # ── Slide-author freedom ────────────────────────────────────────────────────────── + lines.append("## What the slide author has freedom over (and what they don't)\n") + lines.append("**Free:**") + lines.append("- Visual style (theme, colours, fonts, slide template).") + lines.append("- Layout and slide breaks.") + lines.append('- Diagrams (slides 1, 2, 3, 5, 6, 9 explicitly say "slide author draws this").') + lines.append("- Order: this document is in narrative order, but the slide author may reshuffle.") + lines.append("- Which scenario(s) to feature: the runner does not pick a headline scenario.") + lines.append( + "- Which discussion thread(s) in slide 7 to make the central claim — pick the one(s) the numbers actually support.\n" + ) + lines.append("**Not free (these are the claims):**") + lines.append( + "- All numbers in the per-scenario tables of slide 7 — quoted from `inverse_design///result.json`." + ) + lines.append( + "- The element-discovery list — computed as occurrence ≥ 95 % in a config's outputs AND 0 in any seed (the bar must be cleared to claim discovery)." + ) + lines.append("- The two-knob naming (`ae_align_scale`, `diversity_scale`) — these are the public API.") + lines.append("- The 8 configuration names (x-axis labels of every `comparison.png`).") + lines.append("- The 3 scenario names + target dicts (slide 5 table is canonical).\n") + lines.append("---\n") + + # ── Raw-data cheat sheet ────────────────────────────────────────────────────────── + lines.append("## Where the raw data lives — full cheat-sheet\n") + lines.append( + "Every figure above is fully reproducible from the raw arrays — **no need to " + "retrain or rerun the optimisation** to change a plot's style / axis / colour scheme.\n" + ) + lines.append( + "- `training/stepNN_/_pred.parquet` — `(composition, true, pred)` (KR has `t` too). Plot any per-task parity / confusion / KR-sequence at any stage." + ) + lines.append("- `training/stepNN_/_metrics.json` — per-task metric dict at that step.") + lines.append( + "- `training/stepNN_/checkpoint.pt` — model state at that step (payload: `{model, task_sequence, step, new_task, active_tasks}`)." + ) + lines.append( + "- `training/experiment_records.json` — every step × every active head metric (at-intro and running)." + ) + lines.append("- `training/metrics_table.csv` — flat task/type/dataset/at-intro/final/metric.") + lines.append( + "- `training/final_model.pt` — final model state_dict + task_sequence (consumed by `--inverse-only` / `paper_inverse_comparison.py` / `finetune_inverse_heads.py`)." + ) + lines.append("- `training/forgetting_trajectory.png` — per-step × per-task primary-metric curves.") + lines.append("- `inverse_design/seeds.json` — seeds in two segments (`strategy_seeds`, `explicit_seeds`).") + lines.append("- `inverse_design//targets.json` — primary + secondary target definitions.") + lines.append( + "- `inverse_design///result.json` — per-config full record: `optimized_weights` `(B, 94)`, `optimized_descriptor` `(B, x_dim)`, `qc_after_decode`, `reg_before` / `reg_achieved_latent` / `reg_after_decode`, `decoded_composition`." + ) + lines.append("- `inverse_design//summary.json` — per-scenario aggregated stats.") + lines.append("- `inverse_design//comparison.png` — 8-config boxplot comparison.") + lines.append( + "- `inverse_design//element_frequency_heatmap.png` — config × element occurrence heatmap; discovered-element x-tick labels are bold + orange." + ) + lines.append("- `inverse_design/SUMMARY.md` — compact cross-scenario table.\n") + lines.append( + "Element order in `optimized_weights`: " + "`foundation_model.utils.kmd_plus.DEFAULT_ELEMENTS` (94 symbols). " + "Composition-formula round-trip: `KMD.inverse(descriptor)` (or directly use `optimized_weights` which already lives on the simplex).\n" + ) + + (self.output_dir / "SLIDE_PREP.md").write_text("\n".join(lines), encoding="utf-8") + logger.info(f"Saved SLIDE_PREP.md to {self.output_dir / 'SLIDE_PREP.md'}") + + def _write_readme(self, records: list[dict[str, Any]], inverse: dict[str, Any]) -> None: + """Top-level run index — what's in this directory and where to start reading.""" + c = self._counts() + scenarios = inverse.get("scenarios", {}) if isinstance(inverse, dict) else {} + lines = [ + "# Continual rehearsal + inverse-design — run directory", + "", + f"{c['n_tasks']} supervised tasks ({c['n_reg']} reg · {c['n_kr']} kr · " + f"{c['n_clf']} clf) + autoencoder · 3 inverse-design scenarios × 4 paths.", + "", + "## Start here", + "- [`SLIDE_PREP.md`](SLIDE_PREP.md) — 9-section slide outline for the external slide author.", + "- [`ANALYSIS.md`](ANALYSIS.md) — long-form narrative analysis (speaker-note material).", + "- [`inverse_design/SUMMARY.md`](inverse_design/SUMMARY.md) — compact cross-scenario table.", + "- `inverse_design//comparison.png` + `element_frequency_heatmap.png` — per-scenario figures (three scenarios, all first-class — no demo-style single-scenario headline).", + "", + "## Directory map", + "```", + "training/", + " stepNN_/ # one dir per training step", + " _pred.parquet # (composition, true, pred) for every active head", + " _metrics.json # per-task metric dict (R²/acc/MAE/…)", + " _parity.png | _confusion.png | _sequences.png # newest-head plot only", + " checkpoint.pt # model state at that step", + " forgetting_trajectory.png # per-step × per-task primary metric", + " experiment_records.json # full records (every step × every head)", + " metrics_table.csv # flat per-task at-intro / final table", + " final_model.pt # final model state_dict + task_sequence", + " final_model_taskconfigs.json # task-config metadata for rebuilding the model", + "inverse_design/", + " seeds.json # 20 seeds (17 top-QC dedup + 3 Au-Ga-Ln)", + " inverse_design.json # full nested result dump", + " SUMMARY.md # cross-scenario compact table", + " /", + " targets.json # primary + secondary objectives", + " summary.json # per-path mean / std headline stats", + " comparison.png # 8-path boxplot (QC + each reg target)", + " element_frequency_heatmap.png # path × top-25 elements (discovered = bold orange)", + " /result.json # raw per-seed arrays, optimized_weights, …", + "SLIDE_PREP.md # slide outline + raw-data pointers", + "ANALYSIS.md # long-form analysis", + "README.md # this file", + "```", + "", + "## Scenarios", + ] + for name, data in scenarios.items(): + reg_targets = data.get("reg_targets", {}) + secondary = ", ".join(f"{_display(t)} {_arrow(v)} {v:+.1f}" for t, v in reg_targets.items()) + lines.append(f"- **{name}** — primary: QC ↑; secondary: {secondary}") + lines.append("") + (self.output_dir / "README.md").write_text("\n".join(lines), encoding="utf-8") + logger.info(f"Saved README.md to {self.output_dir / 'README.md'}") + + +# --- CLI --------------------------------------------------------------------- + + +def _load_toml(path: Path) -> dict[str, Any]: + try: + import tomllib # type: ignore[attr-defined] + except ModuleNotFoundError: # pragma: no cover + import tomli as tomllib # type: ignore + return tomllib.loads(Path(path).read_text(encoding="utf-8")) + + +def _parse_args(argv: list[str] | None = None) -> tuple[ContinualRehearsalFullConfig, argparse.Namespace]: + parser = argparse.ArgumentParser(description="Continual rehearsal + inverse-design — full run.") + parser.add_argument("--config-file", type=Path, default=None) + parser.add_argument("--output-dir", type=Path, default=None) + parser.add_argument("--sample-per-dataset", type=int, default=None) + parser.add_argument("--max-epochs-per-step", type=int, default=None) + parser.add_argument("--accelerator", type=str, default=None) + parser.add_argument( + "--inverse-only", + type=Path, + default=None, + metavar="CKPT", + help="Skip training; load a final_model.pt checkpoint and rerun only the inverse-design stage.", + ) + # Trajectory plotting flags — mirror paper_inverse_comparison's CLI so the user can switch + # animation format / opt out of per-step recording without code changes. + parser.add_argument( + "--record-trajectory", + action=argparse.BooleanOptionalAction, + default=True, + help="Record per-step optimisation trajectories and emit trajectory plots / animations " + "per scenario × path. ``--no-record-trajectory`` skips both (saves ~10 %% on the latent " + "path and the animation rendering cost).", + ) + parser.add_argument( + "--per-seed-trajectories", + action="store_true", + help="Additionally emit one plot + animation per (path × seed) under " + "``trajectories_per_seed/`` (heavy: 20× more figures). Off by default.", + ) + parser.add_argument( + "--animation-formats", + nargs="+", + choices=["gif", "html", "svg", "none"], + default=["gif"], + help="Trajectory animation formats. ``none`` disables animations (the static plot is " + "still written). Default: gif.", + ) + args = parser.parse_args(argv) + + data = _load_toml(args.config_file) if args.config_file else {} + for key in ("output_dir", "sample_per_dataset", "max_epochs_per_step", "accelerator"): + val = getattr(args, key) + if val is not None: + data[key] = val + + field_names = set(ContinualRehearsalFullConfig.__dataclass_fields__) + path_fields = { + "qc_data_path", + "qc_preprocessing_path", + "superconductor_path", + "magnetic_path", + "phonix_path", + "output_dir", + } + kwargs: dict[str, Any] = {} + for key, value in data.items(): + if key not in field_names: + logger.warning(f"Ignoring unknown config key '{key}'.") + continue + if key == "inverse_scenarios": + kwargs[key] = [InverseScenario(**sc) if isinstance(sc, dict) else sc for sc in value] + elif key in path_fields: + # Empty string means "unset" (e.g. qc_preprocessing_path with no matching pkl). + kwargs[key] = Path(value) if value not in (None, "") else None + else: + kwargs[key] = value + return ContinualRehearsalFullConfig(**kwargs), args + + +def main(argv: list[str] | None = None) -> None: + config, args = _parse_args(argv) + runner = ContinualRehearsalFullRunner(config) + traj_kwargs: dict[str, Any] = { + "record_trajectory": args.record_trajectory, + "per_seed_trajectories": args.per_seed_trajectories, + "animation_formats": tuple(args.animation_formats), + } + if args.inverse_only is not None: + runner.run_inverse_only(args.inverse_only, **traj_kwargs) + else: + runner.run(**traj_kwargs) + + +if __name__ == "__main__": + main() diff --git a/src/foundation_model/scripts/continual_rehearsal_full_test.py b/src/foundation_model/scripts/continual_rehearsal_full_test.py new file mode 100644 index 0000000..769cdc5 --- /dev/null +++ b/src/foundation_model/scripts/continual_rehearsal_full_test.py @@ -0,0 +1,268 @@ +"""Tests for the full continual-rehearsal + inverse-design runner (config/catalogue/CLI logic). + +Training and data loading are exercised by the smoke run, not here; these tests cover the pure +logic that is cheap and worth guarding: the task catalogue, config validation, and TOML/CLI parsing. +""" + +from __future__ import annotations + +import textwrap +from pathlib import Path + +import pytest + +from foundation_model.scripts.continual_rehearsal_full import ( + ALLOY_PALETTE, + DEFAULT_FIXED_TAIL, + DEFAULT_SEQUENCE, + INVERSE_PATH_CONFIGS, + INVERSE_PATHS, + REG_TASK_TITLES, + TASK_SPECS, + ContinualRehearsalFullConfig, + ContinualRehearsalFullRunner, + InverseScenario, + _arrow, + _display, + _parse_args, + _title, +) + + +def test_default_sequence_is_24_tasks_by_type(): + kinds = [TASK_SPECS[t]["kind"] for t in DEFAULT_SEQUENCE] + assert len(DEFAULT_SEQUENCE) == 24 + assert kinds.count("reg") == 16 + assert kinds.count("kr") == 7 + assert kinds.count("clf") == 1 + + +def test_catalogue_consistency(): + # Every sequenced task is known; kernel tasks declare a t_column; clf declares num_classes. + for task in DEFAULT_SEQUENCE: + spec = TASK_SPECS[task] + assert spec["kind"] in {"reg", "kr", "clf"} + if spec["kind"] == "kr": + assert "t_column" in spec + if spec["kind"] == "clf": + assert "num_classes" in spec + # The fixed tail is the last segment of the default sequence. + assert DEFAULT_SEQUENCE[-len(DEFAULT_FIXED_TAIL) :] == DEFAULT_FIXED_TAIL + # material_type is last so the QC classifier is freshest for inverse design. + assert DEFAULT_SEQUENCE[-1] == "material_type" + + +def test_inverse_path_configs_match_demo(): + # 8 configurations — 3 latent ae_align_scale points + 5 composition configs — mirroring the + # demo's paper_inverse_comparison.py so the figures read the same across runners. + assert len(INVERSE_PATH_CONFIGS) == 8 + methods = [c["method"] for c in INVERSE_PATH_CONFIGS] + assert methods.count("latent") == 3 + assert methods.count("composition") == 5 + latent_alphas = [c["ae_align_scale"] for c in INVERSE_PATH_CONFIGS if c["method"] == "latent"] + assert latent_alphas == [0.0, 0.25, 1.0] + # The key list is a flat str list of unique stable identifiers used as result subdir names. + assert INVERSE_PATHS == [c["key"] for c in INVERSE_PATH_CONFIGS] + assert len(set(INVERSE_PATHS)) == len(INVERSE_PATHS) + # One config row must hit each demo configuration knob. + keys = set(INVERSE_PATHS) + assert { + "latent_align0p0", + "latent_align0p25", + "latent_align1p0", + "comp_seed", + "comp_seed_blend", + "comp_seed_blend_palette", + "comp_seed_blend_palette_lowdiv", + "comp_random", + } == keys + + +def test_reg_task_titles_include_scenario_targets(): + # Every reg task across the three default scenarios should have a paper-style panel title. + for t in ("formation_energy", "klat", "magnetic_moment", "tc"): + assert t in REG_TASK_TITLES + assert "[" in REG_TASK_TITLES[t] and "]" in REG_TASK_TITLES[t] # units present + assert REG_TASK_TITLES[t].endswith(("↑", "↓")) + + +def test_alloy_palette_contents(): + # Plan §5 originally specified 41 elements; extended 2026-05 with the full Hf–Pt 5d TM row + # (7 symbols) → 48. The three Au-Ga-Ln explicit seeds must still fit. + assert len(ALLOY_PALETTE) == 48 + for sym in ("Au", "Ga", "Gd", "Tb", "Dy", "Mg", "Pd", "Al"): + assert sym in ALLOY_PALETTE + # 5d transition metals (Hf–Pt) — newly added. + for sym in ("Hf", "Ta", "W", "Re", "Os", "Ir", "Pt"): + assert sym in ALLOY_PALETTE + # Radioactive / unwanted symbols deliberately excluded. + for sym in ("Pu", "Tc", "Pm"): + assert sym not in ALLOY_PALETTE + + +def test_default_config_valid_and_inverse_defaults(): + cfg = ContinualRehearsalFullConfig() + assert len(cfg.inverse_scenarios) == 3 + assert all(isinstance(sc, InverseScenario) for sc in cfg.inverse_scenarios) + # Plan §5 defaults: 20 seeds (17 strategy + 3 Au-Ga-Ln) + the 41-element palette. The single- + # value ae_align / seed_blend / diversity knobs are fixed in INVERSE_PATH_CONFIGS, not the + # config dataclass — see test_inverse_path_configs_match_demo. + assert cfg.inverse_n_seeds == 20 + assert cfg.inverse_composition_allowed_elements == ALLOY_PALETTE + assert cfg.inverse_seed_explicit_append == ["Au65 Ga20 Gd15", "Au65 Ga20 Tb15", "Au65 Ga20 Dy15"] + + +def test_unknown_task_raises(): + with pytest.raises(ValueError, match="Unknown task"): + ContinualRehearsalFullConfig(task_sequence=["density", "not_a_task", "material_type"]) + + +def test_duplicate_task_raises(): + seq = list(DEFAULT_SEQUENCE) + ["density"] + with pytest.raises(ValueError, match="duplicates"): + ContinualRehearsalFullConfig(task_sequence=seq) + + +def test_fixed_tail_must_be_in_sequence(): + with pytest.raises(ValueError, match="fixed_tail"): + ContinualRehearsalFullConfig(fixed_tail=["formation_energy", "not_present", "material_type"]) + + +@pytest.mark.parametrize("ratio_kwargs", [{"replay_ratio": -0.1}, {"replay_ratio_high": 1.5}]) +def test_replay_ratio_bounds(ratio_kwargs): + with pytest.raises(ValueError, match="must be in"): + ContinualRehearsalFullConfig(**ratio_kwargs) + + +def test_allowed_elements_validation(): + with pytest.raises(ValueError, match="non-empty"): + ContinualRehearsalFullConfig(inverse_composition_allowed_elements=[]) + with pytest.raises(ValueError, match="not in DEFAULT_ELEMENTS"): + ContinualRehearsalFullConfig(inverse_composition_allowed_elements=["Mg", "Xx"]) + + +def test_inverse_scenario_length_mismatch(): + with pytest.raises(ValueError, match="equal length"): + InverseScenario("bad", ["formation_energy"], [-2.0, 2.0]) + + +def test_scenario_task_must_be_regression(): + # material_type is a classification task → cannot be a regression objective. + bad = InverseScenario("bad", ["material_type"], [1.0]) + with pytest.raises(ValueError, match="must be a"): + ContinualRehearsalFullConfig(inverse_scenarios=[bad]) + + # a kernel-regression task is also not a scalar regression objective. + bad_kr = InverseScenario("bad_kr", ["dos_density"], [1.0]) + with pytest.raises(ValueError, match="must be a"): + ContinualRehearsalFullConfig(inverse_scenarios=[bad_kr]) + + +def test_scenario_task_must_be_in_sequence(): + short_seq = ["density", "material_type"] + bad = InverseScenario("bad", ["formation_energy"], [-2.0]) + with pytest.raises(ValueError, match="not in task_sequence"): + ContinualRehearsalFullConfig(task_sequence=short_seq, fixed_tail=["material_type"], inverse_scenarios=[bad]) + + +def test_material_type_required(): + seq = [t for t in DEFAULT_SEQUENCE if t != "material_type"] + with pytest.raises(ValueError, match="material_type"): + ContinualRehearsalFullConfig(task_sequence=seq, fixed_tail=["formation_energy"], inverse_scenarios=[]) + + +def test_invalid_seed_strategy(): + with pytest.raises(ValueError, match="inverse_seed_strategy"): + ContinualRehearsalFullConfig(inverse_seed_strategy="bogus") + + +def test_display_helpers(): + assert _display("formation_energy") == "Formation Energy" + assert "Density" in _title("density") + assert "normalized" in _title("density") # qc scale + assert "z-scored" in _title("tc") # raw scale + assert _arrow(-2.0) == "↓" + assert _arrow(2.0) == "↑" + + +def test_element_system_and_dedup(): + # Element-system extraction ignores numeric ratios; dedup keeps the first per element set. + assert ContinualRehearsalFullRunner._element_system("Au65 Ga20 Gd15") == frozenset({"Au", "Ga", "Gd"}) + assert ContinualRehearsalFullRunner._element_system("Au0.65Ga0.20Gd0.15") == frozenset({"Au", "Ga", "Gd"}) + deduped = ContinualRehearsalFullRunner._dedupe_by_element_system( + ["Mg2 Zn1 Y1", "Mg1 Zn2 Y1", "Al1 Cu1 Fe1", "Mg3 Zn3 Y2"], n=10 + ) + # Mg-Zn-Y duplicates collapsed to the first occurrence; Al-Cu-Fe kept. + assert deduped == ["Mg2 Zn1 Y1", "Al1 Cu1 Fe1"] + + +def test_parse_args_tuple_return_and_toml(tmp_path: Path): + toml = tmp_path / "cfg.toml" + toml.write_text( + textwrap.dedent( + """ + qc_preprocessing_path = "" + task_sequence = ["density", "formation_energy", "magnetic_moment", "klat", "tc", "material_type"] + fixed_tail = ["formation_energy", "magnetic_moment", "tc", "klat", "material_type"] + replay_ratio_high = 0.2 + inverse_composition_allowed_elements = ["Mg", "Al", "Cu", "Pd"] + + [[inverse_scenarios]] + name = "s1" + reg_tasks = ["formation_energy", "klat"] + reg_targets = [-2.0, 2.0] + + [[inverse_scenarios]] + name = "s2" + reg_tasks = ["formation_energy", "tc", "magnetic_moment"] + reg_targets = [-2.0, 2.0, 2.0] + """ + ), + encoding="utf-8", + ) + cfg, args = _parse_args(["--config-file", str(toml), "--sample-per-dataset", "500", "--max-epochs-per-step", "2"]) + # Empty-string path field becomes None (no dropped_idx filtering). + assert cfg.qc_preprocessing_path is None + # inverse_scenarios dicts are coerced to InverseScenario objects. + assert [sc.name for sc in cfg.inverse_scenarios] == ["s1", "s2"] + assert all(isinstance(sc, InverseScenario) for sc in cfg.inverse_scenarios) + # CLI overrides land on the config; the palette override propagates from TOML. + assert cfg.sample_per_dataset == 500 + assert cfg.max_epochs_per_step == 2 + assert cfg.replay_ratio_high == 0.2 + assert cfg.inverse_composition_allowed_elements == ["Mg", "Al", "Cu", "Pd"] + # Namespace returned alongside config so main() can read --inverse-only. + assert args.inverse_only is None + + +def test_parse_args_inverse_only_flag(tmp_path: Path): + ckpt = tmp_path / "model.pt" + ckpt.write_bytes(b"placeholder") # presence-only; loading is exercised by smoke + _cfg, args = _parse_args(["--inverse-only", str(ckpt)]) + assert args.inverse_only == ckpt + + +def test_parse_args_unknown_key_ignored(tmp_path: Path): + toml = tmp_path / "cfg.toml" + toml.write_text("totally_unknown_key = 7\nreplay_ratio = 0.05\n", encoding="utf-8") + cfg, _args = _parse_args(["--config-file", str(toml)]) + assert cfg.replay_ratio == 0.05 + assert not hasattr(cfg, "totally_unknown_key") + + +def test_demo_inverse_plot_helpers_imported(): + """The runner relies on two helpers imported from ``paper_inverse_comparison`` to draw the + ``qc_vs_secondary_scatter`` and ``seed_to_optimized__*`` figures. If those imports drift + the inverse-design loop silently loses both figure groups (no test would catch a missing + plot without this guard, because the runner's training loop is only smoke-tested). + """ + from foundation_model.scripts import continual_rehearsal_full as crf + from foundation_model.scripts.paper_inverse_comparison import ( + _plot_qc_vs_reg_scatter as demo_scatter, + ) + from foundation_model.scripts.paper_inverse_comparison import ( + _plot_seed_to_optimized_mapping as demo_mapping, + ) + + assert crf._plot_qc_vs_reg_scatter is demo_scatter + assert crf._plot_seed_to_optimized_mapping is demo_mapping diff --git a/src/foundation_model/scripts/eval_inverse_methods.py b/src/foundation_model/scripts/eval_inverse_methods.py new file mode 100644 index 0000000..7d4e4ad --- /dev/null +++ b/src/foundation_model/scripts/eval_inverse_methods.py @@ -0,0 +1,443 @@ +# Copyright 2025 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +""" +Compare two inverse-design methods on a single trained checkpoint. + +Method A — latent-space optimisation with AE-alignment penalty + optimize_latent(optimize_space="latent", class_target_weight=…, ae_align_scale=λ). + The optimised latent is decoded back to a descriptor through the AE; the heads' values at + the **decoded** descriptor are reported (so "round-trip drift" is the key failure mode and + cycle-consistency is the proposed mitigation, swept over λ). + +Method B — composition-space optimisation via differentiable KMD + optimize_composition(kmd_kernel, class_target_weight=…). The optimisation variable IS the + element-weight recipe ``w``; descriptor is ``w @ K``; there is no AE in the loop. + +Both methods run on the **same model**, **same seed compositions**, and **same targets** so the +two columns are directly comparable. Output is a JSON summary + a comparison PNG. + +This script is independent of the rehearsal demo — its own CLI, own output dir, no rehearsal. + + python -m foundation_model.scripts.eval_inverse_methods \\ + --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml \\ + --checkpoint artifacts/inverse_heads_finetuned/final_model.pt \\ + --output-dir artifacts/inverse_methods_eval \\ + --align-scales 0,0.25,0.5,0.75,1.0 +""" + +from __future__ import annotations + +import argparse +import json +import time +from pathlib import Path +from typing import Any + +import matplotlib + +matplotlib.use("Agg") + +import matplotlib.pyplot as plt +import numpy as np +import torch +from lightning import seed_everything +from loguru import logger + +from foundation_model.scripts.continual_rehearsal_demo import ( + QC_CLASSES, + ContinualRehearsalConfig, + ContinualRehearsalRunner, +) +from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS, formula_to_composition + + +# --- Helpers ------------------------------------------------------------------ + + +def _qc_prob(model, x: torch.Tensor) -> np.ndarray: + with torch.no_grad(): + h = torch.tanh(model.encoder(x)) + probs = torch.softmax(model.task_heads["material_type"](h), dim=-1) + return probs[:, QC_CLASSES].sum(dim=-1).cpu().numpy() + + +def _reg_preds(model, x: torch.Tensor, tasks: list[str]) -> dict[str, np.ndarray]: + with torch.no_grad(): + h = torch.tanh(model.encoder(x)) + return {t: model.task_heads[t](h).squeeze(-1).cpu().numpy() for t in tasks} + + +def _seed_weights_from_compositions(seeds: list[str], n_components: int) -> torch.Tensor: + """Element-weight tensor (B, n_components) for ``optimize_composition`` seeding.""" + rows = [] + for c in seeds: + w = formula_to_composition(c) + if w is None: + raise ValueError(f"Cannot parse seed composition '{c}' to element weights.") + rows.append(np.asarray(w, dtype=np.float64)) + return torch.tensor(np.stack(rows), dtype=torch.float64) + + +def _decode_latent_path(kmd, descriptors: np.ndarray) -> list[str]: + """Latent path's composition output: AE-decoded descriptor → KMD.inverse → formula string.""" + try: + weights = kmd.inverse(descriptors) + except Exception as exc: # pragma: no cover + logger.warning(f"KMD.inverse failed ({exc}); skipping composition decoding.") + return [""] * descriptors.shape[0] + return _format_weights(weights) + + +def _format_weights(weights: np.ndarray, top_k: int = 6, eps: float = 1e-3) -> list[str]: + """Render element-weight rows as compact formula strings (top-K elements above ``eps``).""" + out: list[str] = [] + for row in weights: + order = np.argsort(row)[::-1] + parts = [f"{DEFAULT_ELEMENTS[i]}{row[i]:.3f}" for i in order[:top_k] if row[i] > eps] + out.append(" ".join(parts) if parts else "") + return out + + +# --- Methods ------------------------------------------------------------------ + + +def _run_latent_method( + runner: ContinualRehearsalRunner, + model, + seeds: list[str], + x_seed: torch.Tensor, + reg_targets: dict[str, float], + class_weight: float, + align_scale: float, + steps: int, + lr: float, + record_trajectory: bool = False, +) -> dict[str, Any]: + device = next(model.parameters()).device + t0 = time.perf_counter() + res = model.optimize_latent( + initial_input=x_seed, + task_targets=reg_targets, + class_targets={"material_type": QC_CLASSES}, + class_target_weight=class_weight, + ae_align_scale=align_scale, + optimize_space="latent", + steps=steps, + lr=lr, + record_input_trajectory=record_trajectory, + ) + elapsed = time.perf_counter() - t0 + + reg_names = list(reg_targets.keys()) + achieved_latent = res.optimized_target[:, 0, :].cpu().numpy() # (B, T) in reg_targets order + optimized_desc = res.optimized_input[:, 0, :] # (B, x_dim) — AE-decoded descriptor + after_qc = _qc_prob(model, optimized_desc) + after_reg = _reg_preds(model, optimized_desc, reg_names) + decoded = _decode_latent_path(runner._kmd, optimized_desc.detach().cpu().numpy()) + # Recover the per-seed element weights too, so downstream replotting (per-element bar charts, + # ratio histograms, similarity matrices) doesn't need to re-run the optimisation. + optimized_weights = runner._kmd.inverse(optimized_desc.detach().cpu().numpy()) + + out = { + "method": "latent", + "align_scale": align_scale, + "elapsed_s": elapsed, + "seeds": list(seeds), + "qc_after_decode": after_qc.tolist(), + "reg_achieved_latent": {t: achieved_latent[:, j].tolist() for j, t in enumerate(reg_names)}, + "reg_after_decode": {t: after_reg[t].tolist() for t in reg_names}, + "decoded_composition": decoded, + # Raw arrays for replotting without rerunning: (B, x_dim) descriptor and (B, n_components) weights. + "optimized_descriptor": optimized_desc.detach().cpu().numpy().tolist(), + "optimized_weights": optimized_weights.tolist(), + } + if record_trajectory: + # Per-step trajectory of the *post-decode* predictions and the per-step decoded weights. + # ``res.trajectory`` is (B, R=1, steps, T) — squeeze the restart axis to (steps, B, T). + # We additionally re-run the heads on the per-step decoded input so the "trajectory" we + # report is on the same surface as the final ``reg_after_decode`` values (the optimiser's + # internal latent-space predictions can diverge from the decode-then-predict ones when + # ``ae_align_scale`` is small — surfacing the decode-then-predict trajectory is the more + # honest signal for the user investigating "how does the recipe evolve"). + out["trajectory_targets"] = res.trajectory[:, 0, :, :].cpu().numpy().transpose(1, 0, 2).tolist() + # (B, R=1, steps, input_dim) → (steps, B, n_components) via KMD.inverse on each step. + # Batched per step: KMD.inverse expects (B, input_dim) and returns (B, n_components). + per_step_inputs = res.input_trajectory[:, 0, :, :].cpu().numpy() # (B, steps, input_dim) + per_step_inputs = per_step_inputs.transpose(1, 0, 2) # (steps, B, input_dim) + per_step_weights = [runner._kmd.inverse(per_step_inputs[s]) for s in range(per_step_inputs.shape[0])] + # (steps, B, n_components) + import numpy as _np + out["trajectory_weights"] = _np.stack(per_step_weights, axis=0).tolist() + return out + + +def _run_composition_method( + runner: ContinualRehearsalRunner, + model, + seeds: list[str], + reg_targets: dict[str, float], + class_weight: float, + steps: int, + lr: float, + allowed_elements: "str | list[str]" = "all", + element_step_scale: "float | dict[str, float]" = 1.0, +) -> dict[str, Any]: + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + kernel = runner._kmd.kernel_torch(device=device, dtype=dtype) + w_seed = _seed_weights_from_compositions(seeds, n_components=len(DEFAULT_ELEMENTS)) + + t0 = time.perf_counter() + res = model.optimize_composition( + kernel, + initial_weights=w_seed, + task_targets=reg_targets, + class_targets={"material_type": QC_CLASSES}, + class_target_weight=class_weight, + allowed_elements=allowed_elements, + element_step_scale=element_step_scale, + steps=steps, + lr=lr, + ) + elapsed = time.perf_counter() - t0 + + reg_names = list(reg_targets.keys()) + achieved = res.optimized_target.cpu().numpy() # (B, T) + optimized_desc = res.optimized_descriptor # (B, x_dim) — w @ K, no decode + final_qc = _qc_prob(model, optimized_desc) + final_reg = _reg_preds(model, optimized_desc, reg_names) + w_final = res.optimized_weights.cpu().numpy() + + return { + "method": "composition", + "align_scale": None, + "elapsed_s": elapsed, + "seeds": list(seeds), + # In composition space there is no "after-decode" drift — the model values AT the optimised + # ``w`` are the same as at the descriptor ``w @ K``. We still report both for symmetry. + "qc_after_decode": final_qc.tolist(), + "reg_achieved_latent": {t: achieved[:, j].tolist() for j, t in enumerate(reg_names)}, + "reg_after_decode": {t: final_reg[t].tolist() for t in reg_names}, + "decoded_composition": _format_weights(w_final), + # Raw arrays for replotting without rerunning: (B, x_dim) descriptor and (B, n_components) weights. + "optimized_descriptor": optimized_desc.detach().cpu().numpy().tolist(), + "optimized_weights": w_final.tolist(), + } + + +# --- Plot --------------------------------------------------------------------- + + +def _plot_summary(results: list[dict[str, Any]], reg_targets: dict[str, float], out_path: Path) -> None: + """Side-by-side: QC prob and each regression target across methods (mean ± seeds).""" + fig, axes = plt.subplots(1, 1 + len(reg_targets), figsize=(4.6 * (1 + len(reg_targets)), 4.2), squeeze=False) + axes = axes[0] + labels = [f"latent (α={r['align_scale']})" if r["method"] == "latent" else "composition" for r in results] + + # QC probability + qc_means = [float(np.mean(r["qc_after_decode"])) for r in results] + qc_stds = [float(np.std(r["qc_after_decode"])) for r in results] + x = np.arange(len(results)) + axes[0].bar(x, qc_means, yerr=qc_stds, color="#55A868", capsize=3) + axes[0].axhline(1.0, color="#C44E52", ls="--", lw=1.4, label="target = 1.0") + axes[0].set_xticks(x, labels, rotation=30, ha="right") + axes[0].set_ylim(-0.02, 1.05) + axes[0].set_ylabel("P(quasicrystal)") + axes[0].set_title("Quasicrystal Probability (primary)") + axes[0].legend(fontsize=9, loc="lower right") + + for ax, (t, tgt) in zip(axes[1:], reg_targets.items()): + means = [float(np.mean(r["reg_after_decode"][t])) for r in results] + stds = [float(np.std(r["reg_after_decode"][t])) for r in results] + ax.bar(x, means, yerr=stds, color="#4C72B0", capsize=3) + ax.axhline(tgt, color="#C44E52", ls="--", lw=1.4, label=f"target = {tgt:+.1f}") + ax.set_xticks(x, labels, rotation=30, ha="right") + ax.set_ylabel("Predicted value") + ax.set_title(f"{t}") + ax.legend(fontsize=9, loc="best") + + fig.suptitle("Inverse-design methods compared (same model, same seeds, same targets)", y=1.04) + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +# --- Main flow ---------------------------------------------------------------- + + +def evaluate( + config: ContinualRehearsalConfig, + ckpt_path: Path, + align_scales: list[float], + allowed_elements: "str | list[str]" = "all", + element_step_scale: "float | dict[str, float]" = 1.0, +) -> None: + seed_everything(config.random_seed, workers=True) + runner = ContinualRehearsalRunner(config) + model = runner._build_full_model() + + state = torch.load(ckpt_path, map_location="cpu", weights_only=True) + state_dict = state["model"] if isinstance(state, dict) and "model" in state else state + model.load_state_dict(state_dict) + model.eval() + + # Deterministic seed compositions: same set for both methods. We reuse the demo's "top-QC + # training composition" selector so this matches what users see from continual_rehearsal_demo. + device = next(model.parameters()).device + + def _qc_prob_fn(x: torch.Tensor) -> np.ndarray: + return _qc_prob(model, x) + + seeds = runner._select_seeds(model, device, _qc_prob_fn) + if not seeds: + raise RuntimeError("No seed compositions selected (check inverse_seed_strategy / data).") + x_seed, seeds = runner._descriptor_tensor(seeds, device) + logger.info(f"Selected {len(seeds)} seed compositions") + + reg_targets = {t: v for t, v in zip(config.inverse_reg_tasks, config.inverse_reg_targets)} + + results: list[dict[str, Any]] = [] + + # Method A: latent-space, sweep ae_align_scale ∈ [0, 1]. + for lam in align_scales: + logger.info(f"--- Latent method, ae_align_scale = {lam} ---") + results.append( + _run_latent_method( + runner, + model, + seeds, + x_seed, + reg_targets, + class_weight=config.inverse_class_weight, + align_scale=float(lam), + steps=config.inverse_steps, + lr=config.inverse_lr, + ) + ) + + # Method B: differentiable KMD, single run (no λ). Element constraints (if any) only apply here. + logger.info("--- Composition method (differentiable KMD) ---") + if isinstance(allowed_elements, list): + logger.info(f" allowed_elements: {len(allowed_elements)} symbol(s) — {allowed_elements}") + if isinstance(element_step_scale, dict): + logger.info(f" element_step_scale: {element_step_scale}") + elif isinstance(element_step_scale, (int, float)) and float(element_step_scale) != 1.0: + logger.info(f" element_step_scale (uniform): {element_step_scale}") + results.append( + _run_composition_method( + runner, + model, + seeds, + reg_targets, + class_weight=config.inverse_class_weight, + steps=config.inverse_steps, + lr=config.inverse_lr, + allowed_elements=allowed_elements, + element_step_scale=element_step_scale, + ) + ) + + out_dir = Path(config.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + # Compact human-readable summary alongside the full per-seed JSON. + summary = [] + for r in results: + row = { + "label": f"latent α={r['align_scale']}" if r["method"] == "latent" else "composition", + "elapsed_s": round(r["elapsed_s"], 2), + "qc_after_mean": round(float(np.mean(r["qc_after_decode"])), 4), + } + for t in reg_targets: + row[f"{t}_after_mean"] = round(float(np.mean(r["reg_after_decode"][t])), 3) + summary.append(row) + logger.info("=== Summary ===") + for row in summary: + logger.info(row) + + (out_dir / "eval_inverse_methods.json").write_text( + json.dumps({"reg_targets": reg_targets, "results": results, "summary": summary}, indent=2), + encoding="utf-8", + ) + _plot_summary(results, reg_targets, out_dir / "eval_inverse_methods.png") + logger.info(f"Wrote {out_dir / 'eval_inverse_methods.json'} and the comparison plot.") + + +def _parse_args(argv: list[str] | None = None) -> tuple[ContinualRehearsalConfig, argparse.Namespace]: + parser = argparse.ArgumentParser(description="Compare inverse-design methods on a trained checkpoint.") + parser.add_argument("--config-file", type=Path, required=True) + parser.add_argument("--checkpoint", type=Path, required=True) + parser.add_argument("--output-dir", type=Path, required=True) + parser.add_argument( + "--align-scales", + type=str, + default="0,0.25,0.5,0.75,1.0", + help="Comma-separated values in [0, 1] for ae_align_scale in the latent method.", + ) + parser.add_argument( + "--allowed-elements", + type=str, + default="", + help=( + "Comma-separated element symbols the composition method is allowed to use (hard " + "whitelist; e.g. 'Mg,Al,Cu,Ni,Zn,Ag'). Empty means every element allowed." + ), + ) + parser.add_argument( + "--locked-elements", + type=str, + default="", + help=( + "Comma-separated element symbols whose composition weight is frozen at the seed " + "value (sets element_step_scale to --locked-step-scale; default 0 = fully locked)." + ), + ) + parser.add_argument( + "--locked-step-scale", + type=float, + default=0.0, + help="Gradient multiplier for locked elements (0 = fully locked; 0.1 = slow drift).", + ) + args = parser.parse_args(argv) + + import tomllib + + data = tomllib.loads(args.config_file.read_text(encoding="utf-8")) + data["output_dir"] = str(args.output_dir) + field_names = set(ContinualRehearsalConfig.__dataclass_fields__) + path_fields = { + "qc_data_path", + "qc_preprocessing_path", + "superconductor_path", + "magnetic_path", + "phonix_path", + "output_dir", + } + kwargs: dict[str, object] = {} + for key, value in data.items(): + if key not in field_names: + continue + kwargs[key] = Path(value) if key in path_fields and value is not None else value + return ContinualRehearsalConfig(**kwargs), args + + +def main(argv: list[str] | None = None) -> None: + config, args = _parse_args(argv) + align_scales = [float(x) for x in args.align_scales.split(",") if x.strip()] + allowed_syms = [s.strip() for s in args.allowed_elements.split(",") if s.strip()] + locked_syms = [s.strip() for s in args.locked_elements.split(",") if s.strip()] + # Pass symbols straight through to optimize_composition's symbol-based API. + allowed_arg: "str | list[str]" = allowed_syms if allowed_syms else "all" + step_scale_arg: "float | dict[str, float]" = ( + {s: args.locked_step_scale for s in locked_syms} if locked_syms else 1.0 + ) + evaluate( + config, + args.checkpoint, + align_scales, + allowed_elements=allowed_arg, + element_step_scale=step_scale_arg, + ) + + +if __name__ == "__main__": + main() diff --git a/src/foundation_model/scripts/finetune_inverse_heads.py b/src/foundation_model/scripts/finetune_inverse_heads.py new file mode 100644 index 0000000..ec989f4 --- /dev/null +++ b/src/foundation_model/scripts/finetune_inverse_heads.py @@ -0,0 +1,215 @@ +# Copyright 2025 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +""" +Targeted fine-tune of the three heads used by inverse design. + +Loads a ``final_model.pt`` checkpoint produced by ``continual_rehearsal_demo``, freezes the +encoder and every other task head (including the autoencoder), and runs a short fine-tune on +just the three inverse-design heads — by default ``formation_energy``, ``klat`` and +``material_type`` — so they are as sharp as possible before we compare inverse-design methods +(latent-with-cycle-consistency vs differentiable KMD). + +The script is **independent of the rehearsal demo** (its own CLI, output dir, and checkpoint). +It reuses the demo runner only for data loading + model reconstruction; no rehearsal loop is run. + + python -m foundation_model.scripts.finetune_inverse_heads \\ + --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml \\ + --checkpoint artifacts/continual_rehearsal_inverse_baseline/final_model.pt \\ + --output-dir artifacts/inverse_heads_finetuned \\ + --epochs 30 +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Iterable + +import torch +from lightning import Trainer, seed_everything +from loguru import logger + +from foundation_model.data.datamodule import CompoundDataModule +from foundation_model.scripts.continual_rehearsal_demo import ( + ContinualRehearsalConfig, + ContinualRehearsalRunner, + _parse_args as _demo_parse_args, # noqa: F401 (kept for documentation; we parse our own args) +) + +DEFAULT_INVERSE_HEADS = ("formation_energy", "klat", "material_type") + + +def freeze_except(model, keep_heads: Iterable[str]) -> dict[str, bool]: + """Freeze encoder + every head NOT in ``keep_heads`` + task_log_sigmas; return prior requires_grad state. + + The model's ``task_log_sigmas`` ParameterDict holds the learnable loss-balancer coefficients + (one scalar per task, active when ``enable_learnable_loss_balancer=True``). Without freezing + them, ``configure_optimizers`` still picks them up and they move during the "head-only" + fine-tune — which would silently change the inverse-design objectives' relative weights and + make the comparison apples-to-oranges. We freeze every per-task balancer scalar here too, + so this script really is head-only. + """ + keep = set(keep_heads) + saved: dict[str, bool] = {} + for name, p in model.named_parameters(): + saved[name] = p.requires_grad + for p in model.encoder.parameters(): + p.requires_grad_(False) + for head_name, head in model.task_heads.items(): + train = head_name in keep + for p in head.parameters(): + p.requires_grad_(train) + # Freeze every learnable-loss-balancer scalar (no-op when the balancer is disabled). + for p in model.task_log_sigmas.parameters(): + p.requires_grad_(False) + return saved + + +def _restore_requires_grad(model, saved: dict[str, bool]) -> None: + for name, p in model.named_parameters(): + if name in saved: + p.requires_grad_(saved[name]) + + +def finetune(config: ContinualRehearsalConfig, ckpt_path: Path, inverse_heads: tuple[str, ...], epochs: int) -> Path: + seed_everything(config.random_seed, workers=True) + runner = ContinualRehearsalRunner(config) # loads data + builds KMD cache (same as demo) + + logger.info(f"Loading model checkpoint {ckpt_path}") + model = runner._build_full_model() + state = torch.load(ckpt_path, map_location="cpu", weights_only=True) + state_dict = state["model"] if isinstance(state, dict) and "model" in state else state + model.load_state_dict(state_dict) + + missing = [t for t in inverse_heads if t not in model.task_heads] + if missing: + raise ValueError( + f"Heads {missing} not found in the loaded model (have {list(model.task_heads.keys())}). " + "Check that the checkpoint was produced with the same task_sequence." + ) + + logger.info(f"Freezing everything except heads: {sorted(inverse_heads)}") + freeze_except(model, inverse_heads) + + # Deactivate every non-inverse head so the Trainer's validation_step doesn't try to forward + # them on a batch that only carries the three inverse-head columns. ``disable_task`` keeps the + # weights in ``model.disabled_task_heads`` (so the saved state_dict still contains them) but + # removes them from ``model.task_heads`` so the forward loop iterates only over the inverse + # ones. Important for KR heads (e.g. ``dos_density``) whose forward expects a ``t_sequences`` + # entry that the inverse-only DataModule does not provide. + other_active = [name for name in list(model.task_heads.keys()) if name not in inverse_heads] + if other_active: + logger.info(f"Disabling {len(other_active)} non-inverse head(s) for the duration of fine-tune: {other_active}") + model.disable_task(*other_active) + + # Use the same task configs as training (built by the runner), but restrict the DataModule to + # the inverse-head tasks and disable masking (we want all available labels for these heads). + task_configs = {name: runner._build_task_config(name) for name in inverse_heads} + for cfg in task_configs.values(): + cfg.task_masking_ratio = 1.0 # no rehearsal-style dropout — we want every label + + datamodule = CompoundDataModule( + task_configs=list(task_configs.values()), + descriptor_fn=runner.descriptor_fn, + task_frames={name: runner.task_frames[name] for name in inverse_heads}, + composition_column="composition", + random_seed=config.datamodule_random_seed, + batch_size=config.batch_size, + num_workers=config.num_workers, + ) + + trainer = Trainer( + max_epochs=epochs, + accelerator=config.accelerator, + devices=config.devices, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + ) + trainer.fit(model, datamodule=datamodule) + + # Re-activate the heads we hid so the saved state_dict's key layout matches what + # paper_inverse_comparison / eval_inverse_methods rebuild (all heads under ``task_heads``). + if other_active: + logger.info(f"Re-enabling {len(other_active)} previously-disabled head(s) before save.") + model.enable_task(*other_active) + + out_path = Path(config.output_dir) / "final_model.pt" + Path(config.output_dir).mkdir(parents=True, exist_ok=True) + torch.save( + { + "model": model.state_dict(), + "task_sequence": list(config.task_sequence), + "finetuned_heads": list(inverse_heads), + "finetune_epochs": int(epochs), + "from_checkpoint": str(ckpt_path), + }, + out_path, + ) + (Path(config.output_dir) / "finetune_summary.json").write_text( + json.dumps( + { + "from_checkpoint": str(ckpt_path), + "finetuned_heads": list(inverse_heads), + "epochs": int(epochs), + "task_sequence": list(config.task_sequence), + }, + indent=2, + ), + encoding="utf-8", + ) + logger.info(f"Saved fine-tuned checkpoint to {out_path}") + return out_path + + +def _parse_args(argv: list[str] | None = None) -> tuple[ContinualRehearsalConfig, argparse.Namespace]: + parser = argparse.ArgumentParser(description="Targeted fine-tune of inverse-design heads.") + parser.add_argument("--config-file", type=Path, required=True, help="Demo config (paths + task_sequence).") + parser.add_argument( + "--checkpoint", type=Path, required=True, help="final_model.pt produced by continual_rehearsal_demo." + ) + parser.add_argument( + "--output-dir", type=Path, required=True, help="Where to write the fine-tuned checkpoint + summary." + ) + parser.add_argument("--epochs", type=int, default=20, help="Fine-tune epochs (default 20).") + parser.add_argument( + "--inverse-heads", + type=str, + default=",".join(DEFAULT_INVERSE_HEADS), + help=f"Comma-separated head names to fine-tune. Default: {','.join(DEFAULT_INVERSE_HEADS)}.", + ) + args = parser.parse_args(argv) + + # Build the demo config (reuses the same TOML schema), overriding output_dir. + import tomllib + + data = tomllib.loads(args.config_file.read_text(encoding="utf-8")) + data["output_dir"] = str(args.output_dir) + field_names = set(ContinualRehearsalConfig.__dataclass_fields__) + path_fields = { + "qc_data_path", + "qc_preprocessing_path", + "superconductor_path", + "magnetic_path", + "phonix_path", + "output_dir", + } + kwargs: dict[str, object] = {} + for key, value in data.items(): + if key not in field_names: + continue + kwargs[key] = Path(value) if key in path_fields and value is not None else value + config = ContinualRehearsalConfig(**kwargs) + return config, args + + +def main(argv: list[str] | None = None) -> None: + config, args = _parse_args(argv) + heads = tuple(h.strip() for h in args.inverse_heads.split(",") if h.strip()) + finetune(config, args.checkpoint, heads, args.epochs) + + +if __name__ == "__main__": + main() diff --git a/src/foundation_model/scripts/finetune_inverse_heads_test.py b/src/foundation_model/scripts/finetune_inverse_heads_test.py new file mode 100644 index 0000000..3ca2725 --- /dev/null +++ b/src/foundation_model/scripts/finetune_inverse_heads_test.py @@ -0,0 +1,103 @@ +# Copyright 2026 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for ``finetune_inverse_heads.freeze_except`` — the per-parameter freeze contract. + +The full ``finetune`` entry point needs a real checkpoint + data parquets, so it's exercised by +the smoke runs under ``artifacts/inverse_design_run/finetune/``. The unit-testable piece is the +freeze logic, which is the most refactor-fragile part: a future change that accidentally +un-freezes the encoder (or forgets the per-task loss-balancer scalars) would silently break +the "apples-to-apples" comparison the script exists to enable. +""" + +from __future__ import annotations + +import pytest +import torch + +from foundation_model.models.flexible_multi_task_model import FlexibleMultiTaskModel +from foundation_model.models.model_config import ( + ClassificationTaskConfig, + MLPEncoderConfig, + RegressionTaskConfig, +) +from foundation_model.scripts.finetune_inverse_heads import freeze_except + + +INPUT_DIM = 16 +LATENT_DIM = 8 + + +def _make_model(enable_balancer: bool = False) -> FlexibleMultiTaskModel: + """Three-head model mirroring the inverse-design tail (formation_energy / klat / material_type). + + ``enable_autoencoder=False`` keeps the test fast — the freeze contract doesn't depend on the + AE head; the smoke run covers that path. + """ + enc = MLPEncoderConfig(hidden_dims=[INPUT_DIM, LATENT_DIM]) + tasks = [ + RegressionTaskConfig(name="formation_energy", data_column="formation_energy", dims=[LATENT_DIM, 4, 1]), + RegressionTaskConfig(name="klat", data_column="klat", dims=[LATENT_DIM, 4, 1]), + ClassificationTaskConfig(name="material_type", data_column="material_type", num_classes=3, dims=[LATENT_DIM, 4, 3]), + # An extra head that should be frozen (simulates ``density`` / ``tc`` / etc. in the real tail). + RegressionTaskConfig(name="density", data_column="density", dims=[LATENT_DIM, 4, 1]), + ] + return FlexibleMultiTaskModel( + task_configs=tasks, + encoder_config=enc, + enable_learnable_loss_balancer=enable_balancer, + ) + + +def _grad_state(model) -> dict[str, bool]: + return {name: p.requires_grad for name, p in model.named_parameters()} + + +def test_freeze_except_freezes_encoder_and_unkept_heads(): + """Encoder + every head NOT in ``keep`` is frozen; kept heads remain trainable.""" + model = _make_model() + inverse_heads = ("formation_energy", "klat", "material_type") + freeze_except(model, inverse_heads) + + # Encoder: every param frozen. + assert all(not p.requires_grad for p in model.encoder.parameters()) + # Kept heads: every param trainable. + for head in inverse_heads: + assert all(p.requires_grad for p in model.task_heads[head].parameters()), f"{head!r} should be trainable" + # Non-kept head (``density``): every param frozen. + assert all(not p.requires_grad for p in model.task_heads["density"].parameters()) + + +def test_freeze_except_freezes_task_log_sigmas_when_balancer_enabled(): + """The learnable per-task loss-balancer scalars MUST be frozen, otherwise the optimiser + silently shifts the inverse heads' relative weights during the head-only fine-tune and + the downstream comparison stops being apples-to-apples.""" + model = _make_model(enable_balancer=True) + # Sanity check: balancer is on so task_log_sigmas has at least one parameter. ``any()`` + # would unwrap to the scalar's bool (0.0 is falsy) — we want a count check instead. + assert len(list(model.task_log_sigmas.parameters())) > 0, "fixture must register balancer scalars" + freeze_except(model, ("formation_energy", "klat", "material_type")) + assert all(not p.requires_grad for p in model.task_log_sigmas.parameters()) + + +def test_freeze_except_returns_pre_freeze_requires_grad_state(): + """The ``saved`` dict captures the pre-call ``requires_grad`` for every named parameter — + used by ``_restore_requires_grad`` if a caller wants to roll back. The contract is that the + returned dict has one entry per ``named_parameters()`` key.""" + model = _make_model() + pre = _grad_state(model) + saved = freeze_except(model, ("formation_energy",)) + assert set(saved.keys()) == set(pre.keys()) + # All params were trainable before freezing → saved should reflect that. + assert all(v is True for v in saved.values()) + + +def test_freeze_except_handles_unknown_keep_head_silently(): + """An unknown ``keep_heads`` entry is *not* an error in this helper — it simply means + no head matches, and every head ends up frozen. This is the right contract for a low-level + freeze; the caller (``finetune``) is responsible for validating head names against the + loaded checkpoint upstream (see ``finetune`` raising on ``missing`` heads).""" + model = _make_model() + freeze_except(model, ("not_a_head",)) + for head in model.task_heads.values(): + assert all(not p.requires_grad for p in head.parameters()) diff --git a/src/foundation_model/scripts/paper_inverse_3scenarios.py b/src/foundation_model/scripts/paper_inverse_3scenarios.py new file mode 100644 index 0000000..306b07b --- /dev/null +++ b/src/foundation_model/scripts/paper_inverse_3scenarios.py @@ -0,0 +1,166 @@ +# Copyright 2025 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +""" +Run the paper-grade inverse-design comparison across multiple scenarios on a single checkpoint. + +This is a thin orchestrator around :mod:`paper_inverse_comparison`. The TOML config is expected to +contain a ``[[inverse_scenarios]]`` array of tables (see plan §5), each entry overriding +``reg_tasks`` / ``reg_targets`` for one scenario. The script loops over the scenarios and writes +each one's outputs into ``//`` so the per-scenario files (figures, raw +arrays, summary) stay isolated. + +Layout:: + + / + scenario1_fe_down_magnetic_up/ + final_model.pt # copy of the input checkpoint (self-contained) + seeds.json + results.json # per-seed raw arrays for all 11 paths (latent α-sweep + 5 comp) + comparison.png # headline 3-panel bar chart + SUMMARY.md + scenario.json # this scenario's reg_tasks/reg_targets + scenario2_fe_down_tc_up_magnetic_up/ + ... + scenario3_fe_down_klat_up/ + ... + README.md # cross-scenario summary index (hand-written downstream) + +The trained model has to expose every regression head listed in any scenario's ``reg_tasks``; +otherwise the per-scenario run will fail loudly at the model side. ``material_type`` (the +classification head) is implicit and always required for the QC primary objective. + +Run: + python -m foundation_model.scripts.paper_inverse_3scenarios \\ + --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml \\ + --checkpoint artifacts/inverse_design_run/finetune/final_model.pt \\ + --output-dir artifacts/inverse_design_run/inverse_design +""" + +from __future__ import annotations + +import argparse +import dataclasses +import json +import tomllib +from pathlib import Path +from typing import Any + +from loguru import logger + +from foundation_model.scripts.continual_rehearsal_demo import ContinualRehearsalConfig +from foundation_model.scripts.paper_inverse_comparison import _parse_args as _paper_parse_args +from foundation_model.scripts.paper_inverse_comparison import run as paper_run + + +def _load_scenarios(config_file: Path) -> list[dict[str, Any]]: + """Pull the ``[[inverse_scenarios]]`` array out of the TOML and validate it.""" + raw = tomllib.loads(config_file.read_text(encoding="utf-8")) + scenarios = raw.get("inverse_scenarios", []) + if not scenarios: + raise ValueError( + f"No [[inverse_scenarios]] array found in {config_file}. " + "Add the array (with name/reg_tasks/reg_targets) per plan §5 first." + ) + for sc in scenarios: + missing = {"name", "reg_tasks", "reg_targets"} - set(sc) + if missing: + raise ValueError(f"Scenario missing required fields {sorted(missing)}: {sc!r}.") + if len(sc["reg_tasks"]) != len(sc["reg_targets"]): + raise ValueError( + f"reg_tasks and reg_targets length mismatch in scenario {sc['name']!r}: " + f"{len(sc['reg_tasks'])} vs {len(sc['reg_targets'])}." + ) + return scenarios + + +def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Paper-grade inverse-design comparison across multiple scenarios.") + parser.add_argument("--config-file", type=Path, required=True) + parser.add_argument("--checkpoint", type=Path, required=True) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Parent folder; each scenario writes into //.", + ) + # Trajectory flags — forwarded verbatim to each scenario's ``paper_inverse_comparison.run()``. + parser.add_argument( + "--record-trajectory", + action=argparse.BooleanOptionalAction, + default=True, + help="Record per-step trajectory (default on; --no-record-trajectory to skip).", + ) + parser.add_argument( + "--per-seed-trajectories", + action=argparse.BooleanOptionalAction, + default=True, + help="Per-(path × seed) trajectory plots/animations (default on; --no-per-seed-trajectories to skip).", + ) + parser.add_argument( + "--animation-formats", + nargs="+", + choices=["gif", "html", "svg", "none"], + default=["gif"], + help="One or more trajectory-animation formats (default: gif).", + ) + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> None: + args = _parse_args(argv) + scenarios = _load_scenarios(args.config_file) + logger.info(f"Loaded {len(scenarios)} inverse-design scenarios from {args.config_file}.") + args.output_dir.mkdir(parents=True, exist_ok=True) + + # Build a baseline config once by re-using the single-scenario parser. We then ``replace`` it + # per-scenario to override ``inverse_reg_tasks`` / ``inverse_reg_targets`` / ``output_dir``. + paper_argv = [ + "--config-file", + str(args.config_file), + "--checkpoint", + str(args.checkpoint), + "--output-dir", + str(args.output_dir / scenarios[0]["name"]), # placeholder; overridden below + ] + base_config, _ = _paper_parse_args(paper_argv) + + for sc in scenarios: + sc_dir = args.output_dir / sc["name"] + sc_config: ContinualRehearsalConfig = dataclasses.replace( + base_config, + inverse_reg_tasks=list(sc["reg_tasks"]), + inverse_reg_targets=list(sc["reg_targets"]), + output_dir=sc_dir, + ) + logger.info(f"=== Scenario {sc['name']} ===") + logger.info(f" reg_tasks : {sc['reg_tasks']}") + logger.info(f" reg_targets : {sc['reg_targets']}") + logger.info(f" output : {sc_dir}") + paper_run( + sc_config, + args.checkpoint, + record_trajectory=args.record_trajectory, + per_seed_trajectories=args.per_seed_trajectories, + animation_formats=tuple(args.animation_formats), + ) + # Drop a per-scenario meta file so future readers don't need to chase results.json's + # `config` block to learn what this folder represents. + (sc_dir / "scenario.json").write_text( + json.dumps( + { + "name": sc["name"], + "reg_tasks": list(sc["reg_tasks"]), + "reg_targets": list(sc["reg_targets"]), + "primary_objective": "P(material_type = QC) ↑", + "checkpoint": str(args.checkpoint), + }, + indent=2, + ), + encoding="utf-8", + ) + logger.info(f"=== {sc['name']} done ===") + + +if __name__ == "__main__": + main() diff --git a/src/foundation_model/scripts/paper_inverse_comparison.py b/src/foundation_model/scripts/paper_inverse_comparison.py new file mode 100644 index 0000000..97d0226 --- /dev/null +++ b/src/foundation_model/scripts/paper_inverse_comparison.py @@ -0,0 +1,1226 @@ +# Copyright 2025 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +""" +Paper-grade comparison of inverse-design methods on a single trained checkpoint. + +Orchestrates a full sweep that ``eval_inverse_methods`` can do piecewise, and writes everything +(the model checkpoint, the seed list, the raw per-seed JSON, and the figures) into one folder +ready to drop into a paper draft. Reuses the per-method helpers from +``eval_inverse_methods`` so the methodology is identical. + +The study covers: + +* **Latent method** with AE-alignment scale α ∈ {0, 0.25, 1.0} — failure-mode baseline, a useful + intermediate, and the [0, 1] upper bound. (Earlier runs swept finer; the three points are enough + to show the qualitative plateau.) +* **Composition method** (differentiable KMD) under five configurations chosen to expose how + ``seed_blend``, the element whitelist, and seeding strategy affect novelty / diversity. Labels + follow a "describe the config in the label" convention: + 1. ``comp (seed)`` — ``seed_blend = 1.0`` (strict seed, support set frozen); + 2. ``comp (seed, 5% all)`` — ``seed_blend = 0.95`` (5 % uniform mixed in, all 94 elements + reachable but no whitelist); + 3. ``comp (seed, 5% all, element list)`` — (2) + ``allowed_elements = ALLOY_PALETTE``; + 4. ``comp (seed, 5% all, element list, low diversity)`` — (3) + ``diversity_scale = 0`` so + per-output entropy is penalised → peaky few-element recipes (ablation); + 5. ``comp (random)`` — ``initial_weights=None``, no seed bias. + + python -m foundation_model.scripts.paper_inverse_comparison \\ + --config-file samples/continual_rehearsal_demo_config_inverse_baseline.toml \\ + --checkpoint artifacts/inverse_heads_finetuned/final_model.pt \\ + --output-dir artifacts/paper_inverse_design +""" + +from __future__ import annotations + +import argparse +import json +import re +import shutil +from collections import Counter +from pathlib import Path +from typing import Any + +import matplotlib + +matplotlib.use("Agg") + +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +from matplotlib.offsetbox import AnnotationBbox, HPacker, TextArea +import numpy as np +import torch +from lightning import seed_everything +from loguru import logger + +from foundation_model.scripts.continual_rehearsal_common import ( + DISCOVERED_ELEMENT_COLOR, + plot_element_frequency_heatmap, +) +from foundation_model.scripts.continual_rehearsal_demo import ( + QC_CLASSES, + ContinualRehearsalConfig, + ContinualRehearsalRunner, +) +from foundation_model.scripts.eval_inverse_methods import ( + _format_weights, + _qc_prob, + _reg_preds, + _run_latent_method, + _seed_weights_from_compositions, +) + +# Feasible alloy palette for the constrained-composition runs. Designed per the plan in +# docs/continual_rehearsal_full_PLAN.md §5: light alkaline-earth + group 13/14 + the full 4th/5th +# period transition metals (Tc excluded for radioactivity) + the full Hf–Pt 5d TM row (added +# 2026-05 to broaden heavy-TM coverage — reaches refractory / noble-metal i-QC families) + Au +# (needed for Au-Ga-RE seeds) + accessible lanthanides (Pm radioactive, Tm/Lu scarce). 48 symbols +# total — wide enough to expose multiple QC-prone basins (incl. heavy-TM families), narrow enough +# to suppress Pu/F/Cs/Tm-style non-physical model bias. +DEFAULT_ALLOY_PALETTE = [ + "Mg", + "Ca", + "B", + "Al", + "Ga", + "In", + "Tl", + "Si", + "Ge", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Y", + "Zr", + "Nb", + "Mo", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + # 5d transition metals (Hf–Pt). Added 2026-05; placed between Cd and Au so the 6th-period TM + # block is contiguous. Keeps the palette ordered by period within each group. + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "La", + "Ce", + "Pr", + "Nd", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Yb", +] +assert len(DEFAULT_ALLOY_PALETTE) == 48 + +# Composition-method configurations. Each row produces one bar in the comparison plot. The first +# two isolate the seed_blend effect; the next two layer on element constraints; the last drops the +# seed entirely (random init) as the no-seed-bias control (Scheme D). +COMPOSITION_CONFIGS: list[dict[str, Any]] = [ + # diversity = 1.0 = no entropy penalty (default user-facing behaviour). + # Labels follow the "describe the config" convention: each comma-separated phrase names a + # knob that's been turned on relative to the previous row. + {"label": "comp\n(seed)", "init": "seed", "blend": 1.0, "allowed": "all", "scale": 1.0, "diversity": 1.0}, + {"label": "comp\n(seed, 5% all)", "init": "seed", "blend": 0.95, "allowed": "all", "scale": 1.0, "diversity": 1.0}, + { + "label": "comp\n(seed, 5% all, element list)", + "init": "seed", + "blend": 0.95, + "allowed": DEFAULT_ALLOY_PALETTE, + "scale": 1.0, + "diversity": 1.0, + }, + { + # Ablation: clamp diversity to 0 → max entropy penalty → forced peaky few-element recipes. + "label": "comp\n(seed, 5% all,\nelement list, low diversity)", + "init": "seed", + "blend": 0.95, + "allowed": DEFAULT_ALLOY_PALETTE, + "scale": 1.0, + "diversity": 0.0, + }, + {"label": "comp\n(random)", "init": "random", "blend": 0.95, "allowed": "all", "scale": 1.0, "diversity": 1.0}, +] +LATENT_ALIGN_SCALES = [0.0, 0.25, 1.0] # ae_align_scale ∈ [0, 1] — three points: failure / mid / max + + +#: Per-task display title with units and a directional arrow that points the way the optimiser +#: should drive the value. Defaults applied for the two tasks the plan §5 scenarios use. The +#: lookup falls back to the raw task name if a task isn't in the map (so the plot still works +#: when scenarios 1 / 2 add ``magnetic_moment`` / ``tc``). +REG_TASK_TITLES: dict[str, str] = { + "formation_energy": "Formation energy [eV/atom] ↓", + "klat": "klat [W/mK] ↑", + "magnetic_moment": "Magnetic moment [μB/f.u.] ↑", + "tc": "Critical temperature [K] ↑", +} + + +def _plot_comparison(results: list[dict[str, Any]], reg_targets: dict[str, float], out_path: Path) -> None: + """Three-panel comparison: QC probability + each regression target across all methods.""" + n_panels = 1 + len(reg_targets) + fig, axes = plt.subplots(1, n_panels, figsize=(5.6 * n_panels, 5.6), squeeze=False) + axes = axes[0] + # Single-line labels so rotated x-ticks don't collide. + labels = [r["label"].replace("\n", " ") for r in results] + colors = ["#55A868" if r["method"] == "latent" else "#2563EB" for r in results] + x = np.arange(len(results)) + + def _set_xticks(ax): + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=9) + + # Panel 1: QC probability. The arrow makes the optimisation direction explicit at a glance. + qc_means = [float(np.mean(r["qc_after_decode"])) for r in results] + qc_stds = [float(np.std(r["qc_after_decode"])) for r in results] + axes[0].bar(x, qc_means, yerr=qc_stds, color=colors, capsize=3) + axes[0].axhline(1.0, color="#C44E52", ls="--", lw=1.4, label="target = 1.0") + _set_xticks(axes[0]) + axes[0].set_ylim(-0.02, 1.05) + axes[0].set_ylabel("P(quasicrystal)") + axes[0].set_title("P(quasicrystal) ↑") + axes[0].legend(fontsize=9, loc="lower right") + + # Remaining panels: regression targets. Title pulled from REG_TASK_TITLES with the unit and + # an arrow indicating whether the target is below (↓) or above (↑) the model's baseline. + for ax, (t, tgt) in zip(axes[1:], reg_targets.items()): + means = [float(np.mean(r["reg_after_decode"][t])) for r in results] + stds = [float(np.std(r["reg_after_decode"][t])) for r in results] + ax.bar(x, means, yerr=stds, color=colors, capsize=3) + ax.axhline(tgt, color="#C44E52", ls="--", lw=1.4, label=f"target = {tgt:+.1f}") + _set_xticks(ax) + ax.set_ylabel("Predicted value") + ax.set_title(REG_TASK_TITLES.get(t, t)) + ax.legend(fontsize=9, loc="best") + + fig.suptitle("Inverse-design comparison: latent (ae_align_scale sweep) vs differentiable KMD (configs)", y=1.00) + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Wrote comparison plot to {out_path}") + + +# --- seed → optimised composition mapping plot ------------------------------------------------- + +#: Element-symbol + optional stoichiometry regex used by ``_parse_formula_to_fractions`` below. +#: ``continual_rehearsal_common.element_set`` carries the same pattern for the *set* of element +#: symbols; here we additionally need the amount (the second capture group) to recover fractions. +_COMP_RE = re.compile(r"([A-Z][a-z]?)([\d.]*)") + + +def _parse_formula_to_fractions(formula: str) -> dict[str, float]: + """Parse a composition string into ``{element: fraction}`` summing to 1. + + Handles both raw-amount formulas (``"Au65 Ga20 Gd15"`` → sum=100 → normalised to 1) and + pre-fractional formulas (``"Mg0.691 Cd0.309"`` → already sums to ~1). + """ + out: dict[str, float] = {} + for el, amt in _COMP_RE.findall(formula): + if not el: + continue + a = float(amt) if amt else 1.0 + out[el] = out.get(el, 0.0) + a + tot = sum(out.values()) + return {k: v / tot for k, v in out.items()} if tot > 0 else out + + +#: Font size for composition formula text in the seed-to-optimized plot. Tuned with the +#: ``_ROW_HEIGHT`` below to keep rows compact without text overlap. +_MAP_FONT = 13 +_MAP_ROW_HEIGHT = 0.34 # data-unit row height; figure height scales with n_rows × this + +#: Short labels used inside the parenthetical block, so a row like +#: ``Δformation_energy=-1.36`` doesn't push the right edge off the figure. Tasks not in the +#: map fall back to their raw name (covered by the lookup default in the call site). +_REG_DISPLAY_SHORT: dict[str, str] = { + "formation_energy": "FE", + "klat": "klat", + "tc": "tc", + "magnetization": "mag", + "magnetic_moment": "mm", +} + + +def _target_arrow(target_value: float, baseline: float = 0.0) -> str: + """Up-arrow if the target is above ``baseline`` (default 0 in z-scored regression space). + + Both project reg targets are z-scored; positive target ⇒ "drive up" (↑), negative ⇒ "drive + down" (↓). The arrow is rendered next to each property name in the column header and in + every row's parenthetical block, so the reader can match the delta sign against the desired + direction at a glance. + """ + return "↑" if target_value > baseline else "↓" + + +def _render_seed_row( + ax, + x_axes_frac: float, + y_data: float, + comp: dict[str, float], + qc: float, +) -> None: + """Draw one *seed* row: all-black text, no element colouring, with a ``(QC=XX.X%)`` suffix. + + The seed side is informational — the comparison signal lives on the optimised side. Keeping + the seed monochrome lets the colour gradient on the right read as a pure 'what the optimiser + did to this seed' story. + """ + if not comp: + return + items = sorted(comp.items(), key=lambda kv: -kv[1]) + parts: list = [] + for el, frac in items: + parts.append( + TextArea( + el, + textprops=dict(color="#111", fontweight="bold", fontsize=_MAP_FONT, fontfamily="monospace"), + ) + ) + parts.append( + TextArea( + f"{frac * 100:.1f} ", + textprops=dict(color="#111", fontsize=_MAP_FONT, fontfamily="monospace"), + ) + ) + parts.append( + TextArea( + f" (QC={qc * 100:.1f}%)", + textprops=dict(color="#555", fontsize=_MAP_FONT - 1, fontfamily="monospace"), + ) + ) + box = HPacker(children=parts, align="baseline", pad=0, sep=2) + ax.add_artist( + AnnotationBbox( + box, + (x_axes_frac, y_data), + xycoords=("axes fraction", "data"), + frameon=False, + box_alignment=(0, 0.5), + pad=0, + ) + ) + + +def _render_optimized_row( + ax, + x_axes_frac: float, + y_data: float, + comp: dict[str, float], + qc: float, + deltas: dict[str, float], + arrows: dict[str, str], + element_counts: Counter, + n_outputs: int, + cmap, +) -> None: + """Draw one *optimised* row: element symbols coloured by frequency in the optimised pool. + + The parenthetical block is ``(QC=XX.X%, Δ=±N.N , ...)`` — the signed + delta tells the reader how much each property moved from its seed value, and the arrow + pins down whether the target wants it to go up or down. + """ + if not comp: + return + items = sorted(comp.items(), key=lambda kv: -kv[1]) + parts: list = [] + for el, frac in items: + count = element_counts.get(el, 0) + # vmin=0 / vmax=n_outputs maps the lowest appearance count to the cmap's darkest end + # (per user request: "the lower, the closer to black"). Elements absent from the + # optimised pool can't actually appear in ``comp`` (we'd never iterate them here), so + # the ``count == 0`` branch is a defensive fallback only. + color = cmap(count / max(n_outputs, 1)) if count > 0 else "#aaaaaa" + parts.append( + TextArea( + el, + textprops=dict(color=color, fontweight="bold", fontsize=_MAP_FONT, fontfamily="monospace"), + ) + ) + parts.append( + TextArea( + f"{frac * 100:.1f} ", + textprops=dict(color="#111", fontsize=_MAP_FONT, fontfamily="monospace"), + ) + ) + # Parenthetical: QC + per-target signed delta + target-direction arrow. Use the short + # display labels so long names like ``formation_energy`` don't push the right edge of the + # axes into the colourbar. + delta_text = ", ".join(f"Δ{_REG_DISPLAY_SHORT.get(t, t)}={deltas[t]:+.2f} {arrows[t]}" for t in deltas) + parts.append( + TextArea( + f" (QC={qc * 100:.1f}%, {delta_text})", + textprops=dict(color="#555", fontsize=_MAP_FONT - 2, fontfamily="monospace"), + ) + ) + box = HPacker(children=parts, align="baseline", pad=0, sep=2) + ax.add_artist( + AnnotationBbox( + box, + (x_axes_frac, y_data), + xycoords=("axes fraction", "data"), + frameon=False, + box_alignment=(0, 0.5), + pad=0, + ) + ) + + +def _plot_seed_to_optimized_mapping( + seeds: list[str], + decoded: list[str], + out_path: Path, + *, + title: str, + seed_qc: np.ndarray, + seed_reg: dict[str, np.ndarray], + optimized_qc: np.ndarray, + optimized_reg: dict[str, np.ndarray], + reg_targets: dict[str, float], +) -> None: + """Per-seed 1:1 view — left column shows the seed, right column shows the optimiser's output. + + Both compositions are normalised to fractions and rendered as percent (so the user-facing + numbers match the seed-side ``"Au65 Ga20 Gd15"`` convention). + + * **Seed side** — all-black monochrome formula + ``(QC=XX.X%)``. + * **Optimised side** — element symbols coloured by their appearance count in the optimised + pool (cmap goes near-black for rare → bright yellow for ubiquitous, per the user's + "low end close to black" request). Parenthetical block carries QC% and per-target + signed deltas ``Δ=+/-N.N `` so the reader can match each delta's sign + against the optimisation direction at a glance. + * **Color bar** on the right shows the appearance-count scale used on the optimised side. + + The intent is to complement the aggregated ``element_frequency_heatmap.png`` with per-seed + detail — which seed gave rise to which composition under each path, and whether each + target moved correctly. + """ + n = len(seeds) + if n == 0 or len(decoded) != n: + logger.warning( + f"_plot_seed_to_optimized_mapping: seeds ({n}) / decoded ({len(decoded)}) mismatch — skipping plot." + ) + return + + seed_dicts = [_parse_formula_to_fractions(s) for s in seeds] + decoded_dicts = [_parse_formula_to_fractions(d) for d in decoded] + + # Element-presence count over the optimised pool — drives the colour scale + colour bar. + element_counts: Counter = Counter() + for d in decoded_dicts: + for el in d: + element_counts[el] += 1 + + # ``inferno`` gives high contrast across the range with the low end close to black, as + # requested. ``vmin=0`` keeps the "rare" colour distinguishable from the "common" end. + cmap = plt.cm.inferno + norm = mcolors.Normalize(vmin=0, vmax=n) + arrows = {t: _target_arrow(v) for t, v in reg_targets.items()} + + fig_height = max(6.5, _MAP_ROW_HEIGHT * n + 1.4) + # ``bbox_inches="tight"`` at savefig crops to actual artist extents, so the 20" width is a + # *minimum* — long parenthetical blocks (many reg targets, long element formulas) will + # stretch it further without colliding with the colour bar. + fig, (ax_main, ax_cbar) = plt.subplots(1, 2, figsize=(20, fig_height), gridspec_kw={"width_ratios": [70, 1]}) + ax_main.set_xlim(0, 1) + ax_main.set_ylim(-0.7, n - 0.3) + ax_main.invert_yaxis() + ax_main.set_axis_off() + + # Column headers above row 0 — also document what's in the parenthetical block, using the + # same short property names so the header matches each row's delta block exactly. + header_arrows = ", ".join(f"Δ{_REG_DISPLAY_SHORT.get(t, t)} {arrows[t]}" for t in reg_targets) + ax_main.text( + 0.005, + -0.6, + "Seed (fraction × 100, QC%)", + fontsize=_MAP_FONT, + fontweight="bold", + ha="left", + va="bottom", + ) + ax_main.text( + 0.38, + -0.6, + f"Optimised composition (fraction × 100, QC%, {header_arrows})", + fontsize=_MAP_FONT, + fontweight="bold", + ha="left", + va="bottom", + ) + + for i, (s_dict, d_dict) in enumerate(zip(seed_dicts, decoded_dicts)): + _render_seed_row(ax_main, x_axes_frac=0.005, y_data=i, comp=s_dict, qc=float(seed_qc[i])) + ax_main.text(0.355, i, "→", fontsize=15, color="#888", ha="center", va="center") + deltas_i = {t: float(optimized_reg[t][i] - seed_reg[t][i]) for t in reg_targets} + _render_optimized_row( + ax_main, + x_axes_frac=0.38, + y_data=i, + comp=d_dict, + qc=float(optimized_qc[i]), + deltas=deltas_i, + arrows=arrows, + element_counts=element_counts, + n_outputs=n, + cmap=cmap, + ) + + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + sm.set_array([]) + cb = fig.colorbar(sm, cax=ax_cbar) + cb.set_label(f"Element appearance count\nin optimised pool (out of {n})", fontsize=_MAP_FONT - 2) + cb.ax.tick_params(labelsize=_MAP_FONT - 3) + + fig.suptitle(title, fontsize=_MAP_FONT + 1, y=0.998) + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Wrote seed→optimised mapping plot to {out_path}") + + +# --- QC vs secondary-property scatter plot ---------------------------------------------------- + + +#: Marker shapes by method-group, per the user's "use shape to separate the two groups" request. +#: Circle for latent (continuous α sweep ↦ a continuous family) vs triangle for composition +#: (discrete-config family). Kept here as a single source of truth so the legend renderer and +#: the scatter loop can't drift. +_SCATTER_MARKERS = {"latent": "o", "composition": "^"} + +#: Per-group base colormaps. Greens vs Blues keep the two groups easily distinguishable at a +#: glance (the user's "two groups' base colors must be easy to tell apart"). Within each group +#: we step the colormap to encode the parameter-config ordering — see ``_group_color_ramp``. +_SCATTER_CMAPS = {"latent": plt.cm.Greens, "composition": plt.cm.Blues} + +#: Seed-layer style: star marker + the project's discovered-element orange. Distinct shape and +#: a third colour family (not Blues / Greens / red-target-lines) so the seed cloud reads as a +#: separate "starting point" anchor without competing with the optimised clouds. +_SEED_MARKER = "*" +_SEED_COLOR = DISCOVERED_ELEMENT_COLOR # ``#E67E22`` — same orange used for new elements in the heatmap + + +def _group_color_ramp(cmap, n: int) -> list: + """Evenly stepped colors across the upper portion of ``cmap``. + + Skip the very pale low end (would be invisible on white) and the near-black high end + (would look the same across both groups). The 0.35 / 0.90 window matches the band used in + the seed-to-optimised plot's element shading. + """ + if n <= 0: + return [] + if n == 1: + return [cmap(0.65)] + return [cmap(0.35 + 0.55 * i / (n - 1)) for i in range(n)] + + +def _plot_qc_vs_reg_scatter( + results: list[dict[str, Any]], + reg_targets: dict[str, float], + out_path: Path, + *, + title: str | None = None, + seed_qc: np.ndarray | None = None, + seed_reg: dict[str, np.ndarray] | None = None, +) -> None: + """One panel per secondary regression target, plotting QC prob vs that target across all paths. + + Each method's per-seed outputs become one scatter cluster: shape encodes the *group* (circle + for latent, triangle for composition — per the "use shape to separate the two groups" spec), + and color steps through that group's colormap (Greens / Blues) in label-order so the reader + can read the parameter sweep off the legend without remembering which α / config is which. + Red dashed lines mark the joint target (vertical at ``QC=1.0``, horizontal at the per-task + regression target). A figure-level legend at the bottom lists every method label once across + all panels. + + When ``seed_qc`` and ``seed_reg`` are provided, the per-seed *baseline* predictions are also + drawn — as orange ★ stars — so the reader can see how far each method moved each seed in + QC-vs-secondary space. ``seed_reg`` must carry one array per key in ``reg_targets``; missing + keys silently skip the seed layer in that panel. + """ + if not reg_targets: + logger.warning("_plot_qc_vs_reg_scatter: no reg_targets — skipping plot.") + return + if not results: + logger.warning("_plot_qc_vs_reg_scatter: no results — skipping plot.") + return + + # Split results by group, preserving the order in which ``run()`` appended them — that's + # the same order the comparison bar chart uses, so the legend matches across figures. + latent_results = [r for r in results if r["method"] == "latent"] + comp_results = [r for r in results if r["method"] == "composition"] + + # Per-group color ramps. Latent: Greens, low α → pale green, high α → deep green. Comp: + # Blues, simple-config → pale blue, full-knob config → deep blue. + latent_colors = _group_color_ramp(_SCATTER_CMAPS["latent"], len(latent_results)) + comp_colors = _group_color_ramp(_SCATTER_CMAPS["composition"], len(comp_results)) + color_by_result: dict[int, Any] = {} + for r, c in zip(latent_results, latent_colors): + color_by_result[id(r)] = c + for r, c in zip(comp_results, comp_colors): + color_by_result[id(r)] = c + + # Seeds layer: drawn first so the optimised clouds overplot it (the seed cloud is the + # "context"; the optimised clouds are the headline data). + has_seeds = seed_qc is not None and seed_reg is not None + seed_qc_arr = np.asarray(seed_qc, dtype=float) if has_seeds else None + + n_panels = len(reg_targets) + fig, axes = plt.subplots(1, n_panels, figsize=(5.6 * n_panels, 6.4), squeeze=False) + axes = axes[0] + + for ax, (task, tgt) in zip(axes, reg_targets.items()): + arrow = _target_arrow(tgt) + # Seeds first (under) — only if seed_reg has this panel's task. + if has_seeds and task in seed_reg: + seed_reg_arr = np.asarray(seed_reg[task], dtype=float) + ax.scatter( + seed_qc_arr, + seed_reg_arr, + marker=_SEED_MARKER, + color=_SEED_COLOR, + s=110, + alpha=0.85, + edgecolor="#222", + linewidths=0.7, + zorder=2, + ) + for r in results: + qc = np.asarray(r["qc_after_decode"], dtype=float) + reg = np.asarray(r["reg_after_decode"][task], dtype=float) + ax.scatter( + qc, + reg, + marker=_SCATTER_MARKERS[r["method"]], + color=color_by_result[id(r)], + s=64, + alpha=0.78, + edgecolor="#222", + linewidths=0.6, + label=r["label"].replace("\n", " "), + zorder=3, + ) + ax.axvline(1.0, color="#C44E52", ls="--", lw=1.3, alpha=0.8) + ax.axhline(tgt, color="#C44E52", ls="--", lw=1.3, alpha=0.8) + ax.set_xlim(-0.05, 1.05) + ax.set_xlabel("P(quasicrystal) ↑") + ax.set_ylabel(REG_TASK_TITLES.get(task, task)) + ax.set_title(f"QC vs {_REG_DISPLAY_SHORT.get(task, task)} {arrow} (target = {tgt:+.1f})", fontsize=11) + + # Figure-level legend across all panels. Use proxy handles so the legend orders by group + # (seeds → latent → composition → target) rather than by whichever panel happened to draw + # which marker first. + from matplotlib.lines import Line2D + + handles: list[Line2D] = [] + if has_seeds: + handles.append( + Line2D( + [0], + [0], + marker=_SEED_MARKER, + color="none", + markerfacecolor=_SEED_COLOR, + markeredgecolor="#222", + markersize=11, + label="seed (baseline)", + ) + ) + for r in latent_results: + handles.append( + Line2D( + [0], + [0], + marker=_SCATTER_MARKERS["latent"], + color="none", + markerfacecolor=color_by_result[id(r)], + markeredgecolor="#222", + markersize=9, + label=r["label"].replace("\n", " "), + ) + ) + for r in comp_results: + handles.append( + Line2D( + [0], + [0], + marker=_SCATTER_MARKERS["composition"], + color="none", + markerfacecolor=color_by_result[id(r)], + markeredgecolor="#222", + markersize=9, + label=r["label"].replace("\n", " "), + ) + ) + handles.append(Line2D([0], [0], color="#C44E52", ls="--", lw=1.3, label="target (QC=1.0 / reg-target)")) + # ncol picked so the legend fits across the figure width without wrapping past 3 rows for + # the 8-method + 1-target sweep we use in practice. + fig.legend( + handles=handles, + loc="lower center", + ncol=min(len(handles), 4), + fontsize=9, + frameon=False, + bbox_to_anchor=(0.5, -0.02), + ) + + if title: + fig.suptitle(title, y=1.00) + # Leave generous bottom padding so the legend (rendered below the axes via bbox_to_anchor) + # ends up inside the saved bbox after ``bbox_inches="tight"`` crops. + fig.tight_layout(rect=(0, 0.10, 1, 0.98)) + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Wrote QC-vs-secondary scatter plot to {out_path}") + + +def _path_slug(r: dict[str, Any]) -> str: + """Stable filename slug for one path. Latent: ``latent_align0p25``; comp: cleaned label.""" + if r["method"] == "latent": + return f"latent_align{r['align_scale']:g}".replace(".", "p") + return re.sub(r"[^a-z0-9]+", "_", r["label"].lower()).strip("_") + + +def _summarise(results: list[dict[str, Any]], reg_targets: dict[str, float]) -> list[dict[str, Any]]: + summary = [] + for r in results: + row = { + "label": r["label"].replace("\n", " "), + "method": r["method"], + "align_scale": r.get("align_scale"), + "config": r.get("config"), + "elapsed_s": round(r["elapsed_s"], 2), + "qc_after_mean": round(float(np.mean(r["qc_after_decode"])), 4), + "qc_after_std": round(float(np.std(r["qc_after_decode"])), 4), + } + for t in reg_targets: + row[f"{t}_after_mean"] = round(float(np.mean(r["reg_after_decode"][t])), 3) + row[f"{t}_after_std"] = round(float(np.std(r["reg_after_decode"][t])), 3) + summary.append(row) + return summary + + +def run( + config: ContinualRehearsalConfig, + ckpt_path: Path, + *, + record_trajectory: bool = True, + per_seed_trajectories: bool = False, + animation_formats: tuple[str, ...] = ("gif",), +) -> None: + seed_everything(config.random_seed, workers=True) + runner = ContinualRehearsalRunner(config) + + # Load the trained model exactly as we built it during training (same task_sequence). + model = runner._build_full_model() + state = torch.load(ckpt_path, map_location="cpu", weights_only=True) + state_dict = state["model"] if isinstance(state, dict) and "model" in state else state + model.load_state_dict(state_dict) + model.eval() + + out_dir = Path(config.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + # Copy the checkpoint so this folder is a self-contained paper artefact (skip when + # the source and destination resolve to the same file — happens on idempotent reruns). + dst = out_dir / "final_model.pt" + if ckpt_path.resolve() != dst.resolve(): + shutil.copy2(ckpt_path, dst) + + device = next(model.parameters()).device + + def _qc_prob_fn(x: torch.Tensor) -> np.ndarray: + return _qc_prob(model, x) + + seeds = runner._select_seeds(model, device, _qc_prob_fn) + if not seeds: + raise RuntimeError("No seed compositions selected.") + x_seed, seeds = runner._descriptor_tensor(seeds, device) + (out_dir / "seeds.json").write_text(json.dumps({"seeds": list(seeds)}, indent=2), encoding="utf-8") + logger.info(f"Selected {len(seeds)} seed compositions (saved to seeds.json)") + + reg_targets = {t: v for t, v in zip(config.inverse_reg_tasks, config.inverse_reg_targets)} + # Per-seed *baseline* predictions (before any inverse-design optimisation). These power the + # seed-side ``(QC=X.X%)`` parenthetical and the ``Δ`` deltas on the optimised side of + # the per-seed mapping plot. Computed once here against ``x_seed`` (the seed descriptors) + # and persisted in ``results.json`` under ``seed_predictions`` so future re-plots don't need + # the model loaded again. + seed_qc = _qc_prob(model, x_seed) + seed_reg = _reg_preds(model, x_seed, list(reg_targets.keys())) + results: list[dict[str, Any]] = [] + + # Latent method: ae_align_scale sweep over [0, 1]. + for lam in LATENT_ALIGN_SCALES: + logger.info(f"--- Latent method, ae_align_scale = {lam} ---") + r = _run_latent_method( + runner, + model, + seeds, + x_seed, + reg_targets, + class_weight=config.inverse_class_weight, + align_scale=lam, + steps=config.inverse_steps, + lr=config.inverse_lr, + record_trajectory=record_trajectory, + ) + r["label"] = f"latent\nα={lam:g}" + r["config"] = {"ae_align_scale": lam} + results.append(r) + + # Composition method: walk through the configuration matrix. + for cfg in COMPOSITION_CONFIGS: + logger.info(f"--- {cfg['label'].replace(chr(10), ' ')} ---") + r = _run_composition_config( + runner, + model, + seeds, + reg_targets, + class_weight=config.inverse_class_weight, + steps=config.inverse_steps, + lr=config.inverse_lr, + cfg=cfg, + record_trajectory=record_trajectory, + ) + r["label"] = cfg["label"] + r["config"] = {k: cfg[k] for k in ("init", "blend", "allowed", "scale", "diversity")} + results.append(r) + + summary = _summarise(results, reg_targets) + logger.info("=== Summary ===") + for row in summary: + logger.info(row) + + # Trajectory arrays would blow up the inlined results.json (≈ 36MB / scenario for a 300-step + # 20-seed run); persist them as compressed .npz next to results.json and replace the inline + # lists with a relative-path reference. The JSON stays browsable; replots read the .npz. + traj_dir: Path | None = None + if record_trajectory: + traj_dir = out_dir / "trajectories" + traj_dir.mkdir(exist_ok=True) + for r in results: + if "trajectory_targets" not in r: + continue + slug = _path_slug(r) + npz_path = traj_dir / f"{slug}.npz" + np.savez_compressed( + npz_path, + targets=np.asarray(r["trajectory_targets"], dtype=np.float32), + weights=np.asarray(r["trajectory_weights"], dtype=np.float32), + ) + r["trajectory_file"] = str(npz_path.relative_to(out_dir)) + del r["trajectory_targets"] + del r["trajectory_weights"] + logger.info(f"Wrote per-path trajectory arrays under {traj_dir}/") + + (out_dir / "results.json").write_text( + json.dumps( + { + "reg_targets": reg_targets, + # ``seed_predictions`` carries the baseline predictions the inverse-design + # optimisation moved away from — needed to render the per-seed mapping plot's + # ``Δ`` deltas (and the seed-side ``QC%`` parenthetical). Save here so a + # future re-plot from results.json alone never has to re-run the model. + "seed_predictions": { + "qc": seed_qc.tolist(), + "reg": {t: vals.tolist() for t, vals in seed_reg.items()}, + }, + "results": results, + "summary": summary, + }, + indent=2, + ), + encoding="utf-8", + ) + _plot_comparison(results, reg_targets, out_dir / "comparison.png") + # Per-method × top-25-element occurrence heatmap. Always written so the discovered-element + # signal (bold orange on the x-axis) is part of every paper-comparison output — the slide + # author / downstream reader doesn't need to find or rerun a separate post-hoc script. + plot_element_frequency_heatmap(results, list(seeds), out_dir / "element_frequency_heatmap.png") + # Seed → optimised 1:1 mapping plot. One figure per path that has per-seed correspondence + # (every method except ``comp (random)``, whose ``seeds`` field is a ``random_start_N`` + # placeholder rather than a real composition). Each plot's right side carries the QC% and + # per-target signed deltas so the reader can see *which seed gave rise to which output* + # and whether each target moved in the right direction. + for r in results: + if r["method"] == "composition" and r.get("config", {}).get("init") != "seed": + # ``comp (random)`` — no per-row seed correspondence. + continue + slug = _path_slug(r) + _plot_seed_to_optimized_mapping( + seeds=list(seeds), + decoded=list(r["decoded_composition"]), + out_path=out_dir / f"seed_to_optimized__{slug}.png", + title=f"Seed → optimised composition · {r['label'].replace(chr(10), ' ')}", + seed_qc=seed_qc, + seed_reg=seed_reg, + optimized_qc=np.asarray(r["qc_after_decode"]), + optimized_reg={t: np.asarray(r["reg_after_decode"][t]) for t in reg_targets}, + reg_targets=reg_targets, + ) + # Scatter view of QC prob vs each secondary reg target, grouped by method (latent = circle / + # green ramp, composition = triangle / blue ramp), with the per-seed baseline drawn as orange + # ★ stars so the reader sees how far each method moved each seed. Complements the bar chart: + # the bar chart collapses each method to a mean ± std, the scatter shows the per-seed cloud. + _plot_qc_vs_reg_scatter( + results, + reg_targets, + out_dir / "qc_vs_secondary_scatter.png", + title="QC probability vs secondary properties (per-seed outputs)", + seed_qc=seed_qc, + seed_reg=seed_reg, + ) + # Per-step optimisation trajectory plots + animations. One figure (and one animation) per + # path; ``--per-seed-trajectories`` additionally emits per-seed variants. Skipped when + # ``--no-record-trajectory`` was passed (results.json carries no trajectory_file refs then). + if record_trajectory and traj_dir is not None: + _emit_trajectory_outputs( + results=results, + reg_targets=reg_targets, + seeds=list(seeds), + seed_qc=seed_qc, + seed_reg=seed_reg, + out_dir=out_dir, + traj_dir=traj_dir, + per_seed=per_seed_trajectories, + animation_formats=animation_formats, + ) + # The auto-generated README is a compact summary table only. It writes to ``SUMMARY.md`` + # (not ``README.md``) so a user-written index — pointing to every figure, file, and the + # full ANALYSIS.md — can live at ``README.md`` without being overwritten on rerun. + _write_readme(out_dir, summary, reg_targets, ckpt_path) + logger.info(f"Paper materials written to {out_dir}") + + +def _emit_trajectory_outputs( + *, + results: list[dict[str, Any]], + reg_targets: dict[str, float], + seeds: list[str], + seed_qc: np.ndarray, + seed_reg: dict[str, np.ndarray], + out_dir: Path, + traj_dir: Path, + per_seed: bool, + animation_formats: tuple[str, ...], +) -> None: + """Render the static "normalised-progress vs step" plot + animation per path. + + Always-on: a mean across-seeds line plot per path under ``trajectories/`` with the comp panel + animated using the seed whose final state best matches all targets (joint normalised distance). + The chosen seed's composition formula is shown under the title. + + ``per_seed=True`` (the new default) also emits one plot+animation per ``(path × seed)`` under + ``trajectories_per_seed/seed{NN}/.{png,gif,html}`` — **seed-major** layout chosen so the + user can compare the same seed across all 8 paths by opening one folder. The seed's composition + string is rendered under each title so the reader doesn't have to cross-reference seed indices + against ``seeds.json``. + + ``animation_formats`` defaults to ``("gif",)``; pass extras (``html``, ``svg``) to also emit + them. ``"none"`` in the format list disables animations entirely (static plot still emitted). + """ + from foundation_model.scripts.paper_inverse_trajectory import ( + best_seed_by_target_distance, + normalize_target_trajectories, + plot_trajectory_animation, + plot_trajectory_static, + ) + from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS + + formats: list[str] = [f for f in animation_formats if f != "none"] + static_dir = out_dir / "trajectories" + static_dir.mkdir(exist_ok=True) + per_seed_dir = out_dir / "trajectories_per_seed" if per_seed else None + if per_seed_dir is not None: + per_seed_dir.mkdir(exist_ok=True) + + for r in results: + if "trajectory_file" not in r: + continue + slug = _path_slug(r) + npz_path = out_dir / r["trajectory_file"] + with np.load(npz_path) as data: + traj_targets = np.asarray(data["targets"]) # (steps, B, T_reg) + traj_weights = np.asarray(data["weights"]) # (steps, B, n_components) + # The composition method's ``trajectory_targets`` is the in-loop reg-only predictions + # (T_reg = len(reg_targets)); for the QC trajectory we replay the qc_after_decode per + # step. That requires running ``_qc_prob`` on each step's descriptor — but the npz only + # stores weights, not descriptors. Fast path: each step's QC ≈ qc_after_decode is well + # approximated by reusing the qc_after_decode of the final step as a fixed line + (initial + # − final) as a linear ramp would be wrong. So we just reconstruct the per-step QC + # trajectory by *not* including QC when it isn't in the npz; the static plot still works + # with reg-only progress. For the inverse-design study this is the right signal anyway — + # the user asked "do the reg targets converge together?" and the QC line is best read off + # the separate ``comparison.png``. + reg_names = list(reg_targets) + # Mean reg trajectory across seeds (per step → per task). + reg_traj_dict: dict[str, np.ndarray] = { + t: traj_targets[:, :, j] for j, t in enumerate(reg_names) + } + # Mean variant: use QC after-decode (final value) as a flat baseline-vs-target progress + # line only if it's available; otherwise drop QC. For the inverse-design case QC is in + # results dict but not per-step; we synthesise a "flat" QC progress line from the final + # value so it shows up on the chart for context. + qc_after = np.asarray(r["qc_after_decode"], dtype=float) + qc_traj = np.tile(qc_after[None, :], (traj_targets.shape[0], 1)) # (steps, B); only end-state QC + progress_mean = normalize_target_trajectories( + qc_trajectory=qc_traj, + reg_trajectory=reg_traj_dict, + reg_targets=reg_targets, + seed_qc=seed_qc, + seed_reg=seed_reg, + ) + # The QC entry is degenerate (flat ≈ end-state); drop it from the static plot to avoid + # misleading the reader. The animation also keeps reg-only. + progress_mean.pop("QC", None) + + # Pick the best representative seed for the animation's comp panel. + reg_final_per_task = {t: np.asarray(r["reg_after_decode"][t], dtype=float) for t in reg_names} + best_idx = best_seed_by_target_distance(qc_after, reg_final_per_task, reg_targets) + per_step_weights_best = traj_weights[:, best_idx, :] # (steps, n_components) + # Map the path's per-row "seeds" entry to a comp string. For comp_random the entry is + # ``random_start_N`` placeholder text; surface it verbatim so the title still says where + # the row came from. The ``r["seeds"]`` carried by every path is exactly the per-row + # label sequence; fall back to the shared ``seeds`` arg if a path forgot to set it. + per_row_seeds = list(r.get("seeds", seeds)) + + # --- Static plot (mean across seeds) --- + static_out = static_dir / f"trajectory__{slug}.png" + plot_trajectory_static( + progress_mean, + static_out, + title=f"Optimisation trajectory · {r['label'].replace(chr(10), ' ')} (mean over {qc_after.shape[0]} seeds)", + ) + + # --- Animation (mean curves + best-seed comp panel) --- + if formats: + out_paths = {fmt: static_dir / f"trajectory__{slug}.{fmt}" for fmt in formats} + plot_trajectory_animation( + progress_mean, + per_step_weights_best, + element_symbols=list(DEFAULT_ELEMENTS), + out_paths_by_format=out_paths, + title=f"Trajectory · {r['label'].replace(chr(10), ' ')} (best seed: {best_idx})", + seed_composition=per_row_seeds[best_idx], + ) + + # --- Per-seed variants (seed-major layout: trajectories_per_seed/seed{NN}/.{ext}) --- + if per_seed_dir is not None: + for seed_i in range(qc_after.shape[0]): + seed_dir = per_seed_dir / f"seed{seed_i:02d}" + seed_dir.mkdir(exist_ok=True) + seed_comp = per_row_seeds[seed_i] + reg_traj_one_seed = {t: traj_targets[:, seed_i : seed_i + 1, j] for j, t in enumerate(reg_names)} + qc_traj_one_seed = qc_traj[:, seed_i : seed_i + 1] + progress_seed = normalize_target_trajectories( + qc_trajectory=qc_traj_one_seed, + reg_trajectory=reg_traj_one_seed, + reg_targets=reg_targets, + seed_qc=seed_qc[seed_i : seed_i + 1], + seed_reg={t: vals[seed_i : seed_i + 1] for t, vals in seed_reg.items()}, + ) + progress_seed.pop("QC", None) + seed_static = seed_dir / f"{slug}.png" + plot_trajectory_static( + progress_seed, + seed_static, + title=f"{r['label'].replace(chr(10), ' ')} · seed {seed_i}", + seed_composition=seed_comp, + ) + if formats: + seed_out_paths = {fmt: seed_dir / f"{slug}.{fmt}" for fmt in formats} + plot_trajectory_animation( + progress_seed, + traj_weights[:, seed_i, :], + element_symbols=list(DEFAULT_ELEMENTS), + out_paths_by_format=seed_out_paths, + title=f"{r['label'].replace(chr(10), ' ')} · seed {seed_i}", + seed_composition=seed_comp, + ) + + +def _run_composition_config( + runner: ContinualRehearsalRunner, + model, + seeds: list[str], + reg_targets: dict[str, float], + *, + class_weight: float, + steps: int, + lr: float, + cfg: dict[str, Any], + record_trajectory: bool = False, +) -> dict[str, Any]: + """Run :meth:`optimize_composition` under one config row (handles seed/random init both).""" + import time + + from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS + + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + kernel = runner._kmd.kernel_torch(device=device, dtype=dtype) + + if cfg["init"] == "seed": + w_seed = _seed_weights_from_compositions(seeds, n_components=len(DEFAULT_ELEMENTS)) + init_kwargs = {"initial_weights": w_seed, "seed_blend": cfg["blend"]} + elif cfg["init"] == "random": + # n_starts matches the seed count so per-row aggregation lines up with the latent runs. + init_kwargs = {"initial_weights": None, "n_starts": len(seeds)} + else: + raise ValueError(f"Unknown init mode in config: {cfg['init']!r}") + + t0 = time.perf_counter() + res = model.optimize_composition( + kernel, + task_targets=reg_targets, + class_targets={"material_type": QC_CLASSES}, + class_target_weight=class_weight, + diversity_scale=cfg["diversity"], + allowed_elements=cfg["allowed"], + element_step_scale=cfg["scale"], + steps=steps, + lr=lr, + record_weights_trajectory=record_trajectory, + **init_kwargs, + ) + elapsed = time.perf_counter() - t0 + + reg_names = list(reg_targets) + optimized_desc = res.optimized_descriptor + w_final = res.optimized_weights.cpu().numpy() + out = { + "method": "composition", + "align_scale": None, + "elapsed_s": elapsed, + # For random init the "seeds" entry is informational only — there's no per-row correspondence. + "seeds": list(seeds) if cfg["init"] == "seed" else [f"random_start_{i}" for i in range(len(seeds))], + "qc_after_decode": _qc_prob(model, optimized_desc).tolist(), + "reg_achieved_latent": {t: res.optimized_target.cpu().numpy()[:, j].tolist() for j, t in enumerate(reg_names)}, + "reg_after_decode": {t: _reg_preds(model, optimized_desc, [t])[t].tolist() for t in reg_names}, + "decoded_composition": _format_weights(w_final), + # Raw arrays — keep so future replots (per-element bar charts, similarity matrices, etc.) + # don't have to re-run the optimisation. ``optimized_weights`` is (B, n_components), + # ``optimized_descriptor`` is (B, x_dim); element order matches DEFAULT_ELEMENTS. + "optimized_descriptor": optimized_desc.detach().cpu().numpy().tolist(), + "optimized_weights": w_final.tolist(), + } + if record_trajectory: + # ``res.trajectory`` is (steps, B, T) in reg-task order — already on the right surface. + # ``res.weights_trajectory`` is (steps, B, n_components) and is the per-step recipe + # exactly (no decode needed — composition method's optim variable already lives there). + out["trajectory_targets"] = res.trajectory.cpu().numpy().tolist() + out["trajectory_weights"] = res.weights_trajectory.cpu().numpy().tolist() + return out + + +def _write_readme(out_dir: Path, summary: list[dict[str, Any]], reg_targets: dict[str, float], ckpt_path: Path) -> None: + lines = [ + "# Inverse-design method comparison — paper materials", + "", + f"Trained model: `final_model.pt` (copied from `{ckpt_path}`).", + "Seed compositions: top-QC training compositions, listed in `seeds.json`.", + f"Targets: QC probability → 1.0; {', '.join(f'{t} → {v:+.1f}' for t, v in reg_targets.items())}.", + "", + "Raw per-seed JSON: `results.json` (one entry per method+config).", + "Comparison figure: `comparison.png`.", + "", + "## Summary (mean ± std across seeds)", + "", + "| label | QC after | " + " | ".join(f"{t} after" for t in reg_targets) + " | elapsed (s) |", + "| --- | --- | " + " | ".join("---" for _ in reg_targets) + " | --- |", + ] + for row in summary: + qc_cell = f"{row['qc_after_mean']:.3f} ± {row['qc_after_std']:.3f}" + reg_cells = [f"{row[f'{t}_after_mean']:+.2f} ± {row[f'{t}_after_std']:.2f}" for t in reg_targets] + lines.append(f"| {row['label']} | {qc_cell} | " + " | ".join(reg_cells) + f" | {row['elapsed_s']} |") + (out_dir / "SUMMARY.md").write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def _parse_args(argv: list[str] | None = None) -> tuple[ContinualRehearsalConfig, argparse.Namespace]: + parser = argparse.ArgumentParser(description="Paper-grade inverse-design comparison.") + parser.add_argument("--config-file", type=Path, required=True) + parser.add_argument("--checkpoint", type=Path, required=True) + parser.add_argument("--output-dir", type=Path, required=True) + parser.add_argument( + "--record-trajectory", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Record per-step optimisation trajectory (target predictions + per-step composition) " + "per path. Adds ~10–30 % runtime + a few MB of disk per scenario but enables the " + "trajectory_* plots and animations. Default: on. Use --no-record-trajectory to skip." + ), + ) + parser.add_argument( + "--per-seed-trajectories", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Also emit one trajectory plot + animation per (path × seed) under " + "trajectories_per_seed/seed{NN}/.{png,gif,html} (seed-major layout — easier " + "to compare paths for one seed). Default: on. Adds ~480 PNGs / scenario plus 480 GIFs " + "(~1GB) + 480 HTMLs (~5GB) if both anim formats are on; use --no-per-seed-trajectories " + "to skip when you only need the across-seed-mean view." + ), + ) + parser.add_argument( + "--animation-formats", + nargs="+", + choices=["gif", "html", "svg", "none"], + default=["gif"], + help=( + "Animation output formats. ``gif`` (default) uses matplotlib's Pillow writer; " + "``html`` emits an interactive JS-controlled HTML file (matplotlib HTMLWriter); " + "``svg`` emits a SMIL-animated single-file SVG; ``none`` disables animations " + "(static plot still emitted). Multi-select supported, e.g. --animation-formats gif html." + ), + ) + args = parser.parse_args(argv) + + import tomllib + + data = tomllib.loads(args.config_file.read_text(encoding="utf-8")) + data["output_dir"] = str(args.output_dir) + field_names = set(ContinualRehearsalConfig.__dataclass_fields__) + path_fields = { + "qc_data_path", + "qc_preprocessing_path", + "superconductor_path", + "magnetic_path", + "phonix_path", + "output_dir", + } + kwargs: dict[str, object] = {} + for key, value in data.items(): + if key not in field_names: + continue + kwargs[key] = Path(value) if key in path_fields and value is not None else value + return ContinualRehearsalConfig(**kwargs), args + + +def main(argv: list[str] | None = None) -> None: + config, args = _parse_args(argv) + run( + config, + args.checkpoint, + record_trajectory=args.record_trajectory, + per_seed_trajectories=args.per_seed_trajectories, + animation_formats=tuple(args.animation_formats), + ) + + +if __name__ == "__main__": + main() diff --git a/src/foundation_model/scripts/paper_inverse_comparison_test.py b/src/foundation_model/scripts/paper_inverse_comparison_test.py new file mode 100644 index 0000000..9c00393 --- /dev/null +++ b/src/foundation_model/scripts/paper_inverse_comparison_test.py @@ -0,0 +1,211 @@ +# Copyright 2025 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the pure helpers in :mod:`paper_inverse_comparison`. + +The main ``run()`` function needs a trained checkpoint + KMD kernel to exercise end-to-end (see +the smoke runs under ``artifacts/inverse_design_run/``); this file targets the *units that don't* +need either — the formula parser, and the two output plot helpers we added in this PR. +""" + +from __future__ import annotations + +import numpy as np + +from foundation_model.scripts.paper_inverse_comparison import ( + _parse_formula_to_fractions, + _plot_qc_vs_reg_scatter, + _plot_seed_to_optimized_mapping, + _target_arrow, +) + + +# --- _parse_formula_to_fractions ---------------------------------------------------------- + + +def test_parse_raw_amount_formula_normalises_to_fractions(): + # Seeds typically come in raw-amount form like "Au65 Ga20 Gd15"; the parser must normalise + # so the same downstream code can read it as fractions. + out = _parse_formula_to_fractions("Au65 Ga20 Gd15") + assert sorted(out.keys()) == ["Au", "Ga", "Gd"] + assert abs(sum(out.values()) - 1.0) < 1e-12 + assert abs(out["Au"] - 0.65) < 1e-12 + assert abs(out["Ga"] - 0.20) < 1e-12 + assert abs(out["Gd"] - 0.15) < 1e-12 + + +def test_parse_pre_fractional_formula_kept_as_fractions(): + # Decoded compositions land here in fractional form ("Mg0.691 Cd0.309 …"); they must round-trip. + out = _parse_formula_to_fractions("Mg0.691 Cd0.309") + assert abs(sum(out.values()) - 1.0) < 1e-12 + assert abs(out["Mg"] - 0.691) < 1e-12 + assert abs(out["Cd"] - 0.309) < 1e-12 + + +def test_parse_handles_missing_amount_as_unit(): + # A bare element symbol ("Mg") gets unit amount, then normalised. + out = _parse_formula_to_fractions("Mg Cu Ni") + # 3 elements, equal amounts, fractions = 1/3 each. + assert sorted(out.keys()) == ["Cu", "Mg", "Ni"] + for v in out.values(): + assert abs(v - 1.0 / 3.0) < 1e-12 + + +def test_parse_empty_formula_returns_empty_dict(): + assert _parse_formula_to_fractions("") == {} + + +# --- _target_arrow -------------------------------------------------------------------------- + + +def test_target_arrow_up_for_positive_target(): + """Target above baseline ⇒ ↑ (optimisation drives the value up).""" + assert _target_arrow(2.0) == "↑" + assert _target_arrow(0.1) == "↑" + + +def test_target_arrow_down_for_negative_or_zero_target(): + """Target at or below baseline ⇒ ↓. The convention treats 0 as "no clear up direction".""" + assert _target_arrow(-2.0) == "↓" + assert _target_arrow(0.0) == "↓" + + +# --- _plot_seed_to_optimized_mapping ------------------------------------------------------ + + +def _mapping_kwargs(seeds: list[str], decoded: list[str]) -> dict: + """Reasonable defaults for the helper's per-seed QC / reg arguments. + + Tests don't care about specific numbers — they just need arrays the same length as the + seed list. Reg-target names map to the project's plan §5 targets. + """ + n = len(seeds) + return dict( + seed_qc=np.full(n, 0.5), + seed_reg={"formation_energy": np.full(n, 0.3), "klat": np.full(n, 0.1)}, + optimized_qc=np.full(n, 0.9), + optimized_reg={"formation_energy": np.full(n, -0.5), "klat": np.full(n, 1.6)}, + reg_targets={"formation_energy": -2.0, "klat": 2.0}, + ) + + +def test_plot_seed_to_optimized_mapping_writes_png(tmp_path): + seeds = [ + "Mg12 Cu3 Ni3", + "Au65 Ga20 Gd15", + "Al6 Co1 Cu3", + ] + decoded = [ + "Mg0.50 Cu0.30 Ni0.20", + "Au0.55 Ga0.30 Gd0.15", + "Al0.60 Pd0.20 Ti0.20", # introduces Pd / Ti not in seeds + ] + out = tmp_path / "seed_to_optimized.png" + _plot_seed_to_optimized_mapping(seeds, decoded, out, title="test scenario", **_mapping_kwargs(seeds, decoded)) + assert out.exists() + + +def test_plot_seed_to_optimized_mapping_skips_on_length_mismatch(tmp_path): + """Mismatched seeds / decoded lengths must not crash — log a warning and skip the write.""" + out = tmp_path / "should_not_exist.png" + _plot_seed_to_optimized_mapping( + ["Mg1 Cu1"], ["Mg0.5 Cu0.5", "Al1.0"], out, title="bad", **_mapping_kwargs(["Mg1 Cu1"], ["Mg0.5 Cu0.5"]) + ) + assert not out.exists() + + +def test_plot_seed_to_optimized_mapping_skips_on_empty(tmp_path): + out = tmp_path / "should_not_exist.png" + _plot_seed_to_optimized_mapping([], [], out, title="empty", **_mapping_kwargs([], [])) + assert not out.exists() + + +# --- _plot_qc_vs_reg_scatter ---------------------------------------------------------------- + + +def _scatter_result(method: str, label: str, n: int = 6, **extra) -> dict: + """Minimal ``results`` row shape consumed by ``_plot_qc_vs_reg_scatter``. + + Only the fields the scatter helper reads are populated — ``method``, ``label``, + ``qc_after_decode``, and ``reg_after_decode``. Numbers are arbitrary; the test asserts + the helper writes a PNG without raising. + """ + rng = np.random.default_rng(abs(hash(label)) % (2**31)) + return { + "method": method, + "label": label, + "qc_after_decode": rng.uniform(0.2, 0.95, size=n).tolist(), + "reg_after_decode": { + "formation_energy": rng.uniform(-1.5, -0.2, size=n).tolist(), + "klat": rng.uniform(0.5, 2.2, size=n).tolist(), + }, + **extra, + } + + +def test_plot_qc_vs_reg_scatter_writes_png(tmp_path): + """End-to-end smoke: latent + composition results, two reg targets, expect a PNG out.""" + results = [ + _scatter_result("latent", "latent\nα=0"), + _scatter_result("latent", "latent\nα=0.25"), + _scatter_result("latent", "latent\nα=1"), + _scatter_result("composition", "comp\n(seed)"), + _scatter_result("composition", "comp\n(seed, 5% all)"), + ] + reg_targets = {"formation_energy": -2.0, "klat": 2.0} + out = tmp_path / "qc_vs_secondary_scatter.png" + _plot_qc_vs_reg_scatter(results, reg_targets, out, title="test") + assert out.exists() + + +def test_plot_qc_vs_reg_scatter_handles_single_target(tmp_path): + """One reg-target = one panel; still must render without grid-shape errors.""" + results = [ + _scatter_result("latent", "latent\nα=1"), + _scatter_result("composition", "comp\n(seed)"), + ] + out = tmp_path / "qc_single.png" + _plot_qc_vs_reg_scatter(results, {"klat": 2.0}, out, title="single target") + assert out.exists() + + +def test_plot_qc_vs_reg_scatter_skips_on_empty_results(tmp_path): + out = tmp_path / "should_not_exist.png" + _plot_qc_vs_reg_scatter([], {"klat": 2.0}, out, title="empty") + assert not out.exists() + + +def test_plot_qc_vs_reg_scatter_skips_on_empty_reg_targets(tmp_path): + """No reg-targets ⇒ nothing to plot; the helper must not write a degenerate figure.""" + results = [_scatter_result("latent", "latent\nα=1")] + out = tmp_path / "should_not_exist.png" + _plot_qc_vs_reg_scatter(results, {}, out, title="no targets") + assert not out.exists() + + +def test_plot_qc_vs_reg_scatter_with_seed_layer(tmp_path): + """Optional ``seed_qc`` / ``seed_reg`` draw the per-seed baseline as orange ★ stars. + + Verifies the figure still renders (the layer is added before the optimised clouds and + drops cleanly when the kwarg is omitted — see the no-arg test above). + """ + results = [ + _scatter_result("latent", "latent\nα=1"), + _scatter_result("composition", "comp\n(seed)"), + ] + reg_targets = {"formation_energy": -2.0, "klat": 2.0} + n_seeds = 5 + rng = np.random.default_rng(123) + out = tmp_path / "qc_with_seeds.png" + _plot_qc_vs_reg_scatter( + results, + reg_targets, + out, + title="with seeds", + seed_qc=rng.uniform(0.1, 0.6, size=n_seeds), + seed_reg={ + "formation_energy": rng.uniform(0.5, 2.5, size=n_seeds), + "klat": rng.uniform(-0.5, 1.0, size=n_seeds), + }, + ) + assert out.exists() diff --git a/src/foundation_model/scripts/paper_inverse_trajectory.py b/src/foundation_model/scripts/paper_inverse_trajectory.py new file mode 100644 index 0000000..e02139e --- /dev/null +++ b/src/foundation_model/scripts/paper_inverse_trajectory.py @@ -0,0 +1,472 @@ +# Copyright 2026 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +"""Per-step trajectory analytics + plots + animations for inverse-design runs. + +Each call to :meth:`FlexibleMultiTaskModel.optimize_latent` / +:meth:`FlexibleMultiTaskModel.optimize_composition` can now optionally record: + +* ``trajectory_targets`` — shape ``(steps, B, T)``: per-step predicted target values + (one column per regression task in ``reg_targets`` order; QC is separate). +* ``trajectory_weights`` — shape ``(steps, B, n_components)``: per-step element weights + (the optimisation variable for ``optimize_composition``; decoded via ``KMD.inverse`` from the + per-step AE-decoded ``x`` for ``optimize_latent``). + +Together with the per-step QC trajectory (also collected from the raw target predictions for +the QC head), these are enough to visualise: + +1. How fast each target converges relative to the others (static line plot, normalised so all + targets are on the same y-axis). +2. How the recipe evolves across the optimisation (animated bar chart of the per-step composition + on the side, frame per step). + +This module hosts the pure helpers; ``paper_inverse_comparison.run()`` is the only caller. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from pathlib import Path +from typing import Any, Iterable + +import matplotlib + +matplotlib.use("Agg") + +import matplotlib.animation as manimation +import matplotlib.pyplot as plt +import numpy as np +from loguru import logger + + +# --- representative-seed picker ----------------------------------------------------------------- + + +def best_seed_by_target_distance( + qc_final: np.ndarray, + reg_final: dict[str, np.ndarray], + reg_targets: Mapping[str, float], +) -> int: + """Pick the seed whose final state minimises the joint normalised distance to all targets. + + "Joint distance" = $\\sqrt{(1 - \\text{QC})^2 + \\sum_t ((y_t - \\text{target}_t) / s_t)^2}$ + where $s_t$ is the per-task scale (we use the absolute target value as a stand-in so each + task contributes on a comparable scale; a target of ±2 σ in z-scored space gives a scale of 2). + + The QC term uses ``1 - QC`` so closer-to-1 wins; the regression terms use signed deviation + so an under-shoot and an over-shoot are penalised equally. + """ + qc_final = np.asarray(qc_final, dtype=float) + n = qc_final.shape[0] + if n == 0: + raise ValueError("best_seed_by_target_distance: empty qc_final array.") + dist_sq = (1.0 - qc_final) ** 2 + for task, target in reg_targets.items(): + scale = max(abs(float(target)), 1.0) # avoid divide-by-zero if target == 0 + vals = np.asarray(reg_final[task], dtype=float) + dist_sq = dist_sq + ((vals - float(target)) / scale) ** 2 + return int(np.argmin(dist_sq)) + + +# --- trajectory normalisation ----------------------------------------------------------------- + + +def normalize_target_trajectories( + qc_trajectory: np.ndarray, + reg_trajectory: dict[str, np.ndarray], + reg_targets: Mapping[str, float], + seed_qc: np.ndarray, + seed_reg: Mapping[str, np.ndarray], +) -> dict[str, np.ndarray]: + """Map per-step target predictions to a [0, 1] "progress" fraction. + + For each target, 0 = "at seed baseline", 1 = "exactly at target". Values can exceed [0, 1] + if the optimiser overshoots. The transform is per-(task, seed): for seed *i* we compute + ``(y[step, i] - baseline[i]) / (target - baseline[i])`` so a noisy seed-to-seed baseline + doesn't dilute the average. After per-seed normalisation we mean over seeds so the static + plot shows the average progress across the seed cohort. + + Returns: dict ``{"QC": (steps,), task_name: (steps,)}`` of mean progress values. + """ + out: dict[str, np.ndarray] = {} + + # QC always targets 1.0. + qc_baseline = np.asarray(seed_qc, dtype=float) # (B,) + qc_target = 1.0 + qc_denom = qc_target - qc_baseline + qc_denom = np.where(np.abs(qc_denom) < 1e-9, 1.0, qc_denom) # protect against /0 + qc_progress = (np.asarray(qc_trajectory, dtype=float) - qc_baseline[None, :]) / qc_denom[None, :] + out["QC"] = qc_progress.mean(axis=1) + + for task, target in reg_targets.items(): + baseline = np.asarray(seed_reg[task], dtype=float) # (B,) + denom = float(target) - baseline + denom = np.where(np.abs(denom) < 1e-9, 1.0, denom) + traj = np.asarray(reg_trajectory[task], dtype=float) # (steps, B) + progress = (traj - baseline[None, :]) / denom[None, :] + out[task] = progress.mean(axis=1) + + return out + + +# --- static plot ------------------------------------------------------------------------------- + + +_TARGET_COLOR_QC = "#C44E52" # red — matches the target lines used elsewhere +_TARGET_COLORS_REG = ["#2563EB", "#55A868", "#E67E22", "#9467bd"] # blue / green / orange / purple + + +def plot_trajectory_static( + progress: Mapping[str, np.ndarray], + out_path: Path, + *, + title: str, + seed_composition: str | None = None, +) -> None: + """Line plot of normalised progress vs step. + + QC is drawn in red; the regression tasks cycle through the project's blue / green / orange + palette. The y-axis is "progress fraction" (0 = at seed, 1 = at target); a horizontal dashed + line at 1.0 marks the joint target. The reader gets a one-glance answer to the question the + user asked: "do the targets converge together, or does the recipe stabilise early and the + targets keep moving?" — divergence between the QC line and the reg lines, or between the reg + lines themselves, surfaces immediately. + + When ``seed_composition`` is provided (the per-seed composition string, e.g. + ``"Au65 Ga20 Gd15"``), it's appended to the figure title under the main title in a monospace + font — the reader can identify the seed by chemistry rather than by index. + """ + fig, ax = plt.subplots(figsize=(8.0, 5.0), dpi=150) + steps = np.arange(len(next(iter(progress.values())))) + + # QC first so it's visually behind the reg lines (the user usually cares about reg + # convergence; QC's behavior is rarely surprising). + if "QC" in progress: + ax.plot(steps, progress["QC"], color=_TARGET_COLOR_QC, lw=2.0, label="QC (P(quasicrystal))") + reg_keys = [k for k in progress if k != "QC"] + for i, key in enumerate(reg_keys): + ax.plot( + steps, + progress[key], + color=_TARGET_COLORS_REG[i % len(_TARGET_COLORS_REG)], + lw=1.8, + label=key, + ) + + ax.axhline(1.0, color="#666", ls="--", lw=1.0, alpha=0.7, label="target (progress = 1.0)") + ax.axhline(0.0, color="#bbb", ls=":", lw=0.8, alpha=0.5) + ax.set_xlabel("Optimisation step") + ax.set_ylabel("Progress (0 = seed, 1 = target)") + if seed_composition: + # Two-line layout: bold main title on top + seed composition underneath, with extra + # ``pad`` so the title doesn't sit flush against the upper axes line. Putting the + # seed-comp as a text annotation at y=1.02 collided with the title when matplotlib's + # default title-pad was applied — fix is to render both lines via set_title and a + # second matching text() at a clearly-distinct y position. + ax.set_title(title, fontsize=12, fontweight="bold", pad=22) + ax.text( + 0.5, 1.005, f"seed: {seed_composition}", + transform=ax.transAxes, ha="center", va="bottom", + fontsize=10, family="monospace", color="#444", + ) + else: + ax.set_title(title, fontsize=12, fontweight="bold") + ax.legend(loc="best", fontsize=9, frameon=False) + ax.grid(True, alpha=0.2) + fig.tight_layout() + fig.savefig(out_path, bbox_inches="tight", facecolor="white") + plt.close(fig) + logger.info(f"Wrote trajectory static plot to {out_path}") + + +# --- animation --------------------------------------------------------------------------------- + + +def _topk_composition_frame(weights: np.ndarray, element_symbols: list[str], top_k: int = 10) -> list[tuple[str, float]]: + """Top-K elements by weight, sorted descending. Used as one frame of the animation's comp panel.""" + idx = np.argsort(weights)[::-1][:top_k] + return [(element_symbols[int(i)], float(weights[int(i)])) for i in idx if weights[int(i)] > 1e-4] + + +def plot_trajectory_animation( + progress: Mapping[str, np.ndarray], + per_step_weights: np.ndarray, + element_symbols: list[str], + out_paths_by_format: Mapping[str, Path], + *, + title: str, + seed_composition: str | None = None, + top_k_elements: int = 10, + fps: int = 15, + max_frames: int = 120, +) -> None: + """Targets-vs-step line plot (top panel) + per-step top-K element bar chart (right panel). + + The line plot draws the full curve from step 0; a vertical "current step" marker advances + one tick per frame. The bar chart on the right re-draws each frame to show the current + composition's top-K elements (so the viewer can see "what is the recipe right now?" as the + targets evolve). For long runs (steps > ``max_frames``) we subsample uniformly so the GIF + stays under a few seconds at fps=15. + + Writers: + - ``gif`` → ``PillowWriter`` (no external deps; embeddable anywhere). + - ``html`` → ``HTMLWriter`` (JS-controlled play/pause/scrub; great for inspection). + - ``svg`` → custom SMIL-animated single-file SVG (browsers play it; PPT cannot embed). + """ + n_steps = len(next(iter(progress.values()))) + if n_steps == 0: + logger.warning("plot_trajectory_animation: empty progress arrays — skipping.") + return + if per_step_weights.shape[0] != n_steps: + logger.warning( + f"plot_trajectory_animation: per_step_weights step count ({per_step_weights.shape[0]}) " + f"does not match progress step count ({n_steps}); skipping animation." + ) + return + + # Uniform subsample down to ``max_frames`` so GIFs stay manageable. The line plot still uses + # the full curve; only the marker / weights frames are subsampled. + frame_steps = np.linspace(0, n_steps - 1, num=min(n_steps, max_frames)).astype(int) + frame_steps = np.unique(frame_steps) # in case of duplicate indices for very small n_steps + + fig = plt.figure(figsize=(12.0, 5.5), dpi=120) + gs = fig.add_gridspec(1, 2, width_ratios=[2.0, 1.0], wspace=0.30) + ax_line = fig.add_subplot(gs[0, 0]) + ax_bar = fig.add_subplot(gs[0, 1]) + + # --- Static line plot in left panel --- + steps = np.arange(n_steps) + if "QC" in progress: + ax_line.plot(steps, progress["QC"], color=_TARGET_COLOR_QC, lw=2.0, label="QC (P(quasicrystal))") + for i, key in enumerate([k for k in progress if k != "QC"]): + ax_line.plot( + steps, + progress[key], + color=_TARGET_COLORS_REG[i % len(_TARGET_COLORS_REG)], + lw=1.8, + label=key, + ) + ax_line.axhline(1.0, color="#666", ls="--", lw=1.0, alpha=0.6) + ax_line.axhline(0.0, color="#bbb", ls=":", lw=0.8, alpha=0.5) + ax_line.set_xlabel("Optimisation step") + ax_line.set_ylabel("Progress (0 = seed, 1 = target)") + if seed_composition: + # Two-line title: bold panel title on top + monospace seed-composition underneath. The + # ``pad=22`` lifts the title clear of the second line; without the pad they overlap + # because matplotlib's default title baseline sits where the text annotation lands. + ax_line.set_title(title, fontsize=11, fontweight="bold", pad=22) + ax_line.text( + 0.5, 1.005, f"seed: {seed_composition}", + transform=ax_line.transAxes, ha="center", va="bottom", + fontsize=10, family="monospace", color="#444", + ) + else: + ax_line.set_title(title, fontsize=11, fontweight="bold") + ax_line.legend(loc="best", fontsize=8, frameon=False) + ax_line.grid(True, alpha=0.2) + marker = ax_line.axvline(0, color="#444", lw=1.2, alpha=0.85) + + # --- Bar chart in right panel (redrawn per frame) --- + ax_bar.set_title("Composition (top-K by weight)", fontsize=10) + ax_bar.set_xlim(0, 1.0) + ax_bar.set_xlabel("weight") + + def _draw_bar(step_idx: int) -> None: + ax_bar.clear() + frame = _topk_composition_frame(per_step_weights[step_idx], element_symbols, top_k=top_k_elements) + if not frame: + ax_bar.text(0.5, 0.5, "(no elements above threshold)", ha="center", va="center", transform=ax_bar.transAxes) + else: + symbols, weights = zip(*frame) + y_pos = np.arange(len(symbols)) + ax_bar.barh(y_pos, weights, color="#2563EB", alpha=0.75, edgecolor="#222", linewidth=0.5) + ax_bar.set_yticks(y_pos) + ax_bar.set_yticklabels(symbols, fontsize=9) + ax_bar.invert_yaxis() # largest on top + ax_bar.set_xlim(0, max(0.5, float(per_step_weights[step_idx].max()) * 1.1)) + ax_bar.set_xlabel("weight") + ax_bar.set_title(f"Composition (step {step_idx + 1}/{n_steps})", fontsize=10) + ax_bar.grid(True, axis="x", alpha=0.2) + + def _init() -> Iterable[Any]: + _draw_bar(int(frame_steps[0])) + marker.set_xdata([int(frame_steps[0])]) + return (marker,) + + def _update(frame_idx: int) -> Iterable[Any]: + step_idx = int(frame_steps[frame_idx]) + _draw_bar(step_idx) + marker.set_xdata([step_idx]) + return (marker,) + + # Only build the matplotlib FuncAnimation when at least one matplotlib-native format + # (gif / html) is requested. For svg-only output we render a handwritten SMIL SVG without + # touching the animation object — building it anyway would emit a "Animation was deleted + # without rendering anything" UserWarning on test runs. + needs_mpl_anim = any(fmt in ("gif", "html") for fmt in out_paths_by_format) + anim = ( + manimation.FuncAnimation( + fig, + _update, + frames=len(frame_steps), + init_func=_init, + interval=1000 // fps, + blit=False, # the bar chart redraw isn't blittable cleanly + ) + if needs_mpl_anim + else None + ) + + for fmt, out_path in out_paths_by_format.items(): + try: + if fmt == "gif": + anim.save(str(out_path), writer=manimation.PillowWriter(fps=fps)) + elif fmt == "html": + # ``to_jshtml`` returns a single self-contained HTML string with frames embedded + # as base64 PNGs. The ``HTMLWriter`` alternative drops a separate ``*_frames/`` + # folder of 120+ PNGs alongside, which clutters the output dir and makes the + # artefact non-portable. The base64 version is bigger per-file (~3 MB vs the + # multi-file's ~10 MB total) but is one self-contained file. + out_path.write_text(anim.to_jshtml(fps=fps), encoding="utf-8") + elif fmt == "svg": + _save_smil_svg(progress, per_step_weights, element_symbols, frame_steps, out_path, title=title, fps=fps) + else: + logger.warning(f"plot_trajectory_animation: unknown format {fmt!r} — skipping.") + continue + logger.info(f"Wrote trajectory animation ({fmt}) to {out_path}") + except Exception as exc: # pragma: no cover (writer-specific failure modes) + logger.warning(f"plot_trajectory_animation: failed to write {fmt} → {out_path}: {exc}") + + plt.close(fig) + + +# --- SMIL SVG writer --------------------------------------------------------------------------- + + +def _save_smil_svg( + progress: Mapping[str, np.ndarray], + per_step_weights: np.ndarray, + element_symbols: list[str], + frame_steps: np.ndarray, + out_path: Path, + *, + title: str, + fps: int, + top_k_elements: int = 10, +) -> None: + """Single-file SMIL-animated SVG. + + matplotlib doesn't have a native SVG-animation writer; rather than render N PNGs and ship a + multi-frame SVG (would defeat the "one file" goal), we emit a compact handwritten SVG with + the static line plot as a vector overlay + ```` tags for the per-step marker and + per-element bar widths. Plays in any modern browser (Firefox / Chrome / Safari); PowerPoint + and Keynote cannot embed it directly — for those use the GIF. + """ + n_steps = len(next(iter(progress.values()))) + duration_s = max(1.0, len(frame_steps) / fps) + # Coordinate system: 800 × 400 viewBox, line plot in [40, 480] × [40, 360], bar plot in + # [520, 780] × [40, 360]. Bars are horizontal, top-K elements, redrawn via . + + # ---- header ---- + parts: list[str] = [] + parts.append( + '' + ) + parts.append(f'{title}') + parts.append(f'{title}') + + # ---- line plot (static) ---- + parts.append('') + parts.append('Optimisation step') + parts.append( + '' + "Progress (0 = seed, 1 = target)" + ) + + # Compute y-range across all curves so 0 and 1 are at fixed pixels. + all_vals = np.concatenate([np.asarray(v) for v in progress.values()]) + y_min, y_max = float(min(all_vals.min(), 0.0)), float(max(all_vals.max(), 1.0)) + y_pad = (y_max - y_min) * 0.05 + y_min -= y_pad + y_max += y_pad + + def _to_x(step_idx: int) -> float: + return 40 + (step_idx / max(n_steps - 1, 1)) * 440 + + def _to_y(val: float) -> float: + return 360 - (val - y_min) / (y_max - y_min) * 320 + + # Static gridlines + axis labels at 0 / 1. + y0, y1 = _to_y(0.0), _to_y(1.0) + parts.append(f'') + parts.append(f'') + parts.append(f'0') + parts.append(f'1') + + color_map = {"QC": _TARGET_COLOR_QC} + reg_keys = [k for k in progress if k != "QC"] + for i, key in enumerate(reg_keys): + color_map[key] = _TARGET_COLORS_REG[i % len(_TARGET_COLORS_REG)] + + # Polyline per target. + legend_y = 50 + for key, vals in progress.items(): + pts = " ".join(f"{_to_x(s):.1f},{_to_y(float(v)):.1f}" for s, v in enumerate(vals)) + color = color_map[key] + parts.append(f'') + parts.append(f'') + parts.append(f'{key}') + legend_y += 14 + + # ---- animated step marker (vertical line in line plot) ---- + x_values_str = ";".join(f"{_to_x(int(s)):.1f}" for s in frame_steps) + parts.append( + f'' + f' ' + f' ' + f"" + ) + + # ---- bar chart (top-K, animated per element) ---- + parts.append('') + parts.append('Composition (top-K, step animated)') + + # Use the union of top-K-per-frame elements across all frames so each bar is one stable row. + seen_idx: list[int] = [] + for s in frame_steps: + top = np.argsort(per_step_weights[int(s)])[::-1][:top_k_elements] + for idx in top: + if int(idx) not in seen_idx: + seen_idx.append(int(idx)) + # Cap at 2× top_k to keep the SVG tidy. + seen_idx = seen_idx[: 2 * top_k_elements] + n_rows = len(seen_idx) + bar_y_top = 50 + bar_height = min(20.0, 300.0 / max(n_rows, 1)) + bar_x_left = 560 + bar_max_w = 200 + + # Per-bar animation values (weight per frame, scaled). + for row_i, elem_idx in enumerate(seen_idx): + y_row = bar_y_top + row_i * bar_height + widths = [per_step_weights[int(s), elem_idx] for s in frame_steps] + w_str = ";".join(f"{max(0.0, float(w)) * bar_max_w:.1f}" for w in widths) + parts.append(f'{element_symbols[elem_idx]}') + parts.append( + f'' + f' ' + f"" + ) + + # Step counter at the bottom. + step_label_values = ";".join(f"step {int(s) + 1}/{n_steps}" for s in frame_steps) + parts.append( + f'' + f' step 1/{n_steps}' + f' ' + f"" + ) + + parts.append("") + out_path.write_text("\n".join(parts), encoding="utf-8") diff --git a/src/foundation_model/scripts/paper_inverse_trajectory_test.py b/src/foundation_model/scripts/paper_inverse_trajectory_test.py new file mode 100644 index 0000000..e5f152c --- /dev/null +++ b/src/foundation_model/scripts/paper_inverse_trajectory_test.py @@ -0,0 +1,192 @@ +# Copyright 2026 TsumiNa. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the pure helpers in :mod:`paper_inverse_trajectory`. + +The full ``_emit_trajectory_outputs`` orchestrator needs a real trained checkpoint to exercise; +this file covers the pure functions — seed picker, progress normalisation, and the writer +smoke-tests (static plot + gif + html + svg). Animations are checked only for "file got written"; +visual correctness is verified by inspecting the rerun artefacts. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from foundation_model.scripts.paper_inverse_trajectory import ( + best_seed_by_target_distance, + normalize_target_trajectories, + plot_trajectory_animation, + plot_trajectory_static, +) + + +# --- best_seed_by_target_distance -------------------------------------------------------------- + + +def test_best_seed_picks_closest_joint_distance_to_targets(): + """Seed 1 is closest to the joint target (QC=1, fe=-2, klat=2); the picker should return 1.""" + qc = np.array([0.20, 0.95, 0.50]) # seed 1 has highest QC + reg = { + "formation_energy": np.array([+0.5, -1.9, -1.0]), # seed 1 hits target -2 best + "klat": np.array([0.0, 1.8, 1.2]), # seed 1 hits target 2 best + } + reg_targets = {"formation_energy": -2.0, "klat": 2.0} + assert best_seed_by_target_distance(qc, reg, reg_targets) == 1 + + +def test_best_seed_handles_zero_target_without_div_by_zero(): + """``target == 0`` would naively divide by zero; the picker uses a min-scale guard.""" + qc = np.array([0.9, 0.8]) + reg = {"some_task": np.array([0.1, 0.5])} + # Should pick seed 0 (closer to target 0). + assert best_seed_by_target_distance(qc, reg, {"some_task": 0.0}) == 0 + + +def test_best_seed_empty_qc_raises(): + with pytest.raises(ValueError, match="empty qc_final"): + best_seed_by_target_distance(np.array([]), {}, {}) + + +# --- normalize_target_trajectories ------------------------------------------------------------- + + +def test_normalize_trajectory_maps_baseline_to_zero_and_target_to_one(): + """Per (task, seed): a step's value of (target - baseline) + baseline ⇒ progress = 1.""" + n_steps = 4 + n_seeds = 2 + # One reg target only. Baseline = [0.0, 0.5], target = 2.0. + reg_targets = {"k": 2.0} + seed_reg = {"k": np.array([0.0, 0.5])} + # Per-seed trajectory: linear interpolation from baseline → target across 4 steps. + traj_k = np.stack( + [ + np.linspace(0.0, 2.0, n_steps), # seed 0 + np.linspace(0.5, 2.0, n_steps), # seed 1 + ], + axis=1, + ) # (steps, B) + # QC trajectory: flat at the seed baseline so it normalises to 0 progress throughout. + seed_qc = np.array([0.1, 0.2]) + qc_traj = np.tile(seed_qc[None, :], (n_steps, 1)) + + progress = normalize_target_trajectories( + qc_trajectory=qc_traj, + reg_trajectory={"k": traj_k}, + reg_targets=reg_targets, + seed_qc=seed_qc, + seed_reg=seed_reg, + ) + # k progress: starts at 0, ends at 1 (per-seed normalised then mean over B). + assert progress["k"].shape == (n_steps,) + assert progress["k"][0] == pytest.approx(0.0, abs=1e-9) + assert progress["k"][-1] == pytest.approx(1.0, abs=1e-9) + # QC stays at baseline ⇒ progress = 0 throughout. + assert progress["QC"].shape == (n_steps,) + assert np.allclose(progress["QC"], 0.0) + + +# --- plot writers ------------------------------------------------------------------------------ + + +def _toy_progress() -> dict[str, np.ndarray]: + """4-target × 30-step normalised progress, monotone so the picture is interpretable.""" + n = 30 + return { + "QC": np.clip(np.linspace(0.0, 0.95, n) + 0.02 * np.sin(np.linspace(0, 4 * np.pi, n)), 0, 1.5), + "formation_energy": np.linspace(0.0, 1.2, n), + "klat": np.linspace(0.0, 0.8, n), + } + + +def _toy_weights(n_steps: int = 30, n_components: int = 12) -> np.ndarray: + """(steps, n_components) toy weights — start sparse, drift toward a different sparse set.""" + rng = np.random.default_rng(7) + w = np.zeros((n_steps, n_components), dtype=float) + # Initial: mass on elements 0..2 + w[0, :3] = [0.5, 0.3, 0.2] + # Final: mass on elements 4, 6, 7 + end = np.zeros(n_components) + end[4], end[6], end[7] = 0.5, 0.3, 0.2 + for s in range(n_steps): + t = s / (n_steps - 1) + w[s] = (1 - t) * w[0] + t * end + 0.001 * rng.standard_normal(n_components) + w[s] = np.clip(w[s], 0, None) + w[s] /= w[s].sum() + return w + + +def test_plot_trajectory_static_writes_png(tmp_path): + out = tmp_path / "static.png" + plot_trajectory_static(_toy_progress(), out, title="toy trajectory") + assert out.exists() + + +def test_plot_trajectory_static_with_seed_composition(tmp_path): + """``seed_composition`` is rendered as a monospace annotation under the title — verify the + plot still writes with the kwarg present (visual correctness is by inspection).""" + out = tmp_path / "static_with_seed.png" + plot_trajectory_static( + _toy_progress(), out, title="toy trajectory", seed_composition="Au65 Ga20 Gd15" + ) + assert out.exists() + + +def test_plot_trajectory_animation_writes_gif(tmp_path): + out = tmp_path / "anim.gif" + plot_trajectory_animation( + _toy_progress(), + per_step_weights=_toy_weights(), + element_symbols=[f"E{i}" for i in range(12)], + out_paths_by_format={"gif": out}, + title="toy animation", + max_frames=10, # keep test fast + ) + assert out.exists() + + +def test_plot_trajectory_animation_writes_html(tmp_path): + out = tmp_path / "anim.html" + plot_trajectory_animation( + _toy_progress(), + per_step_weights=_toy_weights(), + element_symbols=[f"E{i}" for i in range(12)], + out_paths_by_format={"html": out}, + title="toy animation", + max_frames=10, + ) + assert out.exists() + + +def test_plot_trajectory_animation_writes_smil_svg(tmp_path): + out = tmp_path / "anim.svg" + plot_trajectory_animation( + _toy_progress(), + per_step_weights=_toy_weights(), + element_symbols=[f"E{i}" for i in range(12)], + out_paths_by_format={"svg": out}, + title="toy animation", + max_frames=8, + ) + assert out.exists() + body = out.read_text(encoding="utf-8") + # The SMIL animation should contain tags driving the marker x1/x2 + bar widths. + assert "