diff --git a/.gitignore b/.gitignore index 54b1e14..bb3ab17 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,18 @@ lightning_logs/ /notebooks/logs /artifacts +# Claude Code workspace +.claude/scheduled_tasks.lock +.claude/worktrees/ + +# Local evaluation outputs — the eval *.py scripts under logs/ are tracked +# for reproducibility, but their generated artefacts (raw logs, JSON, PNGs) +# are not. +logs/*.log +logs/*.json +logs/*.png +logs/__pycache__/ + # IDE .idea/ .vscode/ diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 3d37bbd..ca6b549 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -259,6 +259,14 @@ User-facing knobs (all on `[0, 1]` where applicable): - `allowed_elements` — hard whitelist over element symbols. - `element_step_scale` — per-element gradient scaling; `0` hard-locks an element to its seed value. - `class_target_weight` — weight on the classification objective vs. the regression targets. +- `max_elements` — cardinality cap "at most K elements per recipe", enforced by a + differentiable Plötz–Roth iterative soft top-K mask + final hard projection. +- `annealing_scale` ∈ `[0, 1]` (default 0.5) — single-knob softness of the K-hot annealing + schedule; maps to `τ_start = 25**scale`. Advanced override via the `annealing_schedule` dict. +- `fixed_amounts` — pin specific elements at user-given absolute amounts (e.g. + `{"Au": 0.65}`); reuses the lock-paste path, no `initial_weights` required. +- `min_nonzero_weight` — drop unlocked positions below this floor (per-row safe fallback + keeps the simplex valid). ```mermaid graph TD diff --git a/README.md b/README.md index 14affb7..9c3a26f 100644 --- a/README.md +++ b/README.md @@ -329,6 +329,20 @@ entry points on the model: | `optimize_latent(optimize_space="latent")` | the latent $h$ | no — needs AE decode | `ae_align_scale ∈ [0, 1]` (default 0.5; pulls $h$ onto the AE manifold) | | `optimize_composition` | element-weight logits $\theta$, with $w = \text{softmax}(\theta)$ | yes — $w$ is the recipe | `diversity_scale ∈ [0, 1]` (default 1.0; per-output entropy penalty) | +`optimize_composition` further accepts an orthogonal constraint surface (full docstrings on +the method; design notes in +[docs/inverse_design_extension_notes.md](docs/inverse_design_extension_notes.md)): + +- `max_elements: int` — cardinality cap (at most K non-zero elements per recipe), enforced + through a differentiable iterative-softmax K-hot mask with a single `annealing_scale ∈ [0, 1]` + softness knob (default 0.5 = the calibrated safe choice). +- `fixed_amounts: {symbol: float}` — pin specific elements at user-given absolute amounts + (e.g. `{"Au": 0.65, "Ga": 0.20}`); the optimiser distributes the remaining mass freely. +- `min_nonzero_weight: float` — reject trace-amount appearances (e.g. drop anything below + 10 %), with safe-fallback so the simplex invariant is always preserved. + +All three compose orthogonally with each other and with `allowed_elements` / `element_step_scale`. + Both methods share the same regression-MSE + classification-cross-entropy backbone; only the third loss term and the optimisation variable differ. **Reference:** [docs/inverse_design_algorithms.md](docs/inverse_design_algorithms.md). diff --git a/docs/inverse_design_algorithms.md b/docs/inverse_design_algorithms.md index fea443b..04a2619 100644 --- a/docs/inverse_design_algorithms.md +++ b/docs/inverse_design_algorithms.md @@ -88,6 +88,9 @@ with | **`allowed_elements` whitelist** | Masks the logits of disallowed elements to $-\infty$ before every softmax step. | Restrict the search to physically realisable elements (e.g. `ALLOY_PALETTE`, 41 symbols), suppressing model biases toward Pu / F / Cs / etc. | | **`element_step_scale` soft-freeze / hard-lock** | Soft: multiply the element's logit gradient by the scale before each Adam step. Hard (value = 0): rewrite the softmax output to paste seed values back at locked positions and renormalise unlocked positions over the remaining mass. | Let the user pin certain elements to their seed values ("keep the Au-Ga-RE skeleton; you may only change the rare-earth ratios"). | | **`seed_blend` mixture** | $w_0 \leftarrow \text{seed\_blend} \cdot \text{seed} + (1 - \text{seed\_blend}) \cdot \text{uniform}_{\text{allowed}}$ | Don't start from a 100 % seed (5 % uniform mass lifts every allowed element's logit from $\log(10^{-12}) \approx -27.6$ to $\log(0.05 / \lvert\text{allowed}\rvert) \approx -7.6$, so Adam can introduce new elements within a few hundred steps — this is the **element-discovery** mechanism). | +| **`max_elements` cardinality cap** | Plötz–Roth iterative-softmax K-hot mask $m \in [0, 1]^n$ with $\sum_i m_i = K$, multiplied with `softmax(θ)` and renormalised; temperature $\tau$ annealed from $\tau_\text{start} = 25^{\text{annealing\_scale}}$ down to $\tau_\text{end} = 0.01$ (geometric by default). A hard top-K projection at the end guarantees exactly K-hot (subject to floor below). | Restrict recipes to **at most K non-zero elements** (e.g. "I want a 3-element alloy"). The annealing doubles as a continuation method — the soft τ early on lets the optimiser explore different K-subsets before committing. | +| **`fixed_amounts` user-pin** | Build $\text{locked\_w}_0$ with user-specified values at the named positions, zero elsewhere; reuse the existing lock-paste machinery (no `initial_weights` required since values are given directly). | Pin specific elements at user-given absolute amounts (e.g. `{"Au": 0.65, "Ga": 0.20}` — the optimiser distributes the remaining 0.15 mass across other allowed elements). | +| **`min_nonzero_weight` floor** | After lock-paste, zero unlocked positions with $0 < w < \text{floor}$ and renormalise the unlocked portion to fit the free mass; safe-fallback when dropping would empty a row (leave that row unfloored). | Reject trace-amount appearances (e.g. `Pt = 0.5 %`) that are not synthesisable — "if you use it, use ≥ 10 %". | ### What each loss term is for @@ -106,6 +109,11 @@ with | `seed_blend` | $[0, 1]$ | 0.95 | Fraction of seed kept at the start (the rest is uniform, so new elements can enter). | | `allowed_elements` | symbol list or `"all"` | `"all"` | Element whitelist (hard constraint). | | `element_step_scale` | float or `{symbol: float}` | 1.0 | Per-element step scaling; `0` = hard-lock to the seed value. | +| `max_elements` | `int` ∈ $[1, n]$ or `None` | `None` | Cardinality cap — at most K non-zero elements (differentiable soft top-K + final hard projection). | +| `annealing_scale` | $[0, 1]$ | 0.5 | Single-knob softness for the K-hot schedule; maps to $\tau_\text{start} = 25^{\text{scale}}$. | +| `annealing_schedule` | dict or `None` | `None` | Advanced piecewise override of the annealing schedule. | +| `fixed_amounts` | `{symbol: float}` or `None` | `None` | Pin elements at user-specified amounts (e.g. `{"Au": 0.65}`); needs $\sum < 1$. | +| `min_nonzero_weight` | $[0, 1]$ | 0.0 | Drop unlocked positions below this floor (and re-distribute mass). | | `steps`, `lr` | — | 300, 0.05 | Adam optimisation budget over the logits. | --- @@ -118,7 +126,7 @@ with | **Where the reported recipe comes from** | $w_{\text{report}}$ inferred from $D(h)$ (an extra AE-decode step) | $w$ itself is the report | | **Method-specific loss term** | $\alpha \cdot \lVert h - \tanh(E(D(h))) \rVert^2$ (keeps $h$ on the AE manifold) | $(1 - d) \cdot H(w)$ (controls per-solution peakiness) | | **Failure mode** | $\alpha = 0$: $h$ drifts off the manifold, decoded recipe unphysical (QC 0.97 → 0.35). | `seed_blend = 1.0`: the seed's support set is frozen — no new elements can ever appear. | -| **Method-specific knobs** | `ae_align_scale` | `diversity_scale`, `seed_blend`, `allowed_elements`, `element_step_scale` | +| **Method-specific knobs** | `ae_align_scale` | `diversity_scale`, `seed_blend`, `allowed_elements`, `element_step_scale`, `max_elements` + `annealing_scale` / `annealing_schedule`, `fixed_amounts`, `min_nonzero_weight` | The shared backbone — (1) regression MSE + (2) classification cross-entropy — is **identical** between the two methods. They differ *only* in the third loss term and in which variable is diff --git a/docs/inverse_design_extension_notes.md b/docs/inverse_design_extension_notes.md index a900404..6875d67 100644 --- a/docs/inverse_design_extension_notes.md +++ b/docs/inverse_design_extension_notes.md @@ -1,10 +1,8 @@ # Inverse-design code map + extension notes -Snapshot of the inverse-design surface as of PR #18, written so a future session -can extend it (e.g. "exactly K elements", "fix Au at 65 %", "min-weight floor") -without having to reverse-engineer the design. See +Code map for `optimize_composition`'s constraint surface. See [inverse_design_algorithms.md](inverse_design_algorithms.md) for the *math*; this -doc is the *code map*. +doc is the *code map* and is kept in sync with the implementation. ## The two entry points @@ -17,10 +15,12 @@ Both share: regression-MSE + classification-cross-entropy backbone, opt-in `record_*_trajectory` flag for per-step capture (used by [paper_inverse_trajectory.py](../src/foundation_model/scripts/paper_inverse_trajectory.py)). -## What's already there on the composition path +## User-facing kwargs (composition path) -User-facing kwargs (validated in `optimize_composition`'s argument block at lines -~2393–2465): +Validated in `optimize_composition`'s argument block; the whole surface composes +orthogonally — any subset of A/B/C plus the existing knobs can be used together. + +### Existing knobs (PR #18) | Kwarg | Range | What it does | Implementation | |---|---|---|---| @@ -29,126 +29,118 @@ User-facing kwargs (validated in `optimize_composition`'s argument block at line | `diversity_scale` | `[0, 1]`, default 1.0 | 0 = peaky few-element; 1 = no penalty | `(1 − d) · H(w)` added to loss | | `seed_blend` | `[0, 1]`, default 0.95 | how much seed kept vs uniform-over-allowed at init | `w₀ ← s·seed + (1−s)·uniform` | | `allowed_elements` | `"all"` or symbol list | hard whitelist | logit mask to `-inf` | -| `element_step_scale` | `float` or `{symbol: float}` | soft per-element gradient scale; `0` = hard-lock to seed value | grad multiplied per element; **lock implemented in `_w_from_logits`** (line 2576) — paste seed values back over softmax + renormalise unlocked positions | - -## The single point of leverage: `_w_from_logits` inside `optimize_composition` - -[flexible_multi_task_model.py:2576-2591](../src/foundation_model/models/flexible_multi_task_model.py#L2576-L2591) - -```python -def _w_from_logits(lg: torch.Tensor) -> torch.Tensor: - """Softmax over logits; mask disallowed elements; hard-lock the chosen ones at seed.""" - w = softmax_with_mask(lg, elem_mask) # whitelist - if locked_mask is None: - return w - # rewrite locked positions to seed values + renormalise unlocked positions to fill 1 − Σ_locked - ... -``` - -**Every simplex-projection / hardening rule belongs here.** It runs once per step, -on every (B × n_components) row, and the gradient flows correctly through any -differentiable rewriting (the existing lock branch is `.detach()`-constant, so -its gradient is 0; new differentiable steps would let gradient flow naturally). -Adding a new constraint = (a) accept a new kwarg in the signature, (b) validate -it in the arg-block, (c) compute any per-step state once before the loop, (d) -apply it inside `_w_from_logits`. - -## Three extensions the user has flagged +| `element_step_scale` | `float` or `{symbol: float}` | soft per-element gradient scale; `0` = hard-lock to seed value | grad multiplied per element; lock implemented in `_w_from_logits` — paste seed values back over softmax + renormalise unlocked positions | -### A. "Specify number of elements" — top-K mass constraint +### Constraint knobs added in this PR (A / B / C) -**Use case**: "give me exactly 3-element recipes" / "at most K elements". - -**Suggested API**: -```python -optimize_composition(..., max_elements: int | None = None) -``` - -**Implementation sketch** (inside `_w_from_logits`): -```python -if max_elements is not None and max_elements < n_components: - # Top-K hardening: keep the K largest weights per row, zero the rest, renormalise. - topk_vals, topk_idx = w.topk(max_elements, dim=-1) - mask = torch.zeros_like(w).scatter_(-1, topk_idx, 1.0) - w = w * mask - w = w / w.sum(dim=-1, keepdim=True).clamp(min=1e-12) -``` - -Notes: -- `topk` returns the K largest indices — this is non-differentiable at the "K-th - vs (K+1)-th" boundary, but the gradient through the K kept values is correct. - In practice, with `diversity_scale < 1` to drive peakiness *before* the hard - cutoff, the boundary doesn't oscillate. -- Validate `1 ≤ max_elements ≤ n_components` in the arg-block. -- Tests to add: pattern after - [test_optimize_composition_element_step_scale_locks_symbols](../src/foundation_model/models/flexible_multi_task_model_test.py) - — assert that `(w > 1e-6).sum(dim=-1) <= max_elements` for every row of the - output. - -### B. "Fix Au at exactly 65 %" — explicit fixed-amount API - -**Use case**: chemistry-driven prior says "I want exactly 65 % Au, 20 % Ga, -optimiser picks the remaining 15 % freely". +| Kwarg | Range | What it does | Implementation | +|---|---|---|---| +| `max_elements` (A) | `int` ∈ `[1, n_components]` or `None` | "at most K non-zero elements" cardinality cap | differentiable Plötz–Roth iterative soft top-K mask multiplies `softmax(lg)` inside `_w_from_logits`; final hard top-K projection at the very end so the returned recipe is exactly K-hot (subject to floor C dropping below-floor positions further down) | +| `annealing_scale` (A) | `[0, 1]`, default 0.5 | single-knob "softness" of the K-hot annealing schedule; maps to `τ_start = 25**scale` (0→1, 0.5→5, 1→25) | drives the τ schedule for the soft top-K mask; default schedule is geometric from `25**scale` down to `τ_end = 0.01` | +| `annealing_schedule` (A) | `dict` or `None` | advanced piecewise override — `{"step": [...], "scale": [...], "annealing_func": [...]}` with per-segment normalised scales and interpolation funcs (`geometric`/`linear`/`cosine`/`constant`) | overrides the front of the simple schedule; if `step[-1] < 1.0`, the tail falls back to a geometric drop to `τ_end = 0.01` | +| `fixed_amounts` (B) | `{symbol: float}` or `None` | pin elements at user-specified absolute amounts (e.g. `{"Au": 0.65, "Ga": 0.20}`); does **not** require `initial_weights` | reuses the existing `locked_mask` / `locked_w0` lock-paste machinery; merged with `element_step_scale=0` locks (validated disjoint) | +| `min_nonzero_weight` (C) | `[0, 1]`, default 0.0 | drop unlocked positions with `0 < w < floor` and re-distribute the freed mass | applied at the very end of `_w_from_logits` (after lock-paste) and again after the final hard projection; locked positions are exempt; per-row fallback when the floor would empty unlocked mass — that row is left unfloored to keep the simplex valid | -**Already half-possible** via `element_step_scale = {"Au": 0.0, "Ga": 0.0}` + -seed has those amounts. But it requires constructing a seed; cleaner standalone -API: +## The pipeline inside `_w_from_logits` ```python -optimize_composition(..., fixed_amounts: Mapping[str, float] | None = None) +# Per-step pipeline (every Adam step) — see flexible_multi_task_model.py near the +# ``_w_from_logits`` definition for the live source. +lg = mask_disallowed(lg, allowed_elements) # whitelist → -inf +w_soft = softmax(lg) # natural simplex +if max_elements is not None: + m = soft_topk_mask(lg, K=max_elements, τ=current_τ) # Plötz-Roth, force-locked positions in + w = (w_soft * m) / Σ # K-hot weighted by softmax ratios +else: + w = w_soft +w = apply_lock_paste(w) # paste pinned values (B and step_scale=0) +w = apply_min_floor(w) # zero unlocked below floor, renorm ``` -**Implementation sketch**: -- Validate that `sum(fixed_amounts.values()) < 1.0` (need free mass) and each - symbol resolves in `DEFAULT_ELEMENTS`. -- Compute `fixed_w0: (n_components,)` with those positions set, zeros elsewhere. -- Reuse the existing `locked_mask` / `locked_w0` infrastructure - ([line 2557-2574](../src/foundation_model/models/flexible_multi_task_model.py#L2557-L2574)) - — basically: set `locked_mask = (fixed_w0 > 0)` and `locked_w0 = fixed_w0` for - every row in the batch, skip the "needs `initial_weights`" requirement that - the `element_step_scale=0` branch has. -- The existing `_w_from_logits` already does the right paste + renormalise; no - change needed there. - -Tests: assert `w[:, fixed_idx] ≈ fixed_amount` exactly after every step. - -### C. Min-weight floor / "if you use Au, use ≥ 10 %" - -**Use case**: avoid trace-amount appearances (`Pt = 0.5 %`) that are not -synthesisable. - -**Suggested API**: `min_nonzero_weight: float = 0.0`. After top-K (B) or simplex -projection, zero out any weight below the floor, renormalise. - -Implementation goes in the same `_w_from_logits` block, after any top-K / -locking. Same test pattern. - -## What lives where (for the future agent) - -| Concern | Location | -|---|---| -| Method docstring (the user-facing contract) | `optimize_composition` docstring, lines 2243–2370 | -| Kwarg validation | arg-block lines 2393–2465 (mirror the pattern: per-kwarg validation block + a `*_arg` local prepared for the inner loop) | -| One-time setup (locked indices, scaled steps, …) | lines 2536–2574 (before the `for _ in range(steps)` loop) | -| Per-step constraint application | `_w_from_logits` (line 2576) — single point | -| Loss term additions (entropy etc.) | inner loop, line ~2614 (`if diversity_scale < 1.0:`) | -| Per-step trajectory recording | already wired via `record_weights_trajectory` (line 2603 + 2622) — new constraints automatically reflected in the trajectory because we record post-`_w_from_logits` weights | -| Tests | [flexible_multi_task_model_test.py](../src/foundation_model/models/flexible_multi_task_model_test.py) — search `test_optimize_composition_*` (38 existing tests cover the current surface) | - -The latent path (`optimize_latent`, line 1735) is more rigid: the optimisation -variable is `h`, not a simplex, so the same constraints don't translate -naturally. Most new constraint features will only make sense for -`optimize_composition` — call this out in any new kwarg's docstring. - -## Pre-merge checklist - -When extending: keep the surgical-edits pattern. The existing PR already added a -lot; future extensions should be one-kwarg-per-PR with the validation + -`_w_from_logits` change + at least one test that pins the contract end-to-end -(input kwarg → output `w` rows satisfy the constraint). - -Reference tests to mimic: -- `test_optimize_composition_element_step_scale_locks_symbols` — - contract test for an existing constraint kwarg. -- `test_optimize_composition_runs_and_returns_simplex_weights` — smoke test - that the simplex is preserved (rows sum to 1, all ≥ 0). +After the optimisation loop, the final state additionally runs a **hard top-K +projection** + a re-paste + re-floor so the returned recipe is clean — at τ_end ≈ 0.01 +the soft state is already near K-hot, so this just cleans residual sub-threshold mass. + +## How the constraints compose + +Designed to be orthogonal — any subset can be used together. The validation enforces +the few impossible combinations up-front so a bad mix raises before model state is +touched: + +| Constraint pair | Validation | Behaviour | +|---|---|---| +| A × `allowed_elements` | `max_elements ≤ |allowed|` | only allowed positions can enter the K-hot mask | +| A × `element_step_scale=0` | `max_elements ≥ n_locked` | locked positions counted toward K; force-selected into the mask | +| A × B | `max_elements > n_locked_total` (strict — `fixed_amounts` has `Σ < 1`, leftover mass needs a free slot) | B locks count toward K | +| A × C | `min_nonzero_weight ≤ 1 / max_elements` | floor compatible with K-element simplex | +| B × C | `min(fixed_amounts.values()) ≥ floor` | floor cannot override a user pin | +| B × `element_step_scale=0` | disjoint symbol sets | one lock mechanism per element | +| C × `element_step_scale=0` | per-row locked seed values ≥ floor (runtime) | floor cannot drop a locked seed | + +Edge case (C): if dropping every below-floor position would leave a row with zero +unlocked mass, the floor is **skipped for that row only** — preserving the simplex +invariant. The "at most K" promise still holds; some rows can land below K. + +## Annealing schedule (A) + +`annealing_scale ∈ [0, 1]` is the single-knob shortcut. Internally each scale value +maps to a raw temperature via `τ = 25**scale`: + +| scale | τ_start (raw) | Calibration notes | +|---|---|---| +| 0.0 | 1.0 | minimal exploration — constraint nearly hard from step 0 | +| **0.5** | **5.0** | **default**; safe choice — QC stays within ±0.02 of unconstrained baseline across all three paper scenarios | +| 1.0 | 25.0 | max exploration — best for escaping local optima at the cost of slower QC refinement | + +The full default schedule is **geometric** from `τ_start(scale)` to `τ_end = 0.01`. For +finer control, supply `annealing_schedule = {"step": [...], "scale": [...], "annealing_func": [...]}` +— see the kwarg docstring. + +**Calibration source**: reproducible via [`logs/sweep_tau_schedule.py`](../logs/sweep_tau_schedule.py) ++ [`logs/plot_sweep.py`](../logs/plot_sweep.py) — the (scale × schedule × K) sweep on the +inverse-design fine-tuned model that placed the 0.5 default in the safe region. JSON / PNG +outputs are git-ignored; rerun the scripts to regenerate. + +## Tests + +All behaviour is contract-tested in +[`flexible_multi_task_model_test.py`](../src/foundation_model/models/flexible_multi_task_model_test.py). +Search patterns: + +- `test_optimize_composition_max_elements_*` — A's contract (≤ K, annealing, K=n no-op, + locked interaction, validation). +- `test_optimize_composition_fixed_amounts_*` — B's contract (exact pin, no-init mode, + combined with A/C, validation). +- `test_optimize_composition_min_nonzero_weight_*` — C's contract (≥ floor, no-op at 0, + fallback, validation against fixed_amounts / step_scale=0 locks). +- `test_optimize_composition_annealing_*` — annealing knob endpoints + dict override. + +Reference contract: `test_optimize_composition_runs_and_returns_simplex_weights` (rows +sum to 1, non-negative — invariant across every combination of the constraints). + +## End-to-end behavioural evidence (reproducible scripts) + +All evaluation outputs are git-ignored; rerun the scripts below to regenerate them. + +- [`logs/eval_abc_intuition.py`](../logs/eval_abc_intuition.py) + + [`logs/plot_abc_intuition.py`](../logs/plot_abc_intuition.py) — 80+ runs on the + inverse-design fine-tuned model across A/B/C and their combinations × two paper + scenarios. Verifies every contract (≤K, exact pins, ≥ floor) and prints PASS/FAIL + per intuition check. Original run (this PR): all 13 contract checks pass; 11/12 + monotone-intuition checks pass (the one "failure" is a legitimate multi-objective + trade-off, not a bug — FE flat while klat improves with K under fixed-Au+Ga). +- [`logs/sweep_tau_schedule.py`](../logs/sweep_tau_schedule.py) — the calibration grid + (τ_start × schedule × K × target-set) used to pick `annealing_scale = 0.5` as the + safe default. +- [`logs/test_max_elements_smoke.py`](../logs/test_max_elements_smoke.py) — minimal + smoke test confirming the byte-identical reproducibility of K=3/K=5 default + (`annealing_scale=0.5` ≡ the previous `τ_start=5.0` calibration). +- [`logs/eval_combined_abc.py`](../logs/eval_combined_abc.py) + + [`logs/plot_combined_abc.py`](../logs/plot_combined_abc.py) — the 9-config combined + evaluation chart (baseline + each of A/B/C alone + every pair + full stack at two + annealing settings). +- `paper_inverse_3scenarios` with `--output-dir + artifacts/inverse_design_run/inverse_design_max_elements/` — the three paper + scenarios rerun with the new A bars added (`paper_inverse_comparison.py` now threads + `max_elements` / `annealing_scale` / `annealing_schedule` from each comp-config + row); existing 5 comp + 3 latent bars are byte-identical to before. diff --git a/logs/eval_abc_intuition.py b/logs/eval_abc_intuition.py new file mode 100644 index 0000000..8ce5680 --- /dev/null +++ b/logs/eval_abc_intuition.py @@ -0,0 +1,360 @@ +"""Behavioral sweep + intuition check for max_elements / fixed_amounts / min_nonzero_weight. + +For each feature alone (1-D sweep) and a few combined cases (pairwise + full stack), this +script: + +1. Runs the optimisation on the inverse-design fine-tuned model. +2. Records achieved targets (FE, Mag, klat), QC probability, non-zero count, sample recipe. +3. Verifies the user-facing contract (e.g., ``∀ row: ``Au == pinned``; ``nz ≤ K``; + ``every non-zero ≥ floor``). +4. Compares the observed trend against a pre-stated intuition and flags mismatches. + +Output: + * ``logs/eval_abc_intuition.json`` — every config's numbers. + * ``logs/eval_abc_intuition.png`` — 1-D sweeps for A / B / C + combinations. + * stdout — markdown summary with PASS/FAIL on each intuition. + +Two scenarios are exercised: scenario1 (FE↓, Mag↑) and scenario3 (FE↓, klat↑). The 2-target +scenario1 lets us isolate the magnetization channel; scenario3 swaps in klat as a more +"reachable" objective. +""" +from __future__ import annotations + +import json +import time +import tomllib +from pathlib import Path + +import torch +from lightning import seed_everything + +from foundation_model.scripts.continual_rehearsal_demo import ( + QC_CLASSES, + ContinualRehearsalConfig, + ContinualRehearsalRunner, +) +from foundation_model.scripts.eval_inverse_methods import _qc_prob, _seed_weights_from_compositions +from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS + +REPO = Path(__file__).resolve().parents[1] +CFG_PATH = REPO / "samples/continual_rehearsal_demo_config_inverse_baseline.toml" +CKPT = REPO / "artifacts/inverse_design_run/finetune/final_model.pt" +OUT_JSON = REPO / "logs/eval_abc_intuition.json" + + +def _build(scenario_name: str): + raw = tomllib.loads(CFG_PATH.read_text(encoding="utf-8")) + scenarios = raw.pop("inverse_scenarios", []) + sc = next(s for s in scenarios if s["name"] == scenario_name) + raw["inverse_reg_tasks"] = sc["reg_tasks"] + raw["inverse_reg_targets"] = sc["reg_targets"] + config = ContinualRehearsalConfig(**raw) + seed_everything(config.random_seed, workers=True) + runner = ContinualRehearsalRunner(config) + model = runner._build_full_model() + state = torch.load(CKPT, map_location="cpu", weights_only=True) + state_dict = state["model"] if isinstance(state, dict) and "model" in state else state + model.load_state_dict(state_dict) + model.eval() + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + kernel = runner._kmd.kernel_torch(device=device, dtype=dtype) + def _qc_fn(x): return _qc_prob(model, x) + seeds = runner._select_seeds(model, device, _qc_fn)[:8] + w_seed = _seed_weights_from_compositions(seeds, n_components=kernel.shape[0]) + return model, kernel, w_seed, dict(sc), seeds + + +def _run(model, kernel, w_seed, reg_targets, **kwargs) -> dict: + common = dict( + task_targets=reg_targets, + class_targets={"material_type": QC_CLASSES}, + class_target_weight=5.0, + initial_weights=w_seed, + seed_blend=0.95, + steps=300, + lr=0.05, + ) + common.update(kwargs) + torch.manual_seed(0) + t0 = time.perf_counter() + res = model.optimize_composition(kernel, **common) + elapsed = time.perf_counter() - t0 + w = res.optimized_weights + targets = res.optimized_target.cpu().numpy() + achieved = {t: {"mean": float(targets[:, j].mean()), "std": float(targets[:, j].std())} + for j, t in enumerate(reg_targets)} + qc = float(_qc_prob(model, res.optimized_descriptor).mean()) + nz = (w > 1e-6).sum(dim=-1) + # Row-0 top 5 elements for inspection. + top = sorted( + ((float(w[0, i]), DEFAULT_ELEMENTS[i]) for i in range(kernel.shape[0]) if float(w[0, i]) > 1e-4), + reverse=True, + )[:5] + return { + "elapsed_s": round(elapsed, 3), + "achieved": achieved, + "qc": qc, + "nz_mean": float(nz.float().mean()), + "nz_min": int(nz.min()), + "nz_max": int(nz.max()), + "smallest_nonzero": float(w[w > 1e-6].min()) if (w > 1e-6).any() else 0.0, + "row0_recipe": [(s, round(v, 4)) for v, s in top], + # For contract checks: + "_au_col": [float(w[b, DEFAULT_ELEMENTS.index("Au")]) for b in range(w.shape[0])], + "_ga_col": [float(w[b, DEFAULT_ELEMENTS.index("Ga")]) for b in range(w.shape[0])], + } + + +def main() -> None: + print("=" * 80) + print("Behavioral evaluation: A (max_elements) + B (fixed_amounts) + C (min_nonzero_weight)") + print("=" * 80) + + all_results: dict[str, list[dict]] = {} + + for scenario_name in ["scenario1_fe_down_magnetic_up", "scenario3_fe_down_klat_up"]: + print(f"\n\n### {scenario_name}") + model, kernel, w_seed, sc, _ = _build(scenario_name) + reg_targets = dict(zip(sc["reg_tasks"], sc["reg_targets"])) + bucket: list[dict] = [] + + # === A sweep === + print("\n[A] max_elements sweep") + for K in (None, 2, 3, 4, 5, 6, 8, 10): + out = _run(model, kernel, w_seed, reg_targets, **({} if K is None else {"max_elements": K})) + out["experiment"] = "A" + out["K"] = K + out["floor"] = 0.0 + out["n_fixed"] = 0 + bucket.append(out) + + # === B sweep (Ga fixed at 0.20, vary Au) === + print("[B] fixed_amounts sweep (Ga=0.20, Au varies)") + for au in (0.0, 0.30, 0.45, 0.65, 0.75): + if au == 0.0: + fa = {"Ga": 0.20} # Au omitted → only Ga pinned + else: + fa = {"Au": au, "Ga": 0.20} + out = _run(model, kernel, w_seed, reg_targets, fixed_amounts=fa) + out["experiment"] = "B" + out["K"] = None + out["au_fixed"] = au + out["n_fixed"] = len(fa) + bucket.append(out) + + # === C sweep === + print("[C] min_nonzero_weight sweep") + for floor in (0.0, 0.05, 0.10, 0.15, 0.20): + kw = {} if floor == 0.0 else {"min_nonzero_weight": floor} + out = _run(model, kernel, w_seed, reg_targets, **kw) + out["experiment"] = "C" + out["K"] = None + out["floor"] = floor + out["n_fixed"] = 0 + bucket.append(out) + + # === A+B (K varies, Au+Ga fixed) === + print("[A+B] K varies, fixed Au=0.65, Ga=0.20") + for K in (3, 4, 5, 6, 8): + out = _run(model, kernel, w_seed, reg_targets, + max_elements=K, fixed_amounts={"Au": 0.65, "Ga": 0.20}) + out["experiment"] = "A+B" + out["K"] = K + out["n_fixed"] = 2 + bucket.append(out) + + # === A+C (K=5, floor varies) === + print("[A+C] K=5, floor varies") + for floor in (0.0, 0.05, 0.10, 0.15, 0.20): + kw = {"max_elements": 5} + if floor > 0: + kw["min_nonzero_weight"] = floor + out = _run(model, kernel, w_seed, reg_targets, **kw) + out["experiment"] = "A+C" + out["K"] = 5 + out["floor"] = floor + bucket.append(out) + + # === B+C (fix Au=0.30 Ga=0.20, floor varies) === + print("[B+C] fixed Au=0.30, Ga=0.20, floor varies") + for floor in (0.0, 0.05, 0.10, 0.15): + kw = {"fixed_amounts": {"Au": 0.30, "Ga": 0.20}} + if floor > 0: + kw["min_nonzero_weight"] = floor + out = _run(model, kernel, w_seed, reg_targets, **kw) + out["experiment"] = "B+C" + out["K"] = None + out["floor"] = floor + bucket.append(out) + + # === A+B+C full stack at 3 annealing settings === + print("[A+B+C] K=4, Au=0.30 Ga=0.20, floor=0.10, annealing varies") + for scale_label, scale in [("scale=0.3", 0.3), ("scale=0.5 default", 0.5), ("scale=0.8", 0.8)]: + out = _run(model, kernel, w_seed, reg_targets, + max_elements=4, + fixed_amounts={"Au": 0.30, "Ga": 0.20}, + min_nonzero_weight=0.10, + annealing_scale=scale) + out["experiment"] = "A+B+C" + out["annealing_scale_label"] = scale_label + out["K"] = 4 + out["n_fixed"] = 2 + out["floor"] = 0.10 + bucket.append(out) + + all_results[scenario_name] = bucket + + OUT_JSON.write_text(json.dumps(all_results, indent=2)) + print(f"\n[saved] {OUT_JSON}") + + # === Intuition checks === + _print_intuition_checks(all_results) + + +def _print_intuition_checks(all_results: dict[str, list[dict]]) -> None: + """Each intuition: a one-line description + observed values + PASS/FAIL. + + A failure here doesn't mean the implementation is broken — it means the model's loss + landscape doesn't satisfy the assumed monotone relationship. Useful regardless. + """ + print("\n" + "=" * 100) + print("INTUITION CHECKS") + print("=" * 100) + + for scenario_name, bucket in all_results.items(): + print(f"\n## {scenario_name}") + sc_short = "FE/Mag" if "magnetic" in scenario_name else "FE/klat" + reg_keys = list(bucket[0]["achieved"].keys()) + + def filter_exp(name: str) -> list[dict]: + return [r for r in bucket if r["experiment"] == name] + + # Check 1: A — nz exactly == K (or ≤ K). + a_rows = filter_exp("A") + constrained = [r for r in a_rows if r["K"] is not None] + nz_eq_K = all(r["nz_max"] <= r["K"] for r in constrained) + print(f"\n [A] nz_max ≤ K for every K-constrained config: " + f"{'PASS' if nz_eq_K else 'FAIL'}") + for r in a_rows: + tag = f"K={r['K']}" if r["K"] is not None else "baseline" + t_str = " ".join(f"{k}={r['achieved'][k]['mean']:+.2f}" for k in reg_keys) + print(f" {tag:<10} nz∈[{r['nz_min']}, {r['nz_max']}] QC={r['qc']:.2f} {t_str}") + + # Check 2: A — primary target (first reg) improves as K grows. + fe_seq = [r["achieved"][reg_keys[0]]["mean"] for r in a_rows if r["K"] is not None] + Ks = [r["K"] for r in a_rows if r["K"] is not None] + # Trend: lower-is-better for FE (target is -2.0). Compare smallest K to largest K. + fe_improves_with_K = fe_seq[0] >= fe_seq[-1] + print(f" [A] {reg_keys[0]} decreases (improves toward target) as K grows (K={Ks[0]}→{Ks[-1]}): " + f"{fe_seq[0]:+.2f} → {fe_seq[-1]:+.2f} " + f"{'PASS' if fe_improves_with_K else 'FAIL'}") + + # Check 3: B — Au and Ga pinned exactly across the batch. + b_rows = filter_exp("B") + all_pinned_ok = True + for r in b_rows: + if r["au_fixed"] > 0: + if not all(abs(v - r["au_fixed"]) < 1e-4 for v in r["_au_col"]): + all_pinned_ok = False + if not all(abs(v - 0.20) < 1e-4 for v in r["_ga_col"]): + all_pinned_ok = False + print(f"\n [B] fixed Au/Ga held exactly across all batch rows: " + f"{'PASS' if all_pinned_ok else 'FAIL'}") + for r in b_rows: + au_lbl = f"Au={r['au_fixed']:.2f}" if r["au_fixed"] > 0 else "Au not fixed" + t_str = " ".join(f"{k}={r['achieved'][k]['mean']:+.2f}" for k in reg_keys) + print(f" {au_lbl:<14} nz∈[{r['nz_min']}, {r['nz_max']}] QC={r['qc']:.2f} {t_str}") + + # Check 4: B — as Au grows from 0.30 to 0.80, less free mass → primary target worsens + b_pinned_rows = [r for r in b_rows if r["au_fixed"] > 0] + fe_seq_b = [r["achieved"][reg_keys[0]]["mean"] for r in b_pinned_rows] + fe_worsens_with_au = fe_seq_b[0] <= fe_seq_b[-1] + print(f" [B] {reg_keys[0]} worsens as Au pin grows (Au=0.30→0.75): " + f"{fe_seq_b[0]:+.2f} → {fe_seq_b[-1]:+.2f} " + f"{'PASS' if fe_worsens_with_au else 'FAIL'}") + + # Check 5: C — every non-zero ≥ floor; smallest_nonzero ≥ floor (within tol). + c_rows = filter_exp("C") + floor_held = all(r["smallest_nonzero"] >= r["floor"] - 1e-5 or r["floor"] == 0 for r in c_rows) + print(f"\n [C] smallest non-zero ≥ floor for every floored config: " + f"{'PASS' if floor_held else 'FAIL'}") + for r in c_rows: + t_str = " ".join(f"{k}={r['achieved'][k]['mean']:+.2f}" for k in reg_keys) + print(f" floor={r['floor']:.2f} nz∈[{r['nz_min']}, {r['nz_max']}] " + f"min_nz={r['smallest_nonzero']:.3f} QC={r['qc']:.2f} {t_str}") + + # Check 6: C — nz decreases monotonically as floor grows. + nz_seq = [r["nz_mean"] for r in c_rows] + nz_monotone = all(nz_seq[i] >= nz_seq[i+1] for i in range(len(nz_seq) - 1)) + print(f" [C] nz_mean decreases monotonically with floor: " + f"{[round(n, 1) for n in nz_seq]} " + f"{'PASS' if nz_monotone else 'FAIL'}") + + # Check 7: A+B — nz exactly = K, Au+Ga held. + ab_rows = filter_exp("A+B") + ab_nz_ok = all(r["nz_max"] <= r["K"] for r in ab_rows) + ab_pin_ok = all( + all(abs(v - 0.65) < 1e-4 for v in r["_au_col"]) + and all(abs(v - 0.20) < 1e-4 for v in r["_ga_col"]) + for r in ab_rows + ) + print(f"\n [A+B] nz ≤ K AND Au/Ga held exactly: " + f"{'PASS' if (ab_nz_ok and ab_pin_ok) else 'FAIL'}") + for r in ab_rows: + t_str = " ".join(f"{k}={r['achieved'][k]['mean']:+.2f}" for k in reg_keys) + print(f" K={r['K']:<2} nz∈[{r['nz_min']}, {r['nz_max']}] QC={r['qc']:.2f} {t_str}") + + # Check 8: A+B — K=3 (only 1 free slot) is worse than K=8 (6 free slots) on primary target. + fe_K3 = next(r["achieved"][reg_keys[0]]["mean"] for r in ab_rows if r["K"] == 3) + fe_K8 = next(r["achieved"][reg_keys[0]]["mean"] for r in ab_rows if r["K"] == 8) + print(f" [A+B] {reg_keys[0]} at K=3 ≥ K=8 (less freedom is worse): " + f"{fe_K3:+.2f} ≥ {fe_K8:+.2f} " + f"{'PASS' if fe_K3 >= fe_K8 else 'FAIL'}") + + # Check 9: A+C — nz_mean ≤ K=5, decreases as floor grows. + ac_rows = filter_exp("A+C") + ac_nz_le_K = all(r["nz_max"] <= 5 for r in ac_rows) + ac_nz_drop = all(ac_rows[i]["nz_mean"] >= ac_rows[i+1]["nz_mean"] for i in range(len(ac_rows) - 1)) + print(f"\n [A+C] nz ≤ K=5 AND nz_mean decreases with floor: " + f"{'PASS' if (ac_nz_le_K and ac_nz_drop) else 'FAIL'}") + for r in ac_rows: + t_str = " ".join(f"{k}={r['achieved'][k]['mean']:+.2f}" for k in reg_keys) + print(f" K=5, floor={r['floor']:.2f} nz∈[{r['nz_min']}, {r['nz_max']}] " + f"min_nz={r['smallest_nonzero']:.3f} QC={r['qc']:.2f} {t_str}") + + # Check 10: B+C — Au=0.30, Ga=0.20 held; floor respected. + bc_rows = filter_exp("B+C") + bc_pin_ok = all( + all(abs(v - 0.30) < 1e-4 for v in r["_au_col"]) + and all(abs(v - 0.20) < 1e-4 for v in r["_ga_col"]) + for r in bc_rows + ) + bc_floor_ok = all(r["smallest_nonzero"] >= r["floor"] - 1e-5 or r["floor"] == 0 for r in bc_rows) + print(f"\n [B+C] fixed values held AND floor respected: " + f"{'PASS' if (bc_pin_ok and bc_floor_ok) else 'FAIL'}") + for r in bc_rows: + t_str = " ".join(f"{k}={r['achieved'][k]['mean']:+.2f}" for k in reg_keys) + print(f" fix Au=.30 Ga=.20, floor={r['floor']:.2f} " + f"nz∈[{r['nz_min']}, {r['nz_max']}] min_nz={r['smallest_nonzero']:.3f} " + f"QC={r['qc']:.2f} {t_str}") + + # Check 11: A+B+C — all three contracts hold simultaneously. + abc_rows = filter_exp("A+B+C") + all_contracts_ok = all( + r["nz_max"] <= r["K"] + and all(abs(v - 0.30) < 1e-4 for v in r["_au_col"]) + and all(abs(v - 0.20) < 1e-4 for v in r["_ga_col"]) + and r["smallest_nonzero"] >= r["floor"] - 1e-5 + for r in abc_rows + ) + print(f"\n [A+B+C] all three contracts hold simultaneously: " + f"{'PASS' if all_contracts_ok else 'FAIL'}") + for r in abc_rows: + t_str = " ".join(f"{k}={r['achieved'][k]['mean']:+.2f}" for k in reg_keys) + print(f" {r['annealing_scale_label']:<20} nz∈[{r['nz_min']}, {r['nz_max']}] " + f"min_nz={r['smallest_nonzero']:.3f} QC={r['qc']:.2f} {t_str}") + + +if __name__ == "__main__": + main() diff --git a/logs/eval_combined_abc.py b/logs/eval_combined_abc.py new file mode 100644 index 0000000..0e40b53 --- /dev/null +++ b/logs/eval_combined_abc.py @@ -0,0 +1,203 @@ +"""Comprehensive evaluation of A (max_elements) + B (fixed_amounts) + C (min_nonzero_weight). + +Exercises each feature in isolation and combined, on the inverse-design fine-tuned model. +Verifies that: + 1. Each constraint enforces its contract (≤ K non-zero; fixed values held; floor respected). + 2. Combinations compose cleanly (the simplex invariant is preserved everywhere). + 3. The annealing_scale knob still works on top of the full feature stack. + +Reports per-config: achieved targets (FE, Mag), QC, non-zero count, row-0 recipe. +""" + +from __future__ import annotations + +import tomllib +from pathlib import Path + +import torch +from lightning import seed_everything + +from foundation_model.scripts.continual_rehearsal_demo import ( + QC_CLASSES, + ContinualRehearsalConfig, + ContinualRehearsalRunner, +) +from foundation_model.scripts.eval_inverse_methods import _qc_prob, _seed_weights_from_compositions +from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS + +REPO = Path(__file__).resolve().parents[1] +CFG_PATH = REPO / "samples/continual_rehearsal_demo_config_inverse_baseline.toml" +CKPT = REPO / "artifacts/inverse_design_run/finetune/final_model.pt" +SCENARIO = "scenario1_fe_down_magnetic_up" + + +def _build(): + raw = tomllib.loads(CFG_PATH.read_text(encoding="utf-8")) + scenarios = raw.pop("inverse_scenarios", []) + sc = next(s for s in scenarios if s["name"] == SCENARIO) + raw["inverse_reg_tasks"] = sc["reg_tasks"] + raw["inverse_reg_targets"] = sc["reg_targets"] + config = ContinualRehearsalConfig(**raw) + seed_everything(config.random_seed, workers=True) + runner = ContinualRehearsalRunner(config) + model = runner._build_full_model() + state = torch.load(CKPT, map_location="cpu", weights_only=True) + state_dict = state["model"] if isinstance(state, dict) and "model" in state else state + model.load_state_dict(state_dict) + model.eval() + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + kernel = runner._kmd.kernel_torch(device=device, dtype=dtype) + return runner, model, kernel, device + + +def _summarise(label: str, res, *, expected=None, floor=None): + """Print one config's results + verify the user-facing contract.""" + w = res.optimized_weights + B = w.shape[0] + nz_mask = w > 1e-6 + nz_per_row = nz_mask.sum(dim=-1).tolist() + row_sums = w.sum(dim=-1) + targets = res.optimized_target.cpu().numpy() + print(f"\n [{label}]") + print(f" elapsed: {res.optimized_weights.shape[0]} rows; nz per row: {nz_per_row}") + print(f" row sums (should be 1.0): min={row_sums.min():.5f}, max={row_sums.max():.5f}") + print(f" FE = {targets[:, 0].mean():+.3f} ± {targets[:, 0].std():.3f} (target -2.0)") + print(f" Mag = {targets[:, 1].mean():+.3f} ± {targets[:, 1].std():.3f} (target +2.0)") + # Row 0 recipe. + w0 = w[0].cpu().numpy() + top = sorted(((float(w0[i]), DEFAULT_ELEMENTS[i]) for i in range(len(w0)) if float(w0[i]) > 1e-4), reverse=True) + print(" row 0 recipe: " + ", ".join(f"{s}={v:.3f}" for v, s in top)) + # Contract checks. + if expected is not None: + for sym, want in expected.items(): + idx = DEFAULT_ELEMENTS.index(sym) + got = w[:, idx].tolist() + ok = all(abs(g - want) < 1e-4 for g in got) + mark = "✓" if ok else "✗" + print(f" [{mark}] {sym} pinned at {want}: got {[round(g, 4) for g in got]}") + if floor is not None: + # Every non-zero unlocked element should be ≥ floor. + violated = (w > 1e-6) & (w < floor - 1e-5) + n_violations = int(violated.sum().item()) + smallest_nz = w[w > 1e-6].min().item() if (w > 1e-6).any() else 0.0 + mark = "✓" if n_violations == 0 else f"✗ ({n_violations} positions)" + print(f" [{mark}] every non-zero ≥ {floor}: smallest non-zero = {smallest_nz:.4f}") + + +def main() -> None: + print(f"[loading] {CKPT}") + runner, model, kernel, device = _build() + + def _qc_fn(x): + return _qc_prob(model, x) + + seeds = runner._select_seeds(model, device, _qc_fn)[:8] + w_seed = _seed_weights_from_compositions(seeds, n_components=kernel.shape[0]) + common = dict( + task_targets={"formation_energy": -2.0, "magnetization": 2.0}, + class_targets={"material_type": QC_CLASSES}, + class_target_weight=5.0, + initial_weights=w_seed, + seed_blend=0.95, + steps=300, + lr=0.05, + ) + + print("\n=== individual features ===") + torch.manual_seed(0) + _summarise("baseline (no constraint)", model.optimize_composition(kernel, **common)) + + torch.manual_seed(0) + _summarise("A only — max_elements=3", model.optimize_composition(kernel, max_elements=3, **common)) + + torch.manual_seed(0) + _summarise( + "B only — fixed Au=0.65, Ga=0.20", + model.optimize_composition(kernel, fixed_amounts={"Au": 0.65, "Ga": 0.20}, **common), + expected={"Au": 0.65, "Ga": 0.20}, + ) + + torch.manual_seed(0) + _summarise( + "C only — floor=0.10", + model.optimize_composition(kernel, min_nonzero_weight=0.10, **common), + floor=0.10, + ) + + print("\n=== pairwise combinations ===") + torch.manual_seed(0) + _summarise( + "A + B — K=4, fixed Au=0.65 Ga=0.20", + model.optimize_composition(kernel, max_elements=4, fixed_amounts={"Au": 0.65, "Ga": 0.20}, **common), + expected={"Au": 0.65, "Ga": 0.20}, + ) + + torch.manual_seed(0) + _summarise( + "A + C — K=5, floor=0.10", + model.optimize_composition(kernel, max_elements=5, min_nonzero_weight=0.10, **common), + floor=0.10, + ) + + torch.manual_seed(0) + _summarise( + "B + C — fixed Au=0.30 Ga=0.20, floor=0.10", + model.optimize_composition(kernel, fixed_amounts={"Au": 0.30, "Ga": 0.20}, min_nonzero_weight=0.10, **common), + expected={"Au": 0.30, "Ga": 0.20}, + floor=0.10, + ) + + print("\n=== full stack A+B+C ===") + torch.manual_seed(0) + _summarise( + "A + B + C — K=4, fixed Au=0.30 Ga=0.20, floor=0.10 (default annealing)", + model.optimize_composition( + kernel, + max_elements=4, + fixed_amounts={"Au": 0.30, "Ga": 0.20}, + min_nonzero_weight=0.10, + **common, + ), + expected={"Au": 0.30, "Ga": 0.20}, + floor=0.10, + ) + + torch.manual_seed(0) + _summarise( + "A + B + C — same + annealing_scale=0.8 (more exploration)", + model.optimize_composition( + kernel, + max_elements=4, + fixed_amounts={"Au": 0.30, "Ga": 0.20}, + min_nonzero_weight=0.10, + annealing_scale=0.8, + **common, + ), + expected={"Au": 0.30, "Ga": 0.20}, + floor=0.10, + ) + + torch.manual_seed(0) + _summarise( + "A + B + C — same + advanced schedule (warm-up then linear)", + model.optimize_composition( + kernel, + max_elements=4, + fixed_amounts={"Au": 0.30, "Ga": 0.20}, + min_nonzero_weight=0.10, + annealing_scale=0.5, + annealing_schedule={ + "step": [0.2, 0.7, 1.0], + "scale": [0.9, 0.5, 0.0], + "annealing_func": ["constant", "linear", "linear"], + }, + **common, + ), + expected={"Au": 0.30, "Ga": 0.20}, + floor=0.10, + ) + + +if __name__ == "__main__": + main() diff --git a/logs/plot_abc_intuition.py b/logs/plot_abc_intuition.py new file mode 100644 index 0000000..b88a4df --- /dev/null +++ b/logs/plot_abc_intuition.py @@ -0,0 +1,120 @@ +"""Multi-panel sweep plot from eval_abc_intuition.json.""" +from __future__ import annotations + +import json +from pathlib import Path + +import matplotlib.pyplot as plt + +REPO = Path(__file__).resolve().parents[1] +JSON_PATH = REPO / "logs/eval_abc_intuition.json" +OUT_PNG = REPO / "logs/eval_abc_intuition.png" + +SCENARIOS = ("scenario1_fe_down_magnetic_up", "scenario3_fe_down_klat_up") +PRETTY = { + "scenario1_fe_down_magnetic_up": "Scenario 1 — FE↓, Mag↑", + "scenario3_fe_down_klat_up": "Scenario 3 — FE↓, klat↑", +} + + +def main() -> None: + data = json.loads(JSON_PATH.read_text()) + + # Two rows (scenarios) × five columns (one per experiment: A, B, C, A+B / A+C / B+C / A+B+C) + # We'll do: row=scenario, columns = A | B | C | A+C | B+C + fig, axes = plt.subplots(2, 5, figsize=(22, 8), squeeze=False) + + for r, scen in enumerate(SCENARIOS): + bucket = data[scen] + reg_keys = list(bucket[0]["achieved"].keys()) + primary = reg_keys[0] # "formation_energy" + secondary = reg_keys[1] # mag or klat + + # --- Column 0: A sweep (K vs targets, QC, nz) --- + a_rows = [r for r in bucket if r["experiment"] == "A" and r["K"] is not None] + Ks = [r["K"] for r in a_rows] + ax = axes[r, 0] + ax.plot(Ks, [r["achieved"][primary]["mean"] for r in a_rows], "o-", label=primary, color="#2563EB") + ax.plot(Ks, [r["achieved"][secondary]["mean"] for r in a_rows], "s-", label=secondary, color="#E0762A") + ax.plot(Ks, [r["qc"] for r in a_rows], "^-", label="QC", color="#55A868") + # baseline reference lines + base = next(r for r in bucket if r["experiment"] == "A" and r["K"] is None) + for v, c, ls in [(base["achieved"][primary]["mean"], "#2563EB", "--"), + (base["achieved"][secondary]["mean"], "#E0762A", "--"), + (base["qc"], "#55A868", "--")]: + ax.axhline(v, color=c, linestyle=ls, alpha=0.4, linewidth=0.8) + ax.set_xlabel("max_elements (K)") + ax.set_title(f"{PRETTY[scen]}\nA — vary K") + ax.legend(fontsize=8, loc="best") + ax.grid(True, alpha=0.3) + + # --- Column 1: B sweep (Au pin vs targets, QC) --- + b_rows = [r for r in bucket if r["experiment"] == "B"] + aus = [r["au_fixed"] for r in b_rows] + ax = axes[r, 1] + ax.plot(aus, [r["achieved"][primary]["mean"] for r in b_rows], "o-", label=primary, color="#2563EB") + ax.plot(aus, [r["achieved"][secondary]["mean"] for r in b_rows], "s-", label=secondary, color="#E0762A") + ax.plot(aus, [r["qc"] for r in b_rows], "^-", label="QC", color="#55A868") + ax.set_xlabel("fixed Au amount (Ga=0.20)") + ax.set_title("B — vary Au pin") + ax.legend(fontsize=8, loc="best") + ax.grid(True, alpha=0.3) + + # --- Column 2: C sweep (floor vs targets, nz_mean) --- + c_rows = [r for r in bucket if r["experiment"] == "C"] + floors = [r["floor"] for r in c_rows] + ax = axes[r, 2] + ax.plot(floors, [r["achieved"][primary]["mean"] for r in c_rows], "o-", label=primary, color="#2563EB") + ax.plot(floors, [r["achieved"][secondary]["mean"] for r in c_rows], "s-", label=secondary, color="#E0762A") + ax.plot(floors, [r["qc"] for r in c_rows], "^-", label="QC", color="#55A868") + ax_nz = ax.twinx() + ax_nz.plot(floors, [r["nz_mean"] for r in c_rows], "d:", label="nz_mean", color="#888") + ax_nz.set_ylabel("mean nz", color="#888") + ax_nz.set_yscale("symlog") + ax.set_xlabel("min_nonzero_weight (floor)") + ax.set_title("C — vary floor") + ax.legend(fontsize=8, loc="best") + ax.grid(True, alpha=0.3) + + # --- Column 3: A+C (K=5, floor sweep) --- + ac_rows = [r for r in bucket if r["experiment"] == "A+C"] + floors = [r["floor"] for r in ac_rows] + ax = axes[r, 3] + ax.plot(floors, [r["achieved"][primary]["mean"] for r in ac_rows], "o-", label=primary, color="#2563EB") + ax.plot(floors, [r["achieved"][secondary]["mean"] for r in ac_rows], "s-", label=secondary, color="#E0762A") + ax.plot(floors, [r["qc"] for r in ac_rows], "^-", label="QC", color="#55A868") + ax_nz = ax.twinx() + ax_nz.plot(floors, [r["nz_mean"] for r in ac_rows], "d:", color="#888") + ax_nz.set_ylabel("mean nz", color="#888") + ax.set_xlabel("floor (K=5 fixed)") + ax.set_title("A+C — K=5, vary floor") + ax.legend(fontsize=8, loc="best") + ax.grid(True, alpha=0.3) + + # --- Column 4: B+C (fix Au=0.30 Ga=0.20, floor sweep) --- + bc_rows = [r for r in bucket if r["experiment"] == "B+C"] + floors = [r["floor"] for r in bc_rows] + ax = axes[r, 4] + ax.plot(floors, [r["achieved"][primary]["mean"] for r in bc_rows], "o-", label=primary, color="#2563EB") + ax.plot(floors, [r["achieved"][secondary]["mean"] for r in bc_rows], "s-", label=secondary, color="#E0762A") + ax.plot(floors, [r["qc"] for r in bc_rows], "^-", label="QC", color="#55A868") + ax_nz = ax.twinx() + ax_nz.plot(floors, [r["nz_mean"] for r in bc_rows], "d:", color="#888") + ax_nz.set_ylabel("mean nz", color="#888") + ax.set_xlabel("floor (fix Au=0.30 Ga=0.20)") + ax.set_title("B+C — fix + vary floor") + ax.legend(fontsize=8, loc="best") + ax.grid(True, alpha=0.3) + + fig.suptitle( + "Behavioural sweeps for A (max_elements) · B (fixed_amounts) · C (min_nonzero_weight)\n" + "dashed horizontal lines in A panels = unconstrained baseline; lower FE / higher Mag/klat / higher QC is better", + fontsize=12, + ) + plt.tight_layout() + plt.savefig(OUT_PNG, dpi=110, bbox_inches="tight") + print(f"saved: {OUT_PNG}") + + +if __name__ == "__main__": + main() diff --git a/logs/plot_combined_abc.py b/logs/plot_combined_abc.py new file mode 100644 index 0000000..2037f36 --- /dev/null +++ b/logs/plot_combined_abc.py @@ -0,0 +1,159 @@ +"""Visualise the A+B+C comprehensive evaluation as a comparison bar chart.""" +from __future__ import annotations + +import json +import tomllib +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +from lightning import seed_everything + +from foundation_model.scripts.continual_rehearsal_demo import ( + QC_CLASSES, + ContinualRehearsalConfig, + ContinualRehearsalRunner, +) +from foundation_model.scripts.eval_inverse_methods import _qc_prob, _seed_weights_from_compositions + +REPO = Path(__file__).resolve().parents[1] +CFG_PATH = REPO / "samples/continual_rehearsal_demo_config_inverse_baseline.toml" +CKPT = REPO / "artifacts/inverse_design_run/finetune/final_model.pt" +SCENARIO = "scenario1_fe_down_magnetic_up" +OUT_PNG = REPO / "logs/combined_abc_comparison.png" + + +def _build(): + raw = tomllib.loads(CFG_PATH.read_text(encoding="utf-8")) + scenarios = raw.pop("inverse_scenarios", []) + sc = next(s for s in scenarios if s["name"] == SCENARIO) + raw["inverse_reg_tasks"] = sc["reg_tasks"] + raw["inverse_reg_targets"] = sc["reg_targets"] + config = ContinualRehearsalConfig(**raw) + seed_everything(config.random_seed, workers=True) + runner = ContinualRehearsalRunner(config) + model = runner._build_full_model() + state = torch.load(CKPT, map_location="cpu", weights_only=True) + state_dict = state["model"] if isinstance(state, dict) and "model" in state else state + model.load_state_dict(state_dict) + model.eval() + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + kernel = runner._kmd.kernel_torch(device=device, dtype=dtype) + return runner, model, kernel, device + + +def main() -> None: + runner, model, kernel, device = _build() + def _qc_fn(x): return _qc_prob(model, x) + seeds = runner._select_seeds(model, device, _qc_fn)[:8] + w_seed = _seed_weights_from_compositions(seeds, n_components=kernel.shape[0]) + common = dict( + task_targets={"formation_energy": -2.0, "magnetization": 2.0}, + class_targets={"material_type": QC_CLASSES}, + class_target_weight=5.0, + initial_weights=w_seed, + seed_blend=0.95, + steps=300, + lr=0.05, + ) + + configs = [ + ("baseline", {}), + ("A: K=3", {"max_elements": 3}), + ("B: fix Au=.65 Ga=.20", {"fixed_amounts": {"Au": 0.65, "Ga": 0.20}}), + ("C: floor=.10", {"min_nonzero_weight": 0.10}), + ("A+B: K=4 + fix", {"max_elements": 4, "fixed_amounts": {"Au": 0.65, "Ga": 0.20}}), + ("A+C: K=5 + floor", {"max_elements": 5, "min_nonzero_weight": 0.10}), + ("B+C: fix Au=.30 Ga=.20 + floor=.10", + {"fixed_amounts": {"Au": 0.30, "Ga": 0.20}, "min_nonzero_weight": 0.10}), + ("A+B+C: K=4 + fix + floor", + {"max_elements": 4, "fixed_amounts": {"Au": 0.30, "Ga": 0.20}, "min_nonzero_weight": 0.10}), + ("A+B+C, scale=0.8", + {"max_elements": 4, "fixed_amounts": {"Au": 0.30, "Ga": 0.20}, + "min_nonzero_weight": 0.10, "annealing_scale": 0.8}), + ] + + results = [] + for label, extras in configs: + torch.manual_seed(0) + res = model.optimize_composition(kernel, **common, **extras) + w = res.optimized_weights + results.append({ + "label": label, + "fe_mean": float(res.optimized_target[:, 0].mean()), + "fe_std": float(res.optimized_target[:, 0].std()), + "mag_mean": float(res.optimized_target[:, 1].mean()), + "mag_std": float(res.optimized_target[:, 1].std()), + "qc": float(_qc_prob(model, res.optimized_descriptor).mean()), + "nz_mean": float((w > 1e-6).sum(dim=-1).float().mean()), + }) + + # Plot + fig, axes = plt.subplots(1, 4, figsize=(22, 5.5), squeeze=False) + axes = axes[0] + labels = [r["label"] for r in results] + colors = ["#888"] + ["#2563EB"] * (len(results) - 1) + x = np.arange(len(results)) + + # Panel 1: FE + ax = axes[0] + fe_means = [r["fe_mean"] for r in results] + fe_stds = [r["fe_std"] for r in results] + ax.bar(x, fe_means, yerr=fe_stds, color=colors, capsize=4) + ax.axhline(-2.0, color="red", linestyle="--", label="target -2.0") + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8) + ax.set_ylabel("formation_energy") + ax.set_title("Formation energy (↓)") + ax.legend() + + # Panel 2: Mag + ax = axes[1] + mag_means = [r["mag_mean"] for r in results] + mag_stds = [r["mag_std"] for r in results] + ax.bar(x, mag_means, yerr=mag_stds, color=colors, capsize=4) + ax.axhline(2.0, color="red", linestyle="--", label="target +2.0") + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8) + ax.set_ylabel("magnetization") + ax.set_title("Magnetization (↑)") + ax.legend() + + # Panel 3: QC + ax = axes[2] + qcs = [r["qc"] for r in results] + ax.bar(x, qcs, color=colors) + ax.axhline(1.0, color="red", linestyle="--", label="target 1.0") + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8) + ax.set_ylabel("P(material_type ∈ QC)") + ax.set_ylim(0, 1) + ax.set_title("QC probability (↑)") + ax.legend() + + # Panel 4: nz + ax = axes[3] + nzs = [r["nz_mean"] for r in results] + ax.bar(x, nzs, color=colors) + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8) + ax.set_ylabel("mean # non-zero elements") + ax.set_title("Recipe complexity (lower = simpler)") + ax.set_yscale("symlog") + + fig.suptitle( + "A (max_elements) + B (fixed_amounts) + C (min_nonzero_weight) — combined evaluation\n" + "scenario1 (FE↓, Mag↑) · 8 seeded starts · 300 steps", + fontsize=12, + ) + plt.tight_layout() + plt.savefig(OUT_PNG, dpi=110, bbox_inches="tight") + print(f"saved: {OUT_PNG}") + # also dump JSON for replay + (REPO / "logs/combined_abc_results.json").write_text(json.dumps(results, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/logs/plot_sweep.py b/logs/plot_sweep.py new file mode 100644 index 0000000..b15e2fa --- /dev/null +++ b/logs/plot_sweep.py @@ -0,0 +1,106 @@ +"""Visualise sweep_tau_schedule_results.json as heatmaps + scatter.""" +from __future__ import annotations + +import json +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +REPO = Path(__file__).resolve().parents[1] +RESULTS = REPO / "logs/sweep_tau_schedule_results.json" +OUT_PNG = REPO / "logs/sweep_tau_schedule.png" + +TAU_STARTS = [1.0, 2.0, 5.0, 10.0, 20.0] +SCHEDULES = ["geometric", "linear", "cosine"] + + +def _target_distance(achieved: dict, targets: dict) -> float: + """Mean |achieved - target| across regression objectives. Lower is better.""" + return float(np.mean([abs(achieved[t]["mean"] - v) for t, v in targets.items()])) + + +def main() -> None: + data = json.loads(RESULTS.read_text()) + target_sets = { + "2T": {"formation_energy": -2.0, "magnetization": 2.0}, + "1T": {"magnetization": 2.0}, + } + Ks = [3, 5] + + fig, axes = plt.subplots(2, 4, figsize=(20, 10), squeeze=False) + + # Top row: target-distance heatmaps (lower is better, blue = good) + # Bottom row: QC-probability heatmaps (higher is better, blue = good) + for col, (tset_name, tgt) in enumerate(target_sets.items()): + for k_idx, K in enumerate(Ks): + col_idx = col * 2 + k_idx + # Build (5, 3) grid: τ_start × schedule + dist_grid = np.full((len(TAU_STARTS), len(SCHEDULES)), np.nan) + qc_grid = np.full((len(TAU_STARTS), len(SCHEDULES)), np.nan) + for r in data: + if r["target_set"] != tset_name or r["K"] != K: + continue + if r["schedule"] not in SCHEDULES or r["tau_start"] not in TAU_STARTS: + continue + i = TAU_STARTS.index(r["tau_start"]) + j = SCHEDULES.index(r["schedule"]) + dist_grid[i, j] = _target_distance(r["achieved"], tgt) + qc_grid[i, j] = r["qc_after"] + + # Get baselines for annotation + base = next(r for r in data if r["target_set"] == tset_name and r["schedule"] == "baseline") + base_dist = _target_distance(base["achieved"], tgt) + base_qc = base["qc_after"] + const_hi = next(r for r in data if r["target_set"] == tset_name + and r["K"] == K and r["schedule"] == "const_t1.0") + post_dist = _target_distance(const_hi["achieved"], tgt) + post_qc = const_hi["qc_after"] + + # Top row: target distance + ax = axes[0, col_idx] + im = ax.imshow(dist_grid, aspect="auto", cmap="RdYlBu_r", origin="lower") + ax.set_xticks(range(len(SCHEDULES))) + ax.set_xticklabels(SCHEDULES) + ax.set_yticks(range(len(TAU_STARTS))) + ax.set_yticklabels([str(t) for t in TAU_STARTS]) + ax.set_xlabel("schedule") + ax.set_ylabel("τ_start") + ax.set_title(f"{tset_name} K={K}\ntarget-dist (baseline={base_dist:.2f}, post-hoc={post_dist:.2f})") + for i in range(len(TAU_STARTS)): + for j in range(len(SCHEDULES)): + val = dist_grid[i, j] + ax.text(j, i, f"{val:.2f}", ha="center", va="center", + color="white" if val > np.nanmedian(dist_grid) else "black", fontsize=9) + plt.colorbar(im, ax=ax, fraction=0.046) + + # Bottom row: QC + ax = axes[1, col_idx] + im = ax.imshow(qc_grid, aspect="auto", cmap="RdYlBu", origin="lower", vmin=0.5, vmax=1.0) + ax.set_xticks(range(len(SCHEDULES))) + ax.set_xticklabels(SCHEDULES) + ax.set_yticks(range(len(TAU_STARTS))) + ax.set_yticklabels([str(t) for t in TAU_STARTS]) + ax.set_xlabel("schedule") + ax.set_ylabel("τ_start") + ax.set_title(f"{tset_name} K={K}\nQC prob (baseline={base_qc:.2f}, post-hoc={post_qc:.2f})") + for i in range(len(TAU_STARTS)): + for j in range(len(SCHEDULES)): + val = qc_grid[i, j] + ax.text(j, i, f"{val:.2f}", ha="center", va="center", + color="white" if val < 0.75 else "black", fontsize=9) + plt.colorbar(im, ax=ax, fraction=0.046) + + fig.suptitle( + "max_elements sweep: τ_start × schedule × K × target_set\n" + "top row: |achieved − target| averaged across regression objectives (LOWER is better, blue)\n" + "bottom row: P(material_type ∈ QC_classes) (HIGHER is better, blue)", + fontsize=12, + ) + plt.tight_layout() + plt.savefig(OUT_PNG, dpi=110, bbox_inches="tight") + print(f"saved: {OUT_PNG}") + + +if __name__ == "__main__": + main() diff --git a/logs/sweep_tau_schedule.py b/logs/sweep_tau_schedule.py new file mode 100644 index 0000000..3396b93 --- /dev/null +++ b/logs/sweep_tau_schedule.py @@ -0,0 +1,282 @@ +"""Systematic sweep over topk_tau_start × topk_schedule × K × target-set. + +Goal: find a sweet spot for the annealing knobs by comparing achieved targets vs the +unconstrained baseline and vs the "constant τ" controls (which approximate post-hoc +projection at high τ and hard-from-start at low τ). + +Output: + logs/sweep_tau_schedule_results.json — all per-config results + stdout: pivot tables for quick reading + +We test: + - **target_set**: '2T' (FE + mag, 3 objectives with QC) vs '1T' (mag only, 2 objectives with QC) + Rationale: 3 targets may be over-constrained; user asked to retry without FE if needed. + - **K**: 3 (the user's main case) and 5 (a looser constraint for comparison) + - **τ_start**: 1.0, 2.0, 5.0, 10.0, 20.0 + - **schedule**: geometric, linear, cosine + - Controls per (target_set, K): + * no-constraint baseline + * constant τ=1.0 (effectively post-hoc projection) + * constant τ=0.01 (effectively hard-from-start) + +For each config we report: + - achieved target means/stds across the n_starts=8 batch + - QC probability (P(material_type ∈ QC_CLASSES)) post-optimisation + - the most common K-element recipe (intersection of top-K per row) + +Run from repo root: + uv run python logs/sweep_tau_schedule.py +""" + +from __future__ import annotations + +import itertools +import json +import time +import tomllib +from collections import Counter +from pathlib import Path + +import torch +from lightning import seed_everything + +from foundation_model.scripts.continual_rehearsal_demo import ( + QC_CLASSES, + ContinualRehearsalConfig, + ContinualRehearsalRunner, +) +from foundation_model.scripts.eval_inverse_methods import _qc_prob, _seed_weights_from_compositions +from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS + +REPO = Path(__file__).resolve().parents[1] +CFG_PATH = REPO / "samples/continual_rehearsal_demo_config_inverse_baseline.toml" +CKPT = REPO / "artifacts/inverse_design_run/finetune/final_model.pt" +SCENARIO = "scenario1_fe_down_magnetic_up" +OUT_JSON = REPO / "logs/sweep_tau_schedule_results.json" + +# Per-objective targets used in the scenario; 2T uses both, 1T drops FE. +TARGETS_FULL = {"formation_energy": -2.0, "magnetization": 2.0} + + +def _build(): + raw = tomllib.loads(CFG_PATH.read_text(encoding="utf-8")) + scenarios = raw.pop("inverse_scenarios", []) + sc = next(s for s in scenarios if s["name"] == SCENARIO) + raw["inverse_reg_tasks"] = sc["reg_tasks"] + raw["inverse_reg_targets"] = sc["reg_targets"] + config = ContinualRehearsalConfig(**raw) + seed_everything(config.random_seed, workers=True) + runner = ContinualRehearsalRunner(config) + model = runner._build_full_model() + state = torch.load(CKPT, map_location="cpu", weights_only=True) + state_dict = state["model"] if isinstance(state, dict) and "model" in state else state + model.load_state_dict(state_dict) + model.eval() + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + kernel = runner._kmd.kernel_torch(device=device, dtype=dtype) + return runner, model, kernel, device + + +def _qc_prob_mean(model, x_desc: torch.Tensor) -> float: + """Mean over batch of P(class ∈ QC_CLASSES).""" + p = _qc_prob(model, x_desc) # (B,) per-row QC probability + return float(p.mean()) + + +def _tau_to_scale(tau: float) -> float: + """Convert a raw τ to the normalised ``annealing_scale`` (inverse of ``τ = 25**scale``). + + Clamped to [0, 1]: τ values outside [1, 25] saturate at the endpoints. + """ + import math + + if tau <= 1.0: + return 0.0 + if tau >= 25.0: + return 1.0 + return math.log(tau) / math.log(25.0) + + +def _evaluate(model, kernel, w_seed, *, max_elements, tau_start, tau_end, schedule, targets, seed=0, steps=300): + """Adapter from the old (τ_start, τ_end, schedule) triple to the new annealing API. + + - schedule="constant": single-segment dict that holds the start scale for the full run. + - schedule="geometric" with default τ_end=0.01: just set ``annealing_scale`` (default geometric + from ``25**scale`` down to 0.01). + - schedule="linear"/"cosine": single-segment dict that interpolates scale_start → 0 (i.e. + τ → 1.0) with the named func; the default tail from step=1.0 → ... is absent (we cover the + full optimisation). Note the linear/cosine sweep cannot reach τ=0.01 inside the loop + without the geometric tail — the final hard projection still gives K-hot output. + """ + torch.manual_seed(seed) + t0 = time.perf_counter() + kwargs = dict( + task_targets=targets, + class_targets={"material_type": QC_CLASSES}, + class_target_weight=5.0, + initial_weights=w_seed, + seed_blend=0.95, + steps=steps, + lr=0.05, + ) + if max_elements is not None: + kwargs["max_elements"] = max_elements + scale_start = _tau_to_scale(tau_start) + if schedule == "geometric": + # Default schedule; just supply annealing_scale (covers τ_start → 0.01 geometrically). + kwargs["annealing_scale"] = scale_start + elif schedule == "constant": + kwargs["annealing_scale"] = scale_start + kwargs["annealing_schedule"] = { + "step": [1.0], + "scale": [scale_start], + "annealing_func": ["constant"], + } + elif schedule in ("linear", "cosine"): + # Single-segment schedule interpolating in scale space from scale_start down to 0 + # (τ from 25**scale_start down to 1.0). The hard projection at the end still + # cleans up to K-hot regardless of where the in-loop τ settles. + kwargs["annealing_scale"] = scale_start + kwargs["annealing_schedule"] = { + "step": [1.0], + "scale": [0.0], + "annealing_func": [schedule], + } + else: + raise ValueError(f"unknown schedule {schedule!r}") + res = model.optimize_composition(kernel, **kwargs) + elapsed = time.perf_counter() - t0 + w = res.optimized_weights + nz = (w > 1e-6).sum(dim=-1).tolist() + targets_arr = res.optimized_target.cpu().numpy() + achieved = { + t: {"mean": float(targets_arr[:, j].mean()), "std": float(targets_arr[:, j].std())} + for j, t in enumerate(targets.keys()) + } + # QC probability after decode + qc_after = _qc_prob_mean(model, res.optimized_descriptor) + # Per-row top-K recipe → element frequency across batch + elem_counter: Counter[str] = Counter() + K_used = max_elements if max_elements is not None else 5 + for b in range(w.shape[0]): + top_idx = w[b].argsort(descending=True)[:K_used] + for i in top_idx: + if float(w[b, i]) > 1e-4: + elem_counter[DEFAULT_ELEMENTS[int(i)]] += 1 + return { + "elapsed_s": round(elapsed, 3), + "nz_per_row": nz, + "achieved": achieved, + "qc_after": qc_after, + "top_elements": elem_counter.most_common(8), + } + + +def main() -> None: + print(f"[loading] {CKPT}") + runner, model, kernel, device = _build() + + def _qc_fn(x): + return _qc_prob(model, x) + + seeds = runner._select_seeds(model, device, _qc_fn)[:8] + print(f"[seeds] {seeds}") + w_seed = _seed_weights_from_compositions(seeds, n_components=kernel.shape[0]) + + target_sets = { + "2T": TARGETS_FULL, # FE + mag + "1T": {"magnetization": 2.0}, # mag only (drop FE) + } + Ks = [3, 5] + tau_starts = [1.0, 2.0, 5.0, 10.0, 20.0] + schedules = ["geometric", "linear", "cosine"] + tau_end = 0.01 + + results: list[dict] = [] + n_total = 0 + for tset_name in target_sets: + # Baseline (no constraint) + n_total += 1 + n_total += len(Ks) * 2 # constant τ=1.0 and τ=0.01 per K + n_total += len(Ks) * len(tau_starts) * len(schedules) + print(f"[plan] {n_total} configs total") + + counter = 0 + for tset_name, tgt in target_sets.items(): + # No-constraint baseline + counter += 1 + print(f"\n[{counter}/{n_total}] {tset_name} baseline (no max_elements)") + out = _evaluate( + model, kernel, w_seed, max_elements=None, tau_start=None, tau_end=None, schedule=None, targets=tgt + ) + results.append({"target_set": tset_name, "K": None, "tau_start": None, "schedule": "baseline", **out}) + + for K in Ks: + # Controls: constant τ=1.0 (≈post-hoc) and τ=0.01 (hard from start) + for ctrl_name, ctrl_tau in [("const_t1.0", 1.0), ("const_t0.01", 0.01)]: + counter += 1 + print(f"[{counter}/{n_total}] {tset_name} K={K} {ctrl_name}") + out = _evaluate( + model, + kernel, + w_seed, + max_elements=K, + tau_start=ctrl_tau, + tau_end=ctrl_tau, + schedule="constant", + targets=tgt, + ) + results.append({"target_set": tset_name, "K": K, "tau_start": ctrl_tau, "schedule": ctrl_name, **out}) + + # The sweep + for tau_start, sched in itertools.product(tau_starts, schedules): + counter += 1 + print(f"[{counter}/{n_total}] {tset_name} K={K} τ_start={tau_start} sched={sched}") + out = _evaluate( + model, + kernel, + w_seed, + max_elements=K, + tau_start=tau_start, + tau_end=tau_end, + schedule=sched, + targets=tgt, + ) + results.append({"target_set": tset_name, "K": K, "tau_start": tau_start, "schedule": sched, **out}) + + OUT_JSON.write_text(json.dumps(results, indent=2)) + print(f"\n[saved] {OUT_JSON}") + _print_summary(results) + + +def _print_summary(results: list[dict]) -> None: + """Pivot the JSON into a per-target-set, per-K markdown table.""" + print("\n" + "=" * 100) + print("SUMMARY") + print("=" * 100) + for tset in ("2T", "1T"): + subset = [r for r in results if r["target_set"] == tset] + print(f"\n### Target set: {tset}") + tgt_keys = list(subset[0]["achieved"].keys()) if subset else [] + header = f"{'config':<35} {'nz_mean':>8} {'QC':>6} " + " ".join(f"{t[:10]:>10}" for t in tgt_keys) + print(header) + print("-" * len(header)) + + # Sort: baseline → controls → sweep (by K, then tau_start, then schedule) + def _key(r): + sched_order = {"baseline": 0, "const_t1.0": 1, "const_t0.01": 2, "geometric": 3, "linear": 4, "cosine": 5} + K = r["K"] if r["K"] is not None else 0 + tau = r["tau_start"] if r["tau_start"] is not None else 0.0 + return (K, tau, sched_order.get(r["schedule"], 99)) + + for r in sorted(subset, key=_key): + tag = "baseline" if r["schedule"] == "baseline" else f"K{r['K']} {r['schedule']:<10} τ0={r['tau_start']:<4}" + nz_mean = sum(r["nz_per_row"]) / len(r["nz_per_row"]) + row = f"{tag:<35} {nz_mean:>8.2f} {r['qc_after']:>6.3f} " + row += " ".join(f"{r['achieved'][t]['mean']:>+5.2f}±{r['achieved'][t]['std']:.2f}" for t in tgt_keys) + print(row) + + +if __name__ == "__main__": + main() diff --git a/logs/test_max_elements_smoke.py b/logs/test_max_elements_smoke.py new file mode 100644 index 0000000..13af9f1 --- /dev/null +++ b/logs/test_max_elements_smoke.py @@ -0,0 +1,154 @@ +"""Real-model smoke test for max_elements. + +Loads the inverse-design fine-tuned model (the same checkpoint paper_inverse_3scenarios uses), +constructs the right KMD kernel via ContinualRehearsalRunner, and compares: + +1. Baseline (no max_elements) — existing behaviour. +2. max_elements=3 / 5 with geometric annealing — should give exactly K-element recipes. +3. max_elements=3 with constant-τ — control showing annealing matters. + +Reports: non-zero counts, achieved regression targets, QC probability, recipe row 0. +""" + +from __future__ import annotations + +import tomllib +import time +from pathlib import Path + +import torch +from lightning import seed_everything + +from foundation_model.scripts.continual_rehearsal_demo import ( + QC_CLASSES, + ContinualRehearsalConfig, + ContinualRehearsalRunner, +) +from foundation_model.scripts.eval_inverse_methods import _qc_prob, _seed_weights_from_compositions +from foundation_model.utils.kmd_plus import DEFAULT_ELEMENTS + +REPO = Path(__file__).resolve().parents[1] +CFG_PATH = REPO / "samples/continual_rehearsal_demo_config_inverse_baseline.toml" +CKPT = REPO / "artifacts/inverse_design_run/finetune/final_model.pt" +SCENARIO = "scenario1_fe_down_magnetic_up" + + +def _build(): + """Mirror the loading dance from paper_inverse_comparison.run().""" + raw = tomllib.loads(CFG_PATH.read_text(encoding="utf-8")) + # Inject scenario 1's overrides (formation_energy down, magnetization up). + scenarios = raw.pop("inverse_scenarios", []) + sc = next(s for s in scenarios if s["name"] == SCENARIO) + raw["inverse_reg_tasks"] = sc["reg_tasks"] + raw["inverse_reg_targets"] = sc["reg_targets"] + # Drop fields ContinualRehearsalConfig doesn't accept (inverse_seed_explicit_append OK). + config = ContinualRehearsalConfig(**raw) + seed_everything(config.random_seed, workers=True) + + runner = ContinualRehearsalRunner(config) + model = runner._build_full_model() + state = torch.load(CKPT, map_location="cpu", weights_only=True) + state_dict = state["model"] if isinstance(state, dict) and "model" in state else state + model.load_state_dict(state_dict) + model.eval() + + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + kernel = runner._kmd.kernel_torch(device=device, dtype=dtype) + return runner, model, kernel, config, device + + +def _report(label, res, B): + w = res.optimized_weights + nz = (w > 1e-4).sum(dim=-1) + targets = res.optimized_target.cpu().numpy() + print(f" [{label}]") + print(f" non-zero(>1e-4): {nz.tolist()} mean={nz.float().mean():.2f}") + print(f" formation_energy = {targets[:, 0].mean():+.3f} ± {targets[:, 0].std():.3f} (target -2.0)") + print(f" magnetization = {targets[:, 1].mean():+.3f} ± {targets[:, 1].std():.3f} (target +2.0)") + # Row 0 recipe + w0 = w[0].detach().cpu() + top = sorted(((float(w0[i]), DEFAULT_ELEMENTS[i]) for i in range(len(w0)) if float(w0[i]) > 1e-4), reverse=True) + print(f" row 0 recipe ({len(top)} elements): " + ", ".join(f"{s}={v:.3f}" for v, s in top)) + + +def main() -> None: + print(f"[loading] checkpoint={CKPT}") + runner, model, kernel, config, device = _build() + print(f"[loaded] kernel shape={kernel.shape}, encoder input_dim={getattr(model.encoder, 'input_dim', None)}") + + # Seed compositions (paper script's top-QC selection) + def _qc_fn(x): + return _qc_prob(model, x) + + seeds = runner._select_seeds(model, device, _qc_fn)[:8] # only 8 for speed + print(f"[seeds] {seeds}") + n = kernel.shape[0] + w_seed = _seed_weights_from_compositions(seeds, n_components=n) + + targets = {"formation_energy": -2.0, "magnetization": 2.0} + common = dict( + task_targets=targets, + class_targets={"material_type": QC_CLASSES}, + class_target_weight=5.0, + initial_weights=w_seed, + seed_blend=0.95, + steps=300, + lr=0.05, + ) + + print("\n[run 1] baseline (no max_elements)") + torch.manual_seed(0) + t0 = time.perf_counter() + res0 = model.optimize_composition(kernel, **common) + el0 = time.perf_counter() - t0 + print(f" elapsed {el0:.2f}s") + _report("baseline", res0, len(seeds)) + + for K in (3, 5): + torch.manual_seed(0) + print(f"\n[run] max_elements={K} (default annealing_scale=0.5)") + t0 = time.perf_counter() + res = model.optimize_composition( + kernel, + max_elements=K, + record_weights_trajectory=True, + **common, + ) + el = time.perf_counter() - t0 + print(f" elapsed {el:.2f}s (overhead vs baseline: {el - el0:+.2f}s)") + _report(f"K={K}", res, len(seeds)) + # annealing visualisation + traj = res.weights_trajectory + nz_t = (traj > 1e-3).sum(dim=-1).float().mean(dim=-1) + chk = [0, 30, 100, 200, 290, 299] + print(" annealing nz over trajectory: " + ", ".join(f"step{s}={nz_t[s]:.1f}" for s in chk if s < len(nz_t))) + + torch.manual_seed(0) + print("\n[run] max_elements=3, annealing_scale=0.0 (no exploration, τ_start=1)") + res_c = model.optimize_composition(kernel, max_elements=3, annealing_scale=0.0, **common) + _report("K=3 scale=0.0", res_c, len(seeds)) + + torch.manual_seed(0) + print("\n[run] max_elements=3, annealing_scale=1.0 (max exploration, τ_start=25)") + res_h = model.optimize_composition(kernel, max_elements=3, annealing_scale=1.0, **common) + _report("K=3 scale=1.0", res_h, len(seeds)) + + torch.manual_seed(0) + print("\n[run] max_elements=3, advanced dict (linear warm-up, then geometric tail)") + res_d = model.optimize_composition( + kernel, + max_elements=3, + annealing_scale=0.5, + annealing_schedule={ + "step": [0.3, 0.6], + "scale": [0.9, 0.5], + "annealing_func": ["linear", "linear"], + }, + **common, + ) + _report("K=3 dict (warm-up→linear→tail)", res_d, len(seeds)) + + +if __name__ == "__main__": + main() diff --git a/src/foundation_model/models/flexible_multi_task_model.py b/src/foundation_model/models/flexible_multi_task_model.py index cfdf4af..7dd3ac2 100644 --- a/src/foundation_model/models/flexible_multi_task_model.py +++ b/src/foundation_model/models/flexible_multi_task_model.py @@ -73,7 +73,14 @@ # :meth:`optimize_composition`; when present it has shape ``(steps, B, n_components)``. CompositionOptimizationResult = namedtuple( "CompositionOptimizationResult", - ["optimized_weights", "optimized_descriptor", "optimized_target", "initial_score", "trajectory", "weights_trajectory"], + [ + "optimized_weights", + "optimized_descriptor", + "optimized_target", + "initial_score", + "trajectory", + "weights_trajectory", + ], defaults=[None], ) @@ -2236,7 +2243,12 @@ def optimize_composition( diversity_scale: float = 1.0, allowed_elements: str | list[str] = "all", element_step_scale: float | Mapping[str, float] = 1.0, + fixed_amounts: Mapping[str, float] | None = None, + min_nonzero_weight: float = 0.0, seed_blend: float = 0.95, + max_elements: int | None = None, + annealing_scale: float = 0.5, + annealing_schedule: Mapping[str, Any] | None = None, steps: int = 300, lr: float = 0.05, record_weights_trajectory: bool = False, @@ -2318,6 +2330,71 @@ def optimize_composition( it to the rest of the row, so this is a soft preference, not a hard guarantee. Symbols are resolved against ``DEFAULT_ELEMENTS`` (kernel alignment required, as above). + fixed_amounts : Mapping[str, float] | None, optional + Pin specific elements at user-specified weights for the entire optimisation; the + optimiser distributes the remaining mass ``1 − Σ fixed_amounts.values()`` across + the unfixed elements freely. + + Example: ``{"Au": 0.65, "Ga": 0.20}`` produces recipes with Au exactly 65 % and + Ga exactly 20 %; the remaining 15 % is split among other allowed elements as the + objective prefers. + + Implementation reuses the same lock-paste machinery as ``element_step_scale = 0``: + a per-row tensor ``locked_w0`` is built with the user's amounts at the named + positions; ``_w_from_logits`` overwrites those positions every step and + renormalises the unlocked positions over ``1 − Σ locked``. + + Constraints: + * Each symbol must be in :data:`DEFAULT_ELEMENTS` (kernel alignment required). + * Each amount must be in ``(0, 1)``; ``Σ values < 1.0`` (need free mass). + * If ``allowed_elements`` is set, every fixed element must also be in the + whitelist (locking outside the whitelist is contradictory). + * If ``element_step_scale = 0`` is also used, the two sets of locked symbols + **must not overlap** — use one mechanism per element. + * If ``max_elements`` is also set, fixed elements count toward K (they're + always in the selection); strict inequality ``max_elements > n_locked_total`` + is enforced. + + Unlike ``element_step_scale = 0``'s hard lock, ``fixed_amounts`` does **not** + require ``initial_weights`` — the lock values come straight from this kwarg. + min_nonzero_weight : float, optional + Lower bound on every unlocked element's final weight: positions with + ``0 < w < min_nonzero_weight`` are zeroed out and their mass is redistributed across + the remaining unlocked positions. Default ``0.0`` (no floor). + + Use case: avoid trace-amount appearances (e.g. ``Pt = 0.5%``) that are not + synthesisable — "if you use it, use ≥ 10%". + + Implementation: applied as the *last* step in ``_w_from_logits`` (after soft top-K + and lock-paste) and again after the final hard top-K projection. Locked elements + (from ``element_step_scale = 0`` or ``fixed_amounts``) are **not** subject to the + floor — their values are set explicitly by the user. + + Constraints: + * ``0 ≤ min_nonzero_weight ≤ 1``. + * If ``max_elements`` is set: ``min_nonzero_weight ≤ 1 / max_elements`` (otherwise + ``K`` elements each ≥ floor can't sum to ≤ 1). + * If ``fixed_amounts`` is set: every fixed value must be ≥ floor (else + contradiction). + * If ``element_step_scale = 0`` locks with ``initial_weights`` are present: every + locked seed value must be ≥ floor (checked at runtime once the seed is + normalised). + + Edge case: if dropping every below-floor position would leave a row with zero + unlocked mass (no element survives), the floor is skipped *for that row only* — + preserving the simplex (rows always sum to 1). When this happens, the row will + contain unlocked positions below ``min_nonzero_weight``; if you see this in + practice your floor is too aggressive for the model's preferred subset. + + Practical note: when ``max_elements`` is not set, no upper bound on the floor is + enforced beyond ``floor ≤ 1``. A very large floor (e.g. 0.5 with 94 components) will + silently trigger the per-row fallback on almost every row — the result is a valid + simplex but the floor is effectively ignored. Pair the floor with ``max_elements`` + (which enforces ``floor ≤ 1 / max_elements``) when you want a hard guarantee. + + "At most K" implication: when combined with ``max_elements``, the floor can drop + below-floor positions in the K-subset, so the final non-zero count can be **less + than K** (still ≤ K — the user-facing promise is unchanged). seed_blend : float, optional How much of the (per-row) seed prior to keep when ``initial_weights`` is given; ``w0 ← seed_blend · seed + (1 − seed_blend) · uniform_over_allowed``. Default ``0.95`` @@ -2327,6 +2404,82 @@ def optimize_composition( when they help the objective. Set to ``1.0`` to reproduce the strict seed-only behaviour (no new elements can enter the support set); ``0.0`` makes the seed irrelevant and starts from uniform. Ignored when ``initial_weights is None``. + max_elements : int | None, optional + If set, restricts the final composition to at most this many non-zero elements. + Unlike a naive post-hoc top-K projection, the constraint **participates in + optimisation throughout** via a differentiable iterative-softmax K-hot mask + (Plötz–Roth, NeurIPS 2018) coupled with a temperature-annealing schedule. + + How it works in one paragraph: at each step we compute a soft K-hot mask + ``m ∈ [0,1]^n`` with ``Σm = K`` from the same logits the softmax uses, then form + ``w = (softmax(lg) · m) / Σ(softmax(lg) · m)``. Temperature ``τ`` controls how + "K-hot" ``m`` is: large τ → uniform-ish (the constraint is soft, gradient can flow + between candidate subsets), small τ → near one-hot per iteration (constraint is hard). + τ is driven by the ``annealing_scale`` / ``annealing_schedule`` kwargs below — by + default a geometric schedule from ``25**annealing_scale`` down to a fixed + ``τ_end = 0.01``. The annealing doubles as a continuation method that helps escape + local optima. + + After the loop, a final hard top-K projection is applied so the returned + ``optimized_weights`` has **at most** ``max_elements`` non-zero positions (subject + to any locked elements, which are always counted toward K — see below). The + count saturates at K when the optimiser left at least K positions with positive + ``w_soft`` mass; if it drove some logits all the way to zero, the row can land + below K — this is by design, not a bug ("at most K" is the user-facing promise). + + Constraints: + * ``1 ≤ max_elements ≤ n_components``. + * If any element is hard-locked via ``element_step_scale=0``, the lock counts + toward K; require ``max_elements ≥ n_locked``. + * If ``allowed_elements`` restricts the support, require ``max_elements ≤ |allowed|``. + + ``None`` (default) or ``max_elements == n_components`` disables the constraint. + annealing_scale : float, optional + Single-knob "softness" of the annealing schedule, normalised to ``[0, 1]``. + Default ``0.5``. Maps internally to raw temperature via ``τ_start = 25**scale``: + + * ``0.0`` → ``τ_start = 1.0`` (no exploration; constraint hard from the start) + * ``0.5`` → ``τ_start = 5.0`` (default; safe choice — QC stable, decent targets) + * ``1.0`` → ``τ_start = 25.0`` (max exploration; longer soft phase) + + The full schedule is geometric from ``τ_start(scale)`` down to ``τ_end = 0.01``. + Ignored when ``max_elements`` is None. + + **Calibration**: the 0.5 default was picked from a sweep on the inverse-design + fine-tuned model (300 steps, K∈{3, 5}; see ``logs/sweep_tau_schedule.png``). Across + the 3 paper scenarios it keeps QC within ±0.02 of the unconstrained baseline while + hitting K=3/5 cardinality. For aggressive target chasing, raise toward 0.8-1.0 + (and consider an advanced schedule with ``annealing_func="linear"`` to hold the + soft phase longer). For QC priority, leave at 0.5. + annealing_schedule : dict | None, optional + Advanced piecewise schedule. **Overrides the front of the simple schedule.** + When supplied, this dict takes precedence over ``annealing_scale``'s implicit + schedule for the steps it covers. The format is three parallel lists of length N: + + .. code-block:: python + + { + "step": [0.2, 0.5, 1.0], # fractional step boundaries (0,1] + "scale": [0.8, 0.5, 0.5], # normalised scale [0,1] at each boundary + "annealing_func": ["geometric", "geometric", "geometric"], # interpolation in each segment + } + + **Reading the dict**: the schedule starts at step=0 from the value given by + ``annealing_scale``. Segment ``i`` covers ``(step[i-1], step[i]]`` (with + ``step[-1] := 0``); within that segment, the normalised scale interpolates from the + previous segment's endpoint (or ``annealing_scale`` for segment 0) to ``scale[i]`` + using ``annealing_func[i]``. The interpolated scale is then mapped to raw τ via the + same ``25**scale`` formula used by ``annealing_scale``. + + **If ``step[-1] < 1.0``**, the remaining ``(step[-1], 1.0]`` portion continues with + a default geometric tail: from the raw τ value at ``step[-1]`` (i.e. + ``25**scale[-1]``) down to ``τ_end = 0.01``. This guarantees the schedule always + reaches the hard end inside the loop (the final hard-projection cleans up K-hot + either way). + + **Allowed annealing_func values**: ``"geometric"``, ``"linear"``, ``"cosine"``, + ``"constant"``. ``"constant"`` holds the segment's starting value (``scale[i]`` is + ignored — useful for warm-up phases). steps : int Adam optimisation steps. Default 300. lr : float @@ -2456,6 +2609,155 @@ def optimize_composition( f"element_symbol → float; got {type(element_step_scale).__name__}." ) + # --- Validate fixed_amounts (per-element explicit pinning) ------------------------------- + # Build the (n_components,) tensors lazily: ``fixed_w0_vec`` (per-element pinned value, + # zero elsewhere) and ``fixed_mask_vec`` (bool: True at pinned positions). The actual + # batch-shaped ``locked_w0`` is materialised later (alongside step_scale=0 locks) once we + # know the batch size. + fixed_w0_vec: torch.Tensor | None = None + fixed_mask_vec: torch.Tensor | None = None + if fixed_amounts is not None: + if not isinstance(fixed_amounts, Mapping): + raise TypeError( + f"fixed_amounts must be a mapping of element_symbol → float or None; " + f"got {type(fixed_amounts).__name__}." + ) + if len(fixed_amounts) == 0: + raise ValueError("fixed_amounts must be non-empty when provided.") + sym_to_idx = {s: i for i, s in enumerate(DEFAULT_ELEMENTS)} + bad_syms = [s for s in fixed_amounts if s not in sym_to_idx] + if bad_syms: + raise ValueError(f"Unknown element symbol(s) in fixed_amounts: {bad_syms}.") + if n_components != len(DEFAULT_ELEMENTS): + raise ValueError( + f"fixed_amounts requires the kernel to align with DEFAULT_ELEMENTS " + f"(n_components={n_components}, expected {len(DEFAULT_ELEMENTS)})." + ) + for sym, amt in fixed_amounts.items(): + if not 0.0 < float(amt) < 1.0: + raise ValueError(f"fixed_amounts['{sym}']={amt} must be strictly between 0 and 1.") + total = float(sum(fixed_amounts.values())) + if total >= 1.0: + raise ValueError( + f"sum(fixed_amounts.values())={total:.4f} must be strictly less than 1.0 " + "(the optimiser needs unfixed mass to allocate)." + ) + # Allowed-list compatibility — pinning outside the whitelist is contradictory. + if elem_mask_arg is not None: + bad_against_allowed = [s for s in fixed_amounts if not elem_mask_arg[sym_to_idx[s]]] + if bad_against_allowed: + raise ValueError( + f"fixed_amounts symbols {bad_against_allowed} are not in allowed_elements — " + "pinning a disallowed element is contradictory." + ) + # Mutual exclusion with element_step_scale = 0 (the other hard-lock path). + if step_scale_arg is not None: + overlap = [s for s in fixed_amounts if float(step_scale_arg[sym_to_idx[s]]) == 0.0] + if overlap: + raise ValueError( + f"Symbols {overlap} appear in both element_step_scale=0 and " + "fixed_amounts. Use one mechanism per element." + ) + fixed_w0_vec = torch.zeros(n_components) + fixed_mask_vec = torch.zeros(n_components, dtype=torch.bool) + for sym, amt in fixed_amounts.items(): + idx = sym_to_idx[sym] + fixed_w0_vec[idx] = float(amt) + fixed_mask_vec[idx] = True + + # --- Validate min_nonzero_weight (per-element floor) ------------------------------------- + if not 0.0 <= min_nonzero_weight <= 1.0: + raise ValueError(f"min_nonzero_weight must be in [0, 1]; got {min_nonzero_weight}.") + if min_nonzero_weight > 0.0: + # If max_elements is set, the floor must be feasible: K elements ≥ floor summing to 1 + # implies K * floor ≤ 1. + if max_elements is not None and min_nonzero_weight > 1.0 / max_elements: + raise ValueError( + f"min_nonzero_weight={min_nonzero_weight} exceeds 1 / max_elements=" + f"{1.0 / max_elements:.4f}. With at most {max_elements} non-zero positions, " + "no row can have every weight ≥ floor and still sum to 1." + ) + # Fixed amounts must themselves be ≥ the floor (else contradiction). + if fixed_amounts is not None: + bad = sorted((s, v) for s, v in fixed_amounts.items() if float(v) < min_nonzero_weight) + if bad: + raise ValueError( + f"fixed_amounts entries {bad} are below min_nonzero_weight=" + f"{min_nonzero_weight}. The floor cannot override an explicit pin." + ) + + # --- Validate cardinality constraint (max_elements + annealing knobs) ----------------------- + if max_elements is not None: + if not isinstance(max_elements, int) or isinstance(max_elements, bool): + raise TypeError(f"max_elements must be an int or None; got {type(max_elements).__name__}.") + if not 1 <= max_elements <= n_components: + raise ValueError(f"max_elements must be in [1, n_components={n_components}]; got {max_elements}.") + if elem_mask_arg is not None: + n_allowed = int(elem_mask_arg.sum().item()) + if max_elements > n_allowed: + raise ValueError( + f"max_elements={max_elements} exceeds the number of allowed elements " + f"({n_allowed}). Widen ``allowed_elements`` or lower ``max_elements``." + ) + # Lock-vs-K check: locked positions (element_step_scale=0 ∪ fixed_amounts) all count + # toward K. We require *strict* ``max_elements > n_locked`` for both lock paths: + # equality leaves the lock-paste with no unlocked slot to absorb the leftover mass + # (1 − Σ locked) and produces rows that sum to < 1 — silently breaking the simplex. + # For ``fixed_amounts`` this is definite (``Σ < 1`` enforced at kwarg time); for + # ``element_step_scale=0`` the seed values *could* sum to exactly 1, but K-constrained + # all-locked recipes have no degrees of freedom anyway, so rejecting equality is + # both safe and clearer. + n_locked_pre = 0 + if step_scale_arg is not None: + n_locked_pre += int((step_scale_arg == 0).sum().item()) + if fixed_mask_vec is not None: + n_locked_pre += int(fixed_mask_vec.sum().item()) + if n_locked_pre >= max_elements: + raise ValueError( + f"max_elements={max_elements} must be > total locked elements ({n_locked_pre}, " + "counting element_step_scale=0 ∪ fixed_amounts) — the lock-paste needs at " + "least one unlocked slot to absorb the leftover mass (1 − Σ locked); equality " + "would silently produce row sums < 1. Raise max_elements or unlock some." + ) + if not 0.0 <= annealing_scale <= 1.0: + raise ValueError(f"annealing_scale must be in [0, 1]; got {annealing_scale}.") + if annealing_schedule is not None: + if not isinstance(annealing_schedule, Mapping): + raise TypeError(f"annealing_schedule must be a mapping; got {type(annealing_schedule).__name__}.") + missing = {"step", "scale", "annealing_func"} - set(annealing_schedule) + if missing: + raise ValueError( + f"annealing_schedule missing required keys {sorted(missing)}. " + "Required: step, scale, annealing_func — all parallel lists." + ) + sched_steps = list(annealing_schedule["step"]) + sched_scales = list(annealing_schedule["scale"]) + sched_funcs = list(annealing_schedule["annealing_func"]) + if not (len(sched_steps) == len(sched_scales) == len(sched_funcs)): + raise ValueError( + f"annealing_schedule lists must be the same length; got " + f"step={len(sched_steps)}, scale={len(sched_scales)}, " + f"annealing_func={len(sched_funcs)}." + ) + if len(sched_steps) == 0: + raise ValueError("annealing_schedule lists must be non-empty.") + prev_s = 0.0 + for s in sched_steps: + if not 0.0 < float(s) <= 1.0: + raise ValueError(f"annealing_schedule['step'] entries must be in (0, 1]; got {s}.") + if float(s) <= prev_s: + raise ValueError(f"annealing_schedule['step'] must be strictly increasing; got {sched_steps}.") + prev_s = float(s) + for t in sched_scales: + if not 0.0 <= float(t) <= 1.0: + raise ValueError(f"annealing_schedule['scale'] entries must be in [0, 1]; got {t}.") + allowed_funcs = ("geometric", "linear", "cosine", "constant") + for f in sched_funcs: + if f not in allowed_funcs: + raise ValueError( + f"annealing_schedule['annealing_func'] entries must be one of {allowed_funcs}; got {f!r}." + ) + # --- Validate the seed (BEFORE touching model state, so a bad input doesn't leave the # model in eval() / with params switched off). --------------------------------------- if initial_weights is None: @@ -2542,18 +2844,31 @@ def optimize_composition( # Move the element-constraint tensors onto the right device (validated above). elem_mask = elem_mask_arg.to(device=device) if elem_mask_arg is not None else None step_scale = step_scale_arg.to(device=device, dtype=dtype) if step_scale_arg is not None else None - - # --- Hard-lock setup for elements with step_scale == 0 ---------------------------------- - # Zeroing ``logit_i.grad`` keeps that logit constant but does NOT keep ``w_i`` constant, - # because softmax renormalises across all logits — when other (unlocked) logits move, the - # softmax denominator changes and so does the locked weight. To truly honour the docstring - # promise "freezes those elements at their seed values", we (a) detect locked indices, (b) - # capture their per-row seed weights, and (c) inside ``_w_from_logits`` paste those seed - # values back over the softmax output and renormalise the unlocked positions to fill the - # remaining ``1 − Σ locked_w`` mass per row. The gradient through the locked indices is - # automatically zero (the lock branch uses a constant), so we no longer need the - # ``step_scale.mul_`` zeroing for them — but we leave that path active for the genuinely - # soft case ``0 < step_scale < 1``. + fixed_w0_dev = fixed_w0_vec.to(device=device, dtype=dtype) if fixed_w0_vec is not None else None + fixed_mask_dev = fixed_mask_vec.to(device=device) if fixed_mask_vec is not None else None + + # --- Hard-lock setup ---------------------------------------------------------------------- + # Two hard-lock sources both end up in the same ``(locked_mask, locked_w0)`` pair so the + # downstream ``_w_from_logits`` / ``_apply_lock_paste`` logic is unchanged: + # + # 1. ``element_step_scale = 0``: pins the listed elements at their (un-blended) + # ``initial_weights`` values. Requires ``initial_weights`` because there's no other + # source for per-row seed values. + # 2. ``fixed_amounts``: pins the listed elements at user-given absolute amounts. No + # ``initial_weights`` required — the lock values come straight from the kwarg. + # + # The two paths must not overlap (validated above). When both are present, we just + # OR the masks and add the value tensors (disjoint by construction). + # + # Why this matters: zeroing ``logit_i.grad`` keeps that logit constant but does NOT keep + # ``w_i`` constant — softmax renormalises across all logits, so when other (unlocked) + # logits move, the softmax denominator changes and so does the locked weight. The fix + # is to (a) detect locked indices, (b) capture their per-row target weights, and (c) + # inside ``_w_from_logits`` paste those values back over the softmax output and + # renormalise the unlocked positions to fill the remaining ``1 − Σ locked_w`` mass per + # row. The gradient through the locked indices is automatically zero (the lock branch + # uses a constant), so we no longer need the ``step_scale.mul_`` zeroing for them — + # but we leave that path active for the genuinely soft case ``0 < step_scale < 1``. locked_mask: torch.Tensor | None = None locked_w0: torch.Tensor | None = None if step_scale is not None: @@ -2572,24 +2887,248 @@ def optimize_composition( locked_mask = locked_idx_mask # (n_components,) bool, on device # (B, n_components): seed values at locked positions, 0 elsewhere — constant. locked_w0 = (w0_seed * locked_mask.to(dtype)).detach() + if fixed_mask_dev is not None: + # Broadcast the per-element fixed values to every row in the batch. + B = logits.shape[0] + fixed_w0_batch = fixed_w0_dev.unsqueeze(0).expand(B, -1).detach() + if locked_mask is None: + locked_mask = fixed_mask_dev + locked_w0 = fixed_w0_batch + else: + locked_mask = locked_mask | fixed_mask_dev # validated disjoint + locked_w0 = locked_w0 + fixed_w0_batch + + # Runtime sanity: combined lock sum must leave room (or fit exactly) for the simplex. + # ``fixed_amounts`` enforces ``Σ < 1`` at kwarg time, and ``element_step_scale=0`` + # locks at seed values which sum to ≤ 1 per row — but the *combined* total could + # exceed 1 (e.g. seed-lock Mg=0.50 + fix Au=0.65). Check here, with a tiny tolerance + # for float noise. + if locked_w0 is not None: + lock_sums = locked_w0.sum(dim=-1) + if (lock_sums > 1.0 + 1e-5).any(): + raise ValueError( + f"Combined locked mass exceeds 1.0 on at least one row " + f"(max row-sum = {float(lock_sums.max()):.4f}). Likely cause: " + "``element_step_scale=0`` locks plus ``fixed_amounts`` together claim more " + "than 100% of the simplex. Lower one set of values or drop a lock." + ) - def _w_from_logits(lg: torch.Tensor) -> torch.Tensor: - """Softmax over logits; mask disallowed elements; hard-lock the chosen ones at seed.""" - if elem_mask is not None: - lg = lg.masked_fill(~elem_mask, float("-inf")) - w = torch.softmax(lg, dim=-1) + # Runtime sanity: floored elements must not contradict the lock-paste targets. + # ``fixed_amounts`` was checked at kwarg time; ``element_step_scale=0`` locks have + # per-row seed values we couldn't see earlier — verify them now. + if min_nonzero_weight > 0.0 and locked_mask is not None and locked_w0 is not None: + locked_below_floor = (locked_w0 > 0) & (locked_w0 < min_nonzero_weight) + if locked_below_floor.any(): + raise ValueError( + f"At least one locked element's value falls below min_nonzero_weight=" + f"{min_nonzero_weight}. Likely cause: an element_step_scale=0 lock points " + "at a seed value below the floor (raise the seed, lower the floor, or " + "drop the lock)." + ) + + # --- Soft top-K (cardinality constraint) helpers ---------------------------------------- + # Schedule shape (controlled by ``annealing_scale`` and optionally ``annealing_schedule``): + # + # * Normalised scale ∈ [0, 1] is the user-facing knob; raw τ is derived via + # ``τ = 25**scale`` (so scale=0 → τ=1, scale=0.5 → τ=5, scale=1 → τ=25). + # * Default schedule when no dict is given: geometric from ``τ_start=25**annealing_scale`` + # at fractional step 0 down to ``_TAU_END=0.01`` at fractional step 1. + # * When ``annealing_schedule`` dict is provided, its segments override the front of + # the schedule; the segment from ``step[-1]`` to 1.0 (if not already at 1.0) falls + # back to the geometric tail from ``25**scale[-1]`` down to ``_TAU_END``. + # + # ``current_tau`` lives in a list so the optimisation loop can mutate it each step + # without rebuilding the ``_w_from_logits`` closure that reads it. + _TAU_FLOOR = 1e-3 # numerical lower bound; below this softmax(lg/τ) loses precision + _TAU_END = 0.01 # fixed final hardness for the default schedule's tail + _SCALE_TAU_BASE = 25.0 # τ = _SCALE_TAU_BASE**scale → 0→1, 0.5→5, 1→25 + + def _scale_to_tau(scale: float) -> float: + return float(_SCALE_TAU_BASE ** max(0.0, min(1.0, scale))) + + def _interp_scalar(a: float, b: float, t: float, func: str) -> float: + """Interpolate from ``a`` to ``b`` at local-time ``t`` ∈ [0, 1].""" + if func == "constant": + return a + if func == "linear": + return a + (b - a) * t + if func == "cosine": + return b + 0.5 * (a - b) * (1.0 + math.cos(math.pi * t)) + # geometric — guard against zero/sign issues by working in log space when both >0. + if a > 0.0 and b > 0.0: + return a * (b / a) ** t + # Fall back to linear for degenerate cases (shouldn't trigger in normal use). + return a + (b - a) * t + + # Materialise schedule arrays once (validated above), so the per-step lookup is light. + _sched_steps: list[float] = ( + [float(s) for s in annealing_schedule["step"]] if annealing_schedule is not None else [] + ) + _sched_scales: list[float] = ( + [float(t) for t in annealing_schedule["scale"]] if annealing_schedule is not None else [] + ) + _sched_funcs: list[str] = ( + list(annealing_schedule["annealing_func"]) if annealing_schedule is not None else [] + ) + + def _tau_for_step(step: int) -> float: + """Return the raw τ for integer optimisation step ``step``.""" + if max_elements is None or steps <= 1: + return float(max(_TAU_END, _TAU_FLOOR)) + # Fractional progress in [0, 1]. + s = step / (steps - 1) + # Default schedule (used directly when no dict, or for the tail when dict ends < 1.0). + default_tau_start = _scale_to_tau(annealing_scale) + default_tau_end = _TAU_END + + if _sched_steps: + # Walk through dict segments to find the one containing ``s``. + prev_step = 0.0 + prev_scale = annealing_scale # segment 0 starts at the simple knob's value + for i, seg_end in enumerate(_sched_steps): + if s <= seg_end: + local_t = (s - prev_step) / max(seg_end - prev_step, 1e-12) + scale_now = _interp_scalar(prev_scale, _sched_scales[i], local_t, _sched_funcs[i]) + return float(max(_scale_to_tau(scale_now), _TAU_FLOOR)) + prev_step = seg_end + prev_scale = _sched_scales[i] + # ``s`` is past the dict's last step → use the geometric tail from + # ``25**scale[-1]`` at ``step[-1]`` down to ``_TAU_END`` at 1.0. + tail_start_tau = _scale_to_tau(_sched_scales[-1]) + tail_end_step = 1.0 + tail_local_t = (s - _sched_steps[-1]) / max(tail_end_step - _sched_steps[-1], 1e-12) + val = tail_start_tau * (default_tau_end / tail_start_tau) ** tail_local_t + return float(max(val, _TAU_FLOOR)) + + # No dict — default geometric schedule from τ_start(annealing_scale) to _TAU_END. + val = default_tau_start * (default_tau_end / default_tau_start) ** s + return float(max(val, _TAU_FLOOR)) + + current_tau = [_tau_for_step(0)] + + def _soft_topk_mask( + lg: torch.Tensor, K: int, tau: float, *, force_select: torch.Tensor | None = None + ) -> torch.Tensor: + """Plötz–Roth iterative softmax. Returns m ∈ [0,1]^(B, n) with Σm = K. + + ``force_select`` (n_components,) bool marks positions that must be in the K + selection (e.g. hard-locked elements). Instead of boosting those logits — which + would make the iterative softmax pick them K times in a row, never moving on — + we **pre-seed** the mask with 1.0 at those positions and run only ``K - n_locked`` + iterations on the *unlocked* positions (their logits are masked to ``-inf`` + inside the iteration so they never compete). + """ + if force_select is None: + alpha = lg + m = torch.zeros_like(lg) + n_iter = K + else: + # Pre-mark locked positions as fully selected; iterate only on the rest. + n_locked = int(force_select.sum().item()) + n_iter = K - n_locked + locked_row = force_select.to(lg.dtype).unsqueeze(0).expand_as(lg) + m = locked_row.clone() + alpha = lg.masked_fill(force_select, float("-inf")) + for _ in range(n_iter): + p = torch.softmax(alpha / tau, dim=-1) + m = m + p + # The shift in scaled-logit space at the selected position is + # ``log(1−p)/τ`` — at small τ this is enormously negative, so the next + # iteration cannot re-pick the same position. (We must NOT multiply by τ here.) + alpha = alpha + torch.log((1.0 - p).clamp(min=1e-12)) + return m + + def _hard_topk_project(w: torch.Tensor, K: int) -> torch.Tensor: + """Hard top-K projection: keep K largest per row, zero rest, renormalise. + + If ``locked_mask`` is set, every locked position is forced into the kept set + (so the lock-paste below still has a place to write its seed values); the + remaining ``K − n_locked`` slots are filled by the largest unlocked weights. + """ + if locked_mask is None: + _, idx = w.topk(K, dim=-1) + keep = torch.zeros_like(w).scatter_(-1, idx, 1.0) + else: + n_locked = int(locked_mask.sum().item()) + n_free = K - n_locked + locked_row = locked_mask.to(w.dtype).unsqueeze(0).expand_as(w) + if n_free > 0: + # Exclude locked positions from the unlocked competition by sending them + # to ``-inf`` before topk; locked positions are added back via ``locked_row``. + w_for_free = w.masked_fill(locked_mask.unsqueeze(0), float("-inf")) + _, idx = w_for_free.topk(n_free, dim=-1) + free_keep = torch.zeros_like(w).scatter_(-1, idx, 1.0) + keep = (locked_row + free_keep).clamp(max=1.0) + else: + keep = locked_row + w = w * keep + return w / w.sum(dim=-1, keepdim=True).clamp(min=1e-12) + + def _apply_lock_paste(w: torch.Tensor) -> torch.Tensor: + """Paste locked seed values onto ``w`` and renormalise unlocked positions.""" if locked_mask is None: return w - # Locked rows hold their seed values; unlocked rows are renormalised to fill the - # remaining mass ``1 − Σ_locked seed``. Differentiable: the lock branch is a constant - # so its gradient is 0; the unlocked branch's gradient flows through the rescale. - free_mask_f = (~locked_mask).to(w.dtype) # (n_components,) - w_unlocked = w * free_mask_f # zero at locked positions - # type: ignore[union-attr] — locked_w0 is set together with locked_mask above. + free_mask_f = (~locked_mask).to(w.dtype) + w_unlocked = w * free_mask_f free_mass = (1.0 - locked_w0.sum(dim=-1, keepdim=True)).clamp(min=0.0) w_unlocked = w_unlocked / w_unlocked.sum(dim=-1, keepdim=True).clamp(min=1e-12) * free_mass return w_unlocked + locked_w0 + def _apply_min_floor(w: torch.Tensor) -> torch.Tensor: + """Drop unlocked positions below ``min_nonzero_weight`` and re-fill free mass. + + Locked positions are exempt (their values are user-set). If dropping below-floor + positions would leave a row with zero unlocked mass, the floor is skipped for + that row — preserving the simplex invariant. The "at most K" guarantee still + holds; some rows may end up with fewer than K non-zero positions. + """ + if min_nonzero_weight <= 0.0: + return w + if locked_mask is not None: + unlocked_f = (~locked_mask).to(w.dtype) + free_mass = (1.0 - locked_w0.sum(dim=-1, keepdim=True)).clamp(min=0.0) + unlocked_bool = (~locked_mask).unsqueeze(0).expand_as(w) + else: + unlocked_f = torch.ones_like(w[0]) + free_mass = torch.ones(w.shape[0], 1, dtype=w.dtype, device=w.device) + unlocked_bool = torch.ones_like(w, dtype=torch.bool) + below = (w > 0) & (w < min_nonzero_weight) & unlocked_bool + if not below.any(): + return w + w_drop = w.masked_fill(below, 0.0) + # Per-row unlocked sum after the tentative drop. + unlocked_after = w_drop * unlocked_f + unlocked_sum = unlocked_after.sum(dim=-1, keepdim=True) + # Rows where the drop is safe — at least one unlocked position survives. + can_drop = unlocked_sum > 1e-12 + # Renormalise unlocked portion to fit the free mass; locked stays as-is. + safe_sum = unlocked_sum.clamp(min=1e-12) + if locked_mask is not None: + locked_part = w_drop * locked_mask.to(w.dtype) + w_renorm = locked_part + unlocked_after * (free_mass / safe_sum) + else: + w_renorm = w_drop / safe_sum + return torch.where(can_drop.expand_as(w), w_renorm, w) + + def _w_from_logits(lg: torch.Tensor) -> torch.Tensor: + """Softmax → optional soft top-K → optional hard-lock paste → optional min-floor. + + Reads ``current_tau[0]`` (set by the outer loop) for the soft top-K temperature. + """ + if elem_mask is not None: + lg = lg.masked_fill(~elem_mask, float("-inf")) + w_soft = torch.softmax(lg, dim=-1) + if max_elements is not None and max_elements < n_components: + # Force locked positions to always sit in the K-hot mask so the lock-paste + # below has somewhere to write. ``w_soft`` itself is computed from the + # *unboosted* logits, so the within-K ratios reflect the optimisation state. + m_topk = _soft_topk_mask(lg, max_elements, current_tau[0], force_select=locked_mask) + w = w_soft * m_topk + w = w / w.sum(dim=-1, keepdim=True).clamp(min=1e-12) + else: + w = w_soft + return _apply_min_floor(_apply_lock_paste(w)) + def _heads_forward(h_task: torch.Tensor) -> tuple[list[torch.Tensor], list[torch.Tensor]]: """Run regression heads, return (per-task predictions, loss terms).""" preds, terms = [], [] @@ -2616,6 +3155,10 @@ def _stack(values: list[torch.Tensor], B: int) -> torch.Tensor: return torch.stack(values, dim=-1) if values else torch.zeros((B, 0), device=device, dtype=dtype) # --- Record initial scores -------------------------------------------------------------- + # Initial scoring uses τ at step 0 of the annealing schedule — i.e. the softest end + # of the (annealing_scale + annealing_schedule)-derived τ curve, where the optimisation + # actually begins. + current_tau[0] = _tau_for_step(0) with torch.no_grad(): w0_tensor = _w_from_logits(logits) h0 = torch.tanh(self.encoder(w0_tensor @ kmd_kernel)) @@ -2627,7 +3170,8 @@ def _stack(values: list[torch.Tensor], B: int) -> torch.Tensor: # gradient only on ``logits`` — no stale grads accumulate on encoder/heads. trajectory: list[torch.Tensor] = [] weights_trajectory: list[torch.Tensor] = [] if record_weights_trajectory else [] - for _ in range(steps): + for step in range(steps): + current_tau[0] = _tau_for_step(step) optimizer.zero_grad() w = _w_from_logits(logits) x = w @ kmd_kernel @@ -2647,15 +3191,29 @@ def _stack(values: list[torch.Tensor], B: int) -> torch.Tensor: optimizer.step() trajectory.append(_stack([p.detach() for p in preds], logits.shape[0])) if record_weights_trajectory: - # Snapshot the post-step weights (after the softmax+hard-lock applied next iter - # would re-clean them, but the user wants what the *current* recipe looks like). + # Snapshot the post-step weights at the *current* (still-soft) τ — the + # trajectory thus reflects the annealing schedule, not the hard projection. # Stored on CPU to keep GPU memory flat for long trajectories on large B. with torch.no_grad(): weights_trajectory.append(_w_from_logits(logits).detach().cpu()) # --- Final state ------------------------------------------------------------------------ + # Use the hardest τ for the final readout, then (if ``max_elements`` is active) apply + # a hard top-K projection so the returned ``optimized_weights`` has **at most** K + # non-zero positions (the floor below may reduce that further) — at τ_end ≈ 0.01 the + # soft mask is already near-K-hot, so the projection just cleans up residual + # sub-threshold weights. + current_tau[0] = float(max(_TAU_END, _TAU_FLOOR)) with torch.no_grad(): w_final = _w_from_logits(logits) + if max_elements is not None and max_elements < n_components: + w_final = _hard_topk_project(w_final, max_elements) + # Re-apply lock-paste — the projection may have re-distributed mass across + # unlocked positions, and lock-paste's "free mass" renormalisation needs to + # be re-run so the row still sums to exactly 1. Then re-floor: the projection + # may have promoted a previously-zeroed below-floor position back in. + w_final = _apply_lock_paste(w_final) + w_final = _apply_min_floor(w_final) x_final = w_final @ kmd_kernel h_final = torch.tanh(self.encoder(x_final)) final_preds, _ = _heads_forward(h_final) diff --git a/src/foundation_model/models/flexible_multi_task_model_test.py b/src/foundation_model/models/flexible_multi_task_model_test.py index 18a450c..799dc11 100644 --- a/src/foundation_model/models/flexible_multi_task_model_test.py +++ b/src/foundation_model/models/flexible_multi_task_model_test.py @@ -1547,6 +1547,726 @@ def test_optimize_composition_uses_kmd_kernel_torch(): assert torch.allclose(res.optimized_weights.sum(dim=-1), torch.ones(3), atol=1e-5) +def test_optimize_composition_max_elements_enforces_K_cardinality(): + """max_elements=K → final composition has *at most* K non-zero elements per row. + + The hard top-K projection picks K positions, but if any of those has zero ``w_soft`` mass + (can happen when the optimiser drove other logits very negative), it stays at zero after + renormalisation — so the contract is "≤ K", not "= K". On a non-degenerate synthetic + setup K is usually saturated; on a real-model load some rows can land below K. + """ + torch.manual_seed(0) + model, kernel, _ = _build_aligned_model_and_kernel() + K = 3 + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + class_targets={"cls": [1]}, + class_target_weight=2.0, + n_starts=4, + max_elements=K, + steps=120, + lr=0.2, + ) + w = res.optimized_weights + # Simplex preserved. + assert torch.allclose(w.sum(dim=-1), torch.ones(w.shape[0]), atol=1e-5) + assert (w >= 0).all() + # At most K non-zero positions per row. + nz = (w > 1e-6).sum(dim=-1) + assert torch.all(nz <= K), f"expected ≤ {K} non-zero per row, got {nz.tolist()}" + # On this toy setup with uniform-ish init we additionally expect saturation at K. + assert torch.all(nz == K), f"toy model should saturate at K={K}, got {nz.tolist()}" + + +def test_optimize_composition_max_elements_full_is_noop(): + """max_elements == n_components disables the constraint (results match the unconstrained run).""" + torch.manual_seed(0) + model = _make_reg_clf_model() + kernel = torch.randn(6, INPUT_DIM) + init = torch.full((2, 6), 1.0 / 6) + base = model.optimize_composition(kernel, task_targets={"prop": 1.0}, initial_weights=init, steps=30, lr=0.1) + torch.manual_seed(0) + constrained = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + initial_weights=init, + max_elements=6, # == n_components → no-op + steps=30, + lr=0.1, + ) + # max_elements == n_components ⇒ no soft top-K, no hard projection ⇒ identical trajectory. + assert torch.allclose(base.optimized_weights, constrained.optimized_weights, atol=1e-5) + + +def test_optimize_composition_max_elements_with_allowed_elements(): + """When ``allowed_elements`` whitelists the support, top-K picks from inside the whitelist.""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + whitelist = ["Mg", "Al", "Cu", "Ni", "Fe"] + K = 3 + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + n_starts=3, + allowed_elements=whitelist, + max_elements=K, + steps=80, + lr=0.2, + ) + w = res.optimized_weights + nz = (w > 1e-6).sum(dim=-1) + assert torch.all(nz == K) + # Non-whitelisted positions still exactly zero. + forbidden = [i for i, s in enumerate(elements) if s not in whitelist] + assert (w[:, forbidden] == 0).all() + + +def test_optimize_composition_max_elements_keeps_locked_in_support(): + """A hard-locked element must remain non-zero (and at its seed value) even with top-K.""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + # Lock Mg at 0.30; allow the optimiser to pick the other K-1 freely. + init_w = torch.zeros(1, len(elements)) + init_w[0, elements.index("Mg")] = 0.30 + init_w[0, elements.index("Al")] = 0.25 + init_w[0, elements.index("Cu")] = 0.25 + init_w[0, elements.index("Ni")] = 0.20 + K = 3 + res = model.optimize_composition( + kernel, + task_targets={"prop": 5.0}, + initial_weights=init_w, + element_step_scale={"Mg": 0.0}, # hard lock + max_elements=K, + steps=120, + lr=0.3, + ) + w = res.optimized_weights[0] + # Mg is held at its un-blended seed value. + assert torch.isclose(w[elements.index("Mg")], torch.tensor(0.30, dtype=w.dtype), atol=1e-4) + # At most K non-zero (Mg + ≤K-1 free, saturated to K on this non-degenerate setup). + nz = int((w > 1e-6).sum().item()) + assert nz <= K, f"expected ≤ {K} non-zero with Mg locked, got {nz}" + assert nz == K, f"toy model with Mg locked should saturate at K={K}, got {nz}" + + +def test_optimize_composition_max_elements_validation(): + """All max_elements / topk_* validation errors fire before model state is touched.""" + model, kernel, elements = _build_aligned_model_and_kernel() + with pytest.raises(ValueError, match=r"max_elements must be in \[1, n_components"): + model.optimize_composition(kernel, task_targets={"prop": 0.0}, max_elements=0, n_starts=2, steps=2) + with pytest.raises(ValueError, match=r"max_elements must be in \[1, n_components"): + model.optimize_composition(kernel, task_targets={"prop": 0.0}, max_elements=999, n_starts=2, steps=2) + with pytest.raises(TypeError, match="max_elements must be an int"): + model.optimize_composition(kernel, task_targets={"prop": 0.0}, max_elements=2.5, n_starts=2, steps=2) # type: ignore[arg-type] + # max_elements > |allowed_elements| → rejected with a specific message. + with pytest.raises(ValueError, match="exceeds the number of allowed elements"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + allowed_elements=["Mg", "Al"], + max_elements=5, + n_starts=2, + steps=2, + ) + # max_elements < n_locked → rejected. + init_w = torch.zeros(1, len(elements)) + init_w[0, elements.index("Mg")] = 0.3 + init_w[0, elements.index("Al")] = 0.3 + init_w[0, elements.index("Cu")] = 0.4 + with pytest.raises(ValueError, match="must be > total locked elements"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + initial_weights=init_w, + element_step_scale={"Mg": 0.0, "Al": 0.0, "Cu": 0.0}, + max_elements=2, + steps=2, + ) + # Bad annealing_scale. + with pytest.raises(ValueError, match=r"annealing_scale must be in \[0, 1\]"): + model.optimize_composition( + kernel, task_targets={"prop": 0.0}, max_elements=2, annealing_scale=-0.1, n_starts=2, steps=2 + ) + with pytest.raises(ValueError, match=r"annealing_scale must be in \[0, 1\]"): + model.optimize_composition( + kernel, task_targets={"prop": 0.0}, max_elements=2, annealing_scale=1.5, n_starts=2, steps=2 + ) + # Bad annealing_schedule dict. + with pytest.raises(ValueError, match="annealing_schedule missing required keys"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + max_elements=2, + annealing_schedule={"step": [0.5], "scale": [0.5]}, # no annealing_func + n_starts=2, + steps=2, + ) + with pytest.raises(ValueError, match="annealing_schedule lists must be the same length"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + max_elements=2, + annealing_schedule={"step": [0.5, 1.0], "scale": [0.5], "annealing_func": ["geometric"]}, + n_starts=2, + steps=2, + ) + with pytest.raises(ValueError, match=r"annealing_schedule\['step'\] entries must be in \(0, 1\]"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + max_elements=2, + annealing_schedule={"step": [0.0, 1.0], "scale": [0.5, 0.0], "annealing_func": ["geometric", "geometric"]}, + n_starts=2, + steps=2, + ) + with pytest.raises(ValueError, match=r"annealing_schedule\['step'\] must be strictly increasing"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + max_elements=2, + annealing_schedule={"step": [0.5, 0.5], "scale": [0.5, 0.3], "annealing_func": ["geometric", "geometric"]}, + n_starts=2, + steps=2, + ) + with pytest.raises(ValueError, match=r"annealing_schedule\['scale'\] entries must be in \[0, 1\]"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + max_elements=2, + annealing_schedule={"step": [1.0], "scale": [1.5], "annealing_func": ["geometric"]}, + n_starts=2, + steps=2, + ) + with pytest.raises(ValueError, match=r"annealing_schedule\['annealing_func'\] entries must be one of"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + max_elements=2, + annealing_schedule={"step": [1.0], "scale": [0.5], "annealing_func": ["exponential_decay"]}, + n_starts=2, + steps=2, + ) + + +def test_optimize_composition_max_elements_trajectory_softens_to_hard(): + """The per-step trajectory shows annealing: early steps are softer (more nonzeros) than late.""" + torch.manual_seed(0) + model, kernel, _ = _build_aligned_model_and_kernel() + K = 3 + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + n_starts=2, + max_elements=K, + steps=60, + lr=0.2, + record_weights_trajectory=True, + ) + traj = res.weights_trajectory # (steps, B, n) + assert traj is not None and traj.shape[0] == 60 + early_nz = (traj[2] > 1e-3).sum(dim=-1).float().mean().item() # avg #non-zero at step 2 + late_nz = (traj[-1] > 1e-3).sum(dim=-1).float().mean().item() # avg #non-zero at last step + # Early (large τ) should carry more spread mass; late (small τ) should be near K. + assert early_nz > late_nz, f"annealing not visible in trajectory: early={early_nz}, late={late_nz}" + # Late state should be at most a hair above K (in soft state; final returned is hard-projected). + assert late_nz <= K + 1, f"final soft state too diffuse: {late_nz} non-zero (target K={K})" + + +def test_optimize_composition_max_elements_constant_schedule_no_anneal(): + """An ``annealing_func='constant'`` segment covering the full run holds τ; hard-projection still gives K.""" + torch.manual_seed(0) + model, kernel, _ = _build_aligned_model_and_kernel() + K = 4 + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + n_starts=2, + max_elements=K, + annealing_scale=0.3, # initial scale; segment will hold this + annealing_schedule={"step": [1.0], "scale": [0.3], "annealing_func": ["constant"]}, + steps=40, + lr=0.2, + ) + w = res.optimized_weights + nz = (w > 1e-6).sum(dim=-1) + assert torch.all(nz == K) + + +def test_optimize_composition_annealing_scale_endpoints(): + """annealing_scale=0 and annealing_scale=1 both run cleanly and enforce K (the two endpoints + of the user-facing knob; calibration: 0→τ_start=1, 0.5→5, 1→25).""" + torch.manual_seed(0) + model, kernel, _ = _build_aligned_model_and_kernel() + K = 3 + for scale in (0.0, 1.0): + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + n_starts=2, + max_elements=K, + annealing_scale=scale, + steps=30, + lr=0.2, + ) + nz = (res.optimized_weights > 1e-6).sum(dim=-1) + assert torch.all(nz <= K), f"scale={scale}: nz={nz.tolist()}" + + +def test_optimize_composition_annealing_schedule_dict_overrides_front(): + """A dict with step[-1] < 1.0 takes over the front; the tail falls back to default.""" + torch.manual_seed(0) + model, kernel, _ = _build_aligned_model_and_kernel() + K = 3 + # Use a two-segment dict that only covers the first 50% of steps. + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + n_starts=2, + max_elements=K, + annealing_scale=0.5, + annealing_schedule={ + "step": [0.2, 0.5], + "scale": [0.9, 0.7], + "annealing_func": ["linear", "cosine"], + }, + steps=60, + lr=0.2, + record_weights_trajectory=True, + ) + # Hard-projected final still has exactly K (this run is non-degenerate). + nz = (res.optimized_weights > 1e-6).sum(dim=-1) + assert torch.all(nz <= K) + # Sanity: the trajectory was recorded — used by visualisation downstream. + assert res.weights_trajectory is not None and res.weights_trajectory.shape[0] == 60 + + +def test_optimize_composition_max_elements_gradient_flows_to_all_logits(): + """The soft top-K must let gradient flow back to logits at *all* positions, not just the chosen K. + + This is the qualitative difference vs. a post-hoc projection: at any τ > 0, all positions + carry non-trivial gradient — so the optimiser can re-select which K to include. + """ + torch.manual_seed(0) + model, kernel, _ = _build_aligned_model_and_kernel() + n = kernel.shape[0] + # Manually replicate one step at a moderate τ to peek at the gradient pattern. + logits = torch.zeros(1, n, requires_grad=True) + w_soft = torch.softmax(logits, dim=-1) + # Build the soft top-K mask inline (mirrors the production code). + K, tau = 3, 0.5 + alpha = logits.clone() + m = torch.zeros_like(logits) + for _ in range(K): + p = torch.softmax(alpha / tau, dim=-1) + m = m + p + alpha = alpha + torch.log((1.0 - p).clamp(min=1e-12)) + w = (w_soft * m) / (w_soft * m).sum(dim=-1, keepdim=True).clamp(min=1e-12) + # Loss against an arbitrary target → gradient should populate everywhere. + target = torch.zeros_like(w) + target[0, 5] = 1.0 + loss = ((w - target) ** 2).mean() + loss.backward() + assert logits.grad is not None + # All entries (not just K) should have non-trivial gradient. + abs_grad = logits.grad.abs() + n_nontrivial = int((abs_grad > 1e-8).sum().item()) + assert n_nontrivial == n, f"expected gradient at all {n} positions, got {n_nontrivial}" + + +def test_optimize_composition_fixed_amounts_pins_single_symbol(): + """fixed_amounts={'Au': 0.65} holds Au at exactly 0.65; the remaining 0.35 spreads to others.""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + fixed_amounts={"Au": 0.65}, + n_starts=3, + steps=40, + lr=0.2, + ) + w = res.optimized_weights + au = elements.index("Au") + # Au at exactly the pinned value across the batch (atol=1e-4 same as other lock tests). + assert torch.allclose(w[:, au], torch.full((3,), 0.65, dtype=w.dtype), atol=1e-4) + # Remaining mass is 1 - 0.65 = 0.35 across the non-Au columns. + rest_sum = w.sum(dim=-1) - w[:, au] + assert torch.allclose(rest_sum, torch.full((3,), 0.35, dtype=w.dtype), atol=1e-4) + + +def test_optimize_composition_fixed_amounts_multi_symbol(): + """Two pinned elements both hold at their assigned values; the rest sum to 1 - Σ fixed.""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + fixed_amounts={"Au": 0.65, "Ga": 0.20}, + n_starts=2, + steps=40, + lr=0.2, + ) + w = res.optimized_weights + au, ga = elements.index("Au"), elements.index("Ga") + assert torch.allclose(w[:, au], torch.full((2,), 0.65, dtype=w.dtype), atol=1e-4) + assert torch.allclose(w[:, ga], torch.full((2,), 0.20, dtype=w.dtype), atol=1e-4) + rest = w.sum(dim=-1) - w[:, au] - w[:, ga] + assert torch.allclose(rest, torch.full((2,), 0.15, dtype=w.dtype), atol=1e-4) + + +def test_optimize_composition_fixed_amounts_works_without_initial_weights(): + """fixed_amounts does not require initial_weights (unlike element_step_scale=0).""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + # No ``initial_weights`` → uses n_starts random init. Should succeed. + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + fixed_amounts={"Au": 0.5}, + n_starts=4, + steps=20, + lr=0.2, + ) + au = elements.index("Au") + assert torch.allclose( + res.optimized_weights[:, au], torch.full((4,), 0.5, dtype=res.optimized_weights.dtype), atol=1e-4 + ) + + +def test_optimize_composition_fixed_amounts_with_max_elements(): + """K=3 + 2 fixed → exactly 3 non-zero per row, with the 2 fixed at their pinned values.""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + K = 3 + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + fixed_amounts={"Au": 0.65, "Ga": 0.20}, + max_elements=K, + n_starts=2, + steps=60, + lr=0.2, + ) + w = res.optimized_weights + nz = (w > 1e-6).sum(dim=-1) + assert torch.all(nz <= K), f"got nz={nz.tolist()}, expected ≤ {K}" + au, ga = elements.index("Au"), elements.index("Ga") + assert torch.allclose(w[:, au], torch.full((2,), 0.65, dtype=w.dtype), atol=1e-4) + assert torch.allclose(w[:, ga], torch.full((2,), 0.20, dtype=w.dtype), atol=1e-4) + + +def test_optimize_composition_fixed_amounts_with_allowed_elements(): + """A fixed element not in allowed_elements is contradictory and rejected.""" + model, kernel, _ = _build_aligned_model_and_kernel() + with pytest.raises(ValueError, match="not in allowed_elements"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + fixed_amounts={"Au": 0.65}, + allowed_elements=["Mg", "Ga", "Cu"], # Au absent + n_starts=2, + steps=2, + ) + + +def test_optimize_composition_fixed_amounts_mutex_with_element_step_scale_zero(): + """The same symbol in both fixed_amounts and element_step_scale=0 is ambiguous → reject.""" + model, kernel, elements = _build_aligned_model_and_kernel() + init_w = torch.zeros(1, len(elements)) + init_w[0, elements.index("Au")] = 0.40 + init_w[0, elements.index("Cu")] = 0.60 + with pytest.raises(ValueError, match="appear in both element_step_scale=0 and"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + initial_weights=init_w, + fixed_amounts={"Au": 0.65}, + element_step_scale={"Au": 0.0}, + steps=2, + ) + + +def test_optimize_composition_fixed_amounts_validation(): + """Sum<1, value range, type, unknown symbols — every guard fires before model state is touched.""" + model, kernel, _ = _build_aligned_model_and_kernel() + # Sum >= 1.0 rejected (no leftover mass). + with pytest.raises(ValueError, match="must be strictly less than 1.0"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + fixed_amounts={"Au": 0.7, "Ga": 0.4}, + n_starts=2, + steps=2, + ) + # Value out of (0, 1). + with pytest.raises(ValueError, match="must be strictly between 0 and 1"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + fixed_amounts={"Au": 0.0}, + n_starts=2, + steps=2, + ) + with pytest.raises(ValueError, match="must be strictly between 0 and 1"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + fixed_amounts={"Au": 1.0}, + n_starts=2, + steps=2, + ) + # Unknown symbol. + with pytest.raises(ValueError, match="Unknown element symbol"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + fixed_amounts={"NotAnElement": 0.5}, + n_starts=2, + steps=2, + ) + # Empty mapping. + with pytest.raises(ValueError, match="must be non-empty"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + fixed_amounts={}, + n_starts=2, + steps=2, + ) + # Wrong type. + with pytest.raises(TypeError, match="must be a mapping"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + fixed_amounts=[("Au", 0.5)], # type: ignore[arg-type] + n_starts=2, + steps=2, + ) + # max_elements <= n_locked when fixed_amounts present. + with pytest.raises(ValueError, match="must be > total locked elements"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + fixed_amounts={"Au": 0.4, "Ga": 0.3}, + max_elements=2, + n_starts=2, + steps=2, + ) + + +def test_optimize_composition_max_elements_K_equals_one_one_hot(): + """K=1 produces a one-hot recipe (the smallest cardinality; exercises the n_iter=1 branch + of the iterative softmax). + """ + torch.manual_seed(0) + model, kernel, _ = _build_aligned_model_and_kernel() + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + n_starts=3, + max_elements=1, + steps=40, + lr=0.2, + ) + w = res.optimized_weights + nz = (w > 1e-6).sum(dim=-1) + assert torch.all(nz == 1), f"expected one-hot per row, got nz={nz.tolist()}" + assert torch.allclose(w.sum(dim=-1), torch.ones(3, dtype=w.dtype), atol=1e-5) + # The single non-zero must be exactly 1.0. + max_per_row = w.max(dim=-1).values + assert torch.allclose(max_per_row, torch.ones(3, dtype=w.dtype), atol=1e-5) + + +def test_optimize_composition_combined_locks_exceeding_simplex_rejected(): + """element_step_scale=0 (locking seed-heavy elements) + fixed_amounts together cannot + claim more than 100% of the simplex — the combined-lock runtime check catches this.""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + # Seed with Mg=0.50 (locked) + Cu=0.50 (free); fixed_amounts={"Au": 0.65}. Combined locked + # mass would be 0.50 (Mg) + 0.65 (Au) = 1.15 > 1.0. + init_w = torch.zeros(1, len(elements)) + init_w[0, elements.index("Mg")] = 0.50 + init_w[0, elements.index("Cu")] = 0.50 + with pytest.raises(ValueError, match="Combined locked mass exceeds 1.0"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + initial_weights=init_w, + element_step_scale={"Mg": 0.0}, + fixed_amounts={"Au": 0.65}, + steps=2, + ) + + +def test_optimize_composition_max_elements_equals_n_locked_rejected(): + """``max_elements == n_locked`` is rejected (no unlocked slot to absorb leftover mass). + + Previously this combination passed validation but silently produced rows with sum < 1 + when the locked seed values summed to less than 1 (e.g. Mg=Al=Cu=0.20, max_elements=3). + The validation now enforces strict ``max_elements > n_locked`` for both lock paths. + """ + model, kernel, elements = _build_aligned_model_and_kernel() + init_w = torch.zeros(1, len(elements)) + init_w[0, elements.index("Mg")] = 0.30 + init_w[0, elements.index("Al")] = 0.30 + init_w[0, elements.index("Cu")] = 0.40 + with pytest.raises(ValueError, match="must be > total locked elements"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + initial_weights=init_w, + element_step_scale={"Mg": 0.0, "Al": 0.0, "Cu": 0.0}, + max_elements=3, + steps=2, + ) + + +def test_optimize_composition_min_nonzero_weight_drops_traces(): + """A floor of 0.1 makes every non-zero unlocked element ≥ 0.1 (no trace amounts).""" + torch.manual_seed(0) + model, kernel, _ = _build_aligned_model_and_kernel() + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + n_starts=4, + max_elements=5, + min_nonzero_weight=0.1, + steps=80, + lr=0.2, + ) + w = res.optimized_weights + # Every non-zero weight is at least the floor (within float tolerance). + nz_mask = w > 0 + assert (w[nz_mask] >= 0.1 - 1e-5).all(), f"floor violated: smallest non-zero = {w[nz_mask].min().item():.4f}" + # Rows still sum to 1. + assert torch.allclose(w.sum(dim=-1), torch.ones(4, dtype=w.dtype), atol=1e-5) + + +def test_optimize_composition_min_nonzero_weight_zero_is_noop(): + """min_nonzero_weight=0.0 (default) produces identical output to omitting the kwarg.""" + torch.manual_seed(0) + model = _make_reg_clf_model() + kernel = torch.randn(6, INPUT_DIM) + init = torch.full((2, 6), 1.0 / 6) + base = model.optimize_composition(kernel, task_targets={"prop": 1.0}, initial_weights=init, steps=20, lr=0.1) + torch.manual_seed(0) + floored = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + initial_weights=init, + min_nonzero_weight=0.0, + steps=20, + lr=0.1, + ) + assert torch.allclose(base.optimized_weights, floored.optimized_weights, atol=1e-6) + + +def test_optimize_composition_min_nonzero_weight_with_fixed_amounts_below_floor(): + """A floor higher than any fixed amount is contradictory — caught at kwarg time.""" + model, kernel, _ = _build_aligned_model_and_kernel() + with pytest.raises(ValueError, match="below min_nonzero_weight"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + fixed_amounts={"Au": 0.05, "Ga": 0.20}, + min_nonzero_weight=0.10, + n_starts=2, + steps=2, + ) + + +def test_optimize_composition_min_nonzero_weight_with_fixed_amounts_compatible(): + """A floor ≤ all fixed amounts works; fixed Au stays at 0.05 even with floor=0.05.""" + torch.manual_seed(0) + model, kernel, elements = _build_aligned_model_and_kernel() + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + fixed_amounts={"Au": 0.30, "Ga": 0.20}, + min_nonzero_weight=0.10, + max_elements=4, + n_starts=2, + steps=40, + lr=0.2, + ) + w = res.optimized_weights + au = elements.index("Au") + assert torch.allclose(w[:, au], torch.full((2,), 0.30, dtype=w.dtype), atol=1e-4) + nz_mask = w > 0 + assert (w[nz_mask] >= 0.10 - 1e-5).all() + + +def test_optimize_composition_min_nonzero_weight_with_locked_seed_below_floor(): + """An element_step_scale=0 lock pinning at a seed value below the floor is rejected.""" + model, kernel, elements = _build_aligned_model_and_kernel() + init_w = torch.zeros(1, len(elements)) + init_w[0, elements.index("Mg")] = 0.05 # locked below floor + init_w[0, elements.index("Cu")] = 0.95 + with pytest.raises(ValueError, match="locked element.*below min_nonzero_weight"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + initial_weights=init_w, + element_step_scale={"Mg": 0.0}, + min_nonzero_weight=0.10, + steps=2, + ) + + +def test_optimize_composition_min_nonzero_weight_validation(): + """floor out of [0,1], floor > 1/max_elements — all rejected pre-state.""" + model, kernel, _ = _build_aligned_model_and_kernel() + with pytest.raises(ValueError, match=r"min_nonzero_weight must be in \[0, 1\]"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + min_nonzero_weight=-0.1, + n_starts=2, + steps=2, + ) + with pytest.raises(ValueError, match=r"min_nonzero_weight must be in \[0, 1\]"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + min_nonzero_weight=1.5, + n_starts=2, + steps=2, + ) + # floor > 1/K: K=3 → 1/K=0.333; floor=0.5 > 0.333 → reject. + with pytest.raises(ValueError, match="exceeds 1 / max_elements"): + model.optimize_composition( + kernel, + task_targets={"prop": 0.0}, + min_nonzero_weight=0.5, + max_elements=3, + n_starts=2, + steps=2, + ) + + +def test_optimize_composition_min_nonzero_weight_fallback_preserves_simplex(): + """When the floor would empty a row's unlocked mass, the row falls back to unfloored + (preserves sum=1 instead of breaking the simplex).""" + torch.manual_seed(0) + model, kernel, _ = _build_aligned_model_and_kernel() + # Use a high floor + small K so dropping below-floor positions could empty unlocked. + res = model.optimize_composition( + kernel, + task_targets={"prop": 1.0}, + n_starts=4, + max_elements=3, + min_nonzero_weight=0.33, # at the edge: needs all 3 ≥ 0.33 → exactly 1/3 each + steps=40, + lr=0.2, + ) + w = res.optimized_weights + # Whatever the floor did, the simplex must be preserved. + assert torch.allclose(w.sum(dim=-1), torch.ones(4, dtype=w.dtype), atol=1e-5) + assert (w >= 0).all() + + def test_optimize_latent_space_with_ae(): model = _make_model() model.eval() diff --git a/src/foundation_model/scripts/paper_inverse_comparison.py b/src/foundation_model/scripts/paper_inverse_comparison.py index 97d0226..2ab5c2b 100644 --- a/src/foundation_model/scripts/paper_inverse_comparison.py +++ b/src/foundation_model/scripts/paper_inverse_comparison.py @@ -158,6 +158,47 @@ "diversity": 0.0, }, {"label": "comp\n(random)", "init": "random", "blend": 0.95, "allowed": "all", "scale": 1.0, "diversity": 1.0}, + # max_elements (cardinality constraint via differentiable iterative-softmax + annealing). + # Builds on the "element list" + seed + 5%-all configuration — the most realistic baseline — + # and adds K=3 / K=5 hard caps. Uses the calibrated default annealing (geometric, τ_start=5); + # the K=5 row also includes a linear-schedule aggressive variant to demonstrate the + # explore-then-commit behaviour that occasionally beats the unconstrained baseline. + { + "label": "comp\n(seed, 5% all,\nelement list, K=3)", + "init": "seed", + "blend": 0.95, + "allowed": DEFAULT_ALLOY_PALETTE, + "scale": 1.0, + "diversity": 1.0, + "max_elements": 3, + }, + { + "label": "comp\n(seed, 5% all,\nelement list, K=5)", + "init": "seed", + "blend": 0.95, + "allowed": DEFAULT_ALLOY_PALETTE, + "scale": 1.0, + "diversity": 1.0, + "max_elements": 5, + }, + { + # Aggressive variant: linear schedule + high annealing_scale (≈ old τ_start=10). + # ``25 ** 0.715 ≈ 10`` — same starting τ as the previous "linear τ=10" but expressed + # in the unified scale knob, and the linear interpolation now lives in the dict. + "label": "comp\n(seed, 5% all,\nelement list, K=5, linear scale=0.72)", + "init": "seed", + "blend": 0.95, + "allowed": DEFAULT_ALLOY_PALETTE, + "scale": 1.0, + "diversity": 1.0, + "max_elements": 5, + "annealing_scale": 0.715, + "annealing_schedule": { + "step": [1.0], + "scale": [0.0], # decay all the way to τ ≈ 1.0 by the end + "annealing_func": ["linear"], + }, + }, ] LATENT_ALIGN_SCALES = [0.0, 0.25, 1.0] # ae_align_scale ∈ [0, 1] — three points: failure / mid / max @@ -968,9 +1009,7 @@ def _emit_trajectory_outputs( # the separate ``comparison.png``. reg_names = list(reg_targets) # Mean reg trajectory across seeds (per step → per task). - reg_traj_dict: dict[str, np.ndarray] = { - t: traj_targets[:, :, j] for j, t in enumerate(reg_names) - } + reg_traj_dict: dict[str, np.ndarray] = {t: traj_targets[:, :, j] for j, t in enumerate(reg_names)} # Mean variant: use QC after-decode (final value) as a flat baseline-vs-target progress # line only if it's available; otherwise drop QC. For the inverse-design case QC is in # results dict but not per-step; we synthesise a "flat" QC progress line from the final @@ -1082,6 +1121,15 @@ def _run_composition_config( else: raise ValueError(f"Unknown init mode in config: {cfg['init']!r}") + # Optional cardinality knobs — fall through to ``optimize_composition`` defaults when absent + # so the original 5 comp rows behave identically to before this change. + topk_kwargs: dict[str, Any] = {} + if "max_elements" in cfg: + topk_kwargs["max_elements"] = cfg["max_elements"] + for k in ("annealing_scale", "annealing_schedule"): + if k in cfg: + topk_kwargs[k] = cfg[k] + t0 = time.perf_counter() res = model.optimize_composition( kernel, @@ -1094,6 +1142,7 @@ def _run_composition_config( steps=steps, lr=lr, record_weights_trajectory=record_trajectory, + **topk_kwargs, **init_kwargs, ) elapsed = time.perf_counter() - t0