diff --git a/docs/index.rst b/docs/index.rst
index 0f92a202..41fc73d0 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -40,6 +40,7 @@
examples_guide
_collections/examples/dp_sgd_flax_linen_mnist
_collections/examples/dp_sgd_keras_gemma3_lora_finetuning_samsum
+ _collections/examples/dp_sgd_keras_gemma3_dpsapf
_collections/examples/dp_sgd_keras_gemma3_synthetic_data
.. toctree::
diff --git a/examples/dp_sgd_keras_gemma3_dpsapf.ipynb b/examples/dp_sgd_keras_gemma3_dpsapf.ipynb
new file mode 100644
index 00000000..f43644d1
--- /dev/null
+++ b/examples/dp_sgd_keras_gemma3_dpsapf.ipynb
@@ -0,0 +1,454 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "# Auto-LoRA + DP-SGD fine-tuning of Gemma3\n\n**Copyright 2026 DeepMind Technologies Limited.**\n\nThis notebook adapts **DP-SAPF** (*Saliency-Aware Parameter Fine-tuning of Public Models for Differentially Private Image Synthesis*; [arXiv:2605.30312](https://arxiv.org/abs/2605.30312)) to LLM fine-tuning. The core idea of DP-SAPF: under DP-SGD, the noise added to every trainable parameter scales with the parameter count, so adapting **every** layer wastes privacy budget on uninformative directions. Instead, first run a **DP-aware saliency probe** to identify which layers carry the most useful gradient signal for the task, then fine-tune only that small subset — getting a better signal-to-noise ratio under the same (eps, delta) budget.\n\nConcretely, this notebook extends [`dp_sgd_keras_gemma3_lora_finetuning_samsum.ipynb`](dp_sgd_keras_gemma3_lora_finetuning_samsum.ipynb) with two new pieces:\n\n1. **Probe (DP)**: for each training sample, vote +1 on the top-k LoRA-candidate layers ranked by per-sample gradient L2 norm; add Gaussian noise to the vote histogram (L2 sensitivity = $\\sqrt k$).\n2. **DP-SGD fine-tune**: enable LoRA only on the top-k% layers by score; calibrate the training noise so the **composed** (probe + DP-SGD) cost matches a single target $\\varepsilon$.\n\nSelection-specific code (the probe, top-k voting, LoRA gating, DP composition) lives in [`dpsapf_utils.py`](https://github.com/google-deepmind/jax_privacy/tree/main/examples/dpsapf_utils.py); everything else (data loading, model setup, optimizer, fit, eval) stays inline so the moving parts are visible."
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Install and configure"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "!pip install -q -U \"keras>=3\" keras-nlp keras-hub rouge-score tqdm ipywidgets\n",
+ "!pip install -q dp_accounting jaxtyping drjax tensorflow_datasets datasets\n",
+ "!pip install -q beautifulsoup4 lxml jax_privacy==1.0.0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "# Redirect big artifacts to a shared scratch dir; set env vars BEFORE\n",
+ "# importing keras / kagglehub / datasets.\n",
+ "CACHE_ROOT = \"./\" # change to a writable path\n",
+ "os.makedirs(CACHE_ROOT, exist_ok=True)\n",
+ "for k, sub in [\n",
+ " (\"KERAS_HOME\", \"keras\"),\n",
+ " (\"KAGGLEHUB_CACHE\", \"kagglehub\"),\n",
+ " (\"TFDS_DATA_DIR\", \"tfds\"),\n",
+ " (\"HF_HOME\", \"huggingface\"),\n",
+ " (\"JAX_COMPILATION_CACHE_DIR\", \"jax_compilation_cache\"),\n",
+ "]:\n",
+ " os.environ.setdefault(k, os.path.join(CACHE_ROOT, sub))\n",
+ " os.makedirs(os.environ[k], exist_ok=True)\n",
+ "\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
+ "os.environ.setdefault(\"XLA_PYTHON_CLIENT_PREALLOCATE\", \"false\")\n",
+ "os.environ.setdefault(\"XLA_PYTHON_CLIENT_MEM_FRACTION\", \"0.85\")\n",
+ "os.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"3\")\n",
+ "\n",
+ "import gc\n",
+ "import keras\n",
+ "import keras_hub\n",
+ "import tensorflow as tf\n",
+ "import tqdm\n",
+ "import jax\n",
+ "from jax_privacy import keras_api\n",
+ "\n",
+ "# Selection-only helpers — everything else is inline.\n",
+ "import examples.dpsapf_utils as ALU"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Kaggle login\n",
+ "\n",
+ "Needed to download Gemma3 weights. You must **accept the Gemma3 license** in your Kaggle account for each variant ([1B](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_instruct_1b), [4B](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_instruct_4b_text)) or the download returns 403."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import kagglehub\n",
+ "\n",
+ "kagglehub.login()\n",
+ "# Or: export KAGGLE_USERNAME=... KAGGLE_KEY=... before launching the notebook."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Hyper-parameters\n",
+ "\n",
+ "Two datasets: `cnn_dailymail` (TFDS) and `xsum_hf` (HuggingFace). Both ~200-290K examples — well-sized for DP subsampling amplification."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "TEST_RUN = True\n",
+ "DATASET = \"cnn_dailymail\" # or \"xsum_hf\"\n",
+ "\n",
+ "# Model / training\n",
+ "GEMMA3_MODEL_TYPE = \"gemma3_instruct_4b_text\"\n",
+ "SEQUENCE_LENGTH = 1024\n",
+ "TEST_DS_SEQUENCE_LENGTH = 1024\n",
+ "EPOCHS = 3\n",
+ "BATCH_SIZE = 4\n",
+ "GRADIENT_ACCUMULATION_STEPS = 128 # effective batch = 4 * 128 = 512\n",
+ "TEST_DS_BATCH_SIZE = 4\n",
+ "LORA_RANK = 64\n",
+ "LEARNING_RATE = 3e-3\n",
+ "SEED = 0\n",
+ "USE_MIXED_PRECISION = False\n",
+ "\n",
+ "# Probe + selection\n",
+ "PROBE_SAMPLES = 10000\n",
+ "PROBE_TOPK = 8\n",
+ "PROBE_NOISE_MULTIPLIER = 40.0 # in units of L2 sensitivity sqrt(topk)\n",
+ "LORA_TOP_K_PERCENT = 10.0 # 1-5%% wins on CNN/DM at 1024+DP\n",
+ "LORA_ATTN_ONLY = True\n",
+ "\n",
+ "# Composed DP target\n",
+ "TOTAL_EPSILON = 4.0\n",
+ "DELTA = 2e-5\n",
+ "CLIPPING_NORM = 0.001\n",
+ "\n",
+ "if TEST_RUN:\n",
+ " GEMMA3_MODEL_TYPE = \"gemma3_instruct_1b\"\n",
+ " PROBE_SAMPLES = 10000\n",
+ "\n",
+ "DATASET_CFG = ALU.DATASET_REGISTRY[DATASET]\n",
+ "print(f\"Dataset: {DATASET} ({DATASET_CFG['note']})\")\n",
+ "print(f\"Model: {GEMMA3_MODEL_TYPE} (test_run={TEST_RUN})\")\n",
+ "print(f\"Effective train batch: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": "## Data\n\nLoad + format using helpers from `dpsapf_utils` (registry + HF/TFDS dispatch + prompt formatter), then shuffle/batch inline."
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fmt = ALU.make_source_to_gemma3_format(DATASET_CFG)\n",
+ "TRAIN_DS = ALU.load_dataset_split(DATASET_CFG, \"train\").map(fmt)\n",
+ "VALIDATION_DS = ALU.load_dataset_split(DATASET_CFG, \"validation\").map(fmt)\n",
+ "TEST_DS = ALU.load_dataset_split(DATASET_CFG, \"test\").map(fmt)\n",
+ "\n",
+ "TRAIN_SIZE = int(TRAIN_DS.cardinality().numpy())\n",
+ "print(\n",
+ " f\"Train: {TRAIN_SIZE} Val: {int(VALIDATION_DS.cardinality().numpy())} \"\n",
+ " f\"Test: {int(TEST_DS.cardinality().numpy())}\"\n",
+ ")\n",
+ "\n",
+ "TRAIN_DS_BATCHED = TRAIN_DS.shuffle(2048).batch(BATCH_SIZE, drop_remainder=True)\n",
+ "VALIDATION_DS_BATCHED = VALIDATION_DS.batch(BATCH_SIZE, drop_remainder=True)\n",
+ "TEST_DS_BATCHED = TEST_DS.batch(TEST_DS_BATCH_SIZE)\n",
+ "\n",
+ "EXAMPLE = (\n",
+ " VALIDATION_DS.take(1)\n",
+ " .batch(1, drop_remainder=True)\n",
+ " .as_numpy_iterator()\n",
+ " .next()\n",
+ ")\n",
+ "for k, v in EXAMPLE.items():\n",
+ " s = v[0].decode(\"utf-8\")\n",
+ " print(f'{k}: \"{s[:200]}{\"...\" if len(s) > 200 else \"\"}\"\\n')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Load Gemma3 (no LoRA yet) and run baseline inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MODEL_WEIGHTS_DTYPE = None\n",
+ "if USE_MIXED_PRECISION:\n",
+ " keras.mixed_precision.set_global_policy(\"mixed_bfloat16\")\n",
+ " MODEL_WEIGHTS_DTYPE = \"bfloat16\"\n",
+ "\n",
+ "keras.distribution.set_distribution(keras.distribution.DataParallel())\n",
+ "\n",
+ "gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(\n",
+ " GEMMA3_MODEL_TYPE, dtype=MODEL_WEIGHTS_DTYPE\n",
+ ")\n",
+ "gemma_lm.preprocessor.sequence_length = SEQUENCE_LENGTH\n",
+ "\n",
+ "\n",
+ "def show_inference():\n",
+ " print(gemma_lm.generate(EXAMPLE[\"prompts\"])[0])\n",
+ " print(f\"\\nGold:\\n{EXAMPLE['responses'][0].decode('utf-8')}\")\n",
+ "\n",
+ "\n",
+ "print(\"=== Baseline (before fine-tuning) ===\")\n",
+ "show_inference()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## DP probe + layer selection\n",
+ "\n",
+ "Per sample, vote +1 on the top-k LoRA-candidate layers ranked by per-sample gradient L2 norm. Gaussian noise on the vote histogram; L2 sensitivity = $\\sqrt k$ under ADD/REMOVE neighboring. `ALU.run_probe` returns the noisy scores, the top-k% selected paths, and DP metadata used downstream for composition."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "probe = ALU.run_probe(\n",
+ " gemma_lm,\n",
+ " TRAIN_DS_BATCHED,\n",
+ " top_k_percent=LORA_TOP_K_PERCENT,\n",
+ " num_probe_samples=PROBE_SAMPLES,\n",
+ " topk=PROBE_TOPK,\n",
+ " noise_multiplier=PROBE_NOISE_MULTIPLIER,\n",
+ " attn_only=LORA_ATTN_ONLY,\n",
+ " seed=SEED,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Enable LoRA on selected layers, set up composed DP-SGD\n",
+ "\n",
+ "Both the probe and DP-SGD are Poisson-subsampled Gaussian mechanisms. We solve for $\\sigma_{\\text{train}}$ such that the Renyi-DP composed cost equals `TOTAL_EPSILON`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ALU.enable_lora_on_paths(gemma_lm.backbone, probe.selected_paths, LORA_RANK)\n",
+ "print(f\"LoRA enabled on {len(probe.selected_paths)} layers (rank={LORA_RANK}).\")\n",
+ "\n",
+ "STEPS_PER_EPOCH = TRAIN_SIZE // BATCH_SIZE\n",
+ "TOTAL_TRAIN_STEPS = EPOCHS * STEPS_PER_EPOCH\n",
+ "EFFECTIVE_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS\n",
+ "\n",
+ "SIGMA_TRAIN = ALU.calibrate_train_sigma(\n",
+ " target_eps=TOTAL_EPSILON,\n",
+ " probe_sigma=probe.noise_multiplier,\n",
+ " num_probe_samples=probe.n_seen,\n",
+ " effective_batch_size=EFFECTIVE_BATCH_SIZE,\n",
+ " train_steps=TOTAL_TRAIN_STEPS,\n",
+ " train_size=TRAIN_SIZE,\n",
+ " delta=DELTA,\n",
+ ")\n",
+ "composed_eps = ALU.compose_probe_and_training(\n",
+ " probe_sigma=probe.noise_multiplier,\n",
+ " num_probe_samples=probe.n_seen,\n",
+ " train_sigma=SIGMA_TRAIN,\n",
+ " effective_batch_size=EFFECTIVE_BATCH_SIZE,\n",
+ " train_steps=TOTAL_TRAIN_STEPS,\n",
+ " train_size=TRAIN_SIZE,\n",
+ " delta=DELTA,\n",
+ ")\n",
+ "print(\n",
+ " f\"target total_eps={TOTAL_EPSILON} -> sigma_train={SIGMA_TRAIN:.4f}; \"\n",
+ " f\"composed (eps, delta) ~= ({composed_eps:.4f}, {DELTA:.1e})\"\n",
+ ")\n",
+ "\n",
+ "dp_cfg = keras_api.DPKerasConfig(\n",
+ " epsilon=TOTAL_EPSILON,\n",
+ " delta=DELTA,\n",
+ " noise_multiplier=SIGMA_TRAIN,\n",
+ " clipping_norm=CLIPPING_NORM,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " train_steps=TOTAL_TRAIN_STEPS,\n",
+ " train_size=TRAIN_SIZE,\n",
+ " gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n",
+ " seed=SEED,\n",
+ ")\n",
+ "gemma_lm = keras_api.make_private(gemma_lm, dp_cfg)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Compile and fine-tune\n",
+ "\n",
+ "`fit()` may be called only once per `make_private` — the declared (eps, delta) budget rules out further training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "optimizer = keras.optimizers.Adam(\n",
+ " learning_rate=LEARNING_RATE,\n",
+ " gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n",
+ ")\n",
+ "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n",
+ "gemma_lm.compile(\n",
+ " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+ " optimizer=optimizer,\n",
+ " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
+ ")\n",
+ "\n",
+ "gemma_lm.fit(\n",
+ " x=TRAIN_DS_BATCHED,\n",
+ " epochs=EPOCHS,\n",
+ " validation_data=VALIDATION_DS_BATCHED,\n",
+ ")\n",
+ "\n",
+ "print(\"\\n=== After fine-tuning ===\")\n",
+ "show_inference()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Evaluate ROUGE on the test split"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gemma_lm.preprocessor.sequence_length = TEST_DS_SEQUENCE_LENGTH\n",
+ "\n",
+ "METRIC_FNS = {\n",
+ " \"rouge_1\": keras_hub.metrics.RougeN(order=1),\n",
+ " \"rouge_2\": keras_hub.metrics.RougeN(order=2),\n",
+ " \"rouge_l\": keras_hub.metrics.RougeL(),\n",
+ "}\n",
+ "\n",
+ "\n",
+ "def _common_prefix(a, b):\n",
+ " i = 0\n",
+ " while i < len(a) and i < len(b) and a[i] == b[i]:\n",
+ " i += 1\n",
+ " return i\n",
+ "\n",
+ "\n",
+ "for batch in tqdm.tqdm(TEST_DS_BATCHED):\n",
+ " prompts = [p.decode(\"utf-8\") for p in batch[\"prompts\"].numpy()]\n",
+ " outputs = gemma_lm.generate(prompts)\n",
+ " outputs = [o[_common_prefix(p, o) :] for p, o in zip(prompts, outputs)]\n",
+ " targets = [s.decode(\"utf-8\") for s in batch[\"responses\"].numpy()]\n",
+ " for m in METRIC_FNS.values():\n",
+ " m.update_state(targets, outputs)\n",
+ "\n",
+ "RESULT = {k: float(m.result()[\"f1_score\"]) for k, m in METRIC_FNS.items()}\n",
+ "print(RESULT)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Results\n",
+ "\n",
+ "ROUGE F1 on the test split, Gemma3-1B + LoRA (rank 64) + DP-SGD at $\\varepsilon=4$, sequence length 1024. **Baseline** is `keras_hub`'s default `enable_lora()` (query + value only); **100%** enables LoRA on every attention projection (query/key/value/attention_output) in every block; **top 5%** keeps only the top-5% attention layers ranked by the DP probe.\n",
+ "\n",
+ "| Dataset | Method | ROUGE-1 | ROUGE-2 | ROUGE-L |\n",
+ "|---|---|---|---|---|\n",
+ "| CNN/DailyMail | baseline (query + value, keras_hub default) | 0.217 | 0.084 | 0.157 |\n",
+ "| CNN/DailyMail | 100% attn projections | 0.154 | 0.046 | 0.114 |\n",
+ "| CNN/DailyMail | **top 5% attn (probe)** | **0.237** | **0.096** | **0.168** |\n",
+ "| XSum | baseline (query + value, keras_hub default) | 0.247 | 0.064 | 0.192 |\n",
+ "| XSum | 100% attn projections | 0.205 | 0.046 | 0.158 |\n",
+ "| XSum | **top 5% attn (probe)** | **0.264** | **0.073** | **0.203** |\n",
+ "\n",
+ "Selecting the top 5% attention layers via the DP probe beats the `keras_hub` default on both datasets, and is dramatically better than enabling LoRA on every attention projection — fewer trainable parameters means less DP noise per useful direction. The effect is consistent across two different summarization styles (multi-sentence highlights vs single-sentence abstract)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Citation\n",
+ "\n",
+ "If you find this notebook useful in your research, please cite:\n",
+ "\n",
+ "```bibtex\n",
+ "@article{gong2026dpsapf,\n",
+ " title = {DP-SAPF: Saliency-Aware Parameter Fine-tuning of Public Models for Differentially Private Image Synthesis},\n",
+ " author = {Chen Gong and Kecen Li and Zinan Lin and Tianhao Wang},\n",
+ " journal = {arXiv preprint arXiv:2605.30312},\n",
+ " year = {2026},\n",
+ " url = {https://arxiv.org/abs/2605.30312}\n",
+ "}\n",
+ "```"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/dpsapf_utils.py b/examples/dpsapf_utils.py
new file mode 100644
index 00000000..46d58a99
--- /dev/null
+++ b/examples/dpsapf_utils.py
@@ -0,0 +1,395 @@
+# Copyright 2026 DeepMind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Selection-only helpers for the auto-LoRA + DP-SGD notebook.
+
+Scope: the DP probe, layer-selection thresholding, and the probe-side DP
+accounting. Data loading, model setup, training, and ROUGE eval stay in
+the notebook so the moving parts are visible.
+
+Public API:
+ DATASET_REGISTRY # cnn_dailymail, xsum_hf
+ make_source_to_gemma3_format(cfg) # tf.data.map prompt/response fn
+ load_dataset_split(cfg, split_spec) # TFDS vs HuggingFace dispatcher
+
+ get_lora_candidate_layers(backbone, ...) # enumerate LoRA-eligible layers
+ run_probe(gemma_lm, train_ds, top_k_percent, **kw) -> ProbeResult
+ enable_lora_on_paths(backbone, paths, rank)
+
+ compose_probe_and_training(...) # RDP composition (probe + DP-SGD)
+ calibrate_train_sigma(target_eps, **kw) # binary search for sigma_train
+"""
+
+import dataclasses
+import functools
+import gc
+
+import dp_accounting
+import jax
+import jax.numpy as jnp
+import keras
+import tensorflow as tf
+import tensorflow_datasets as tfds
+
+from jax_privacy.accounting import accountants
+
+# HuggingFace `datasets` is only needed for the xsum_hf entry of
+# DATASET_REGISTRY; keep it optional so TFDS-only users don't need it.
+try:
+ from datasets import load_dataset # pytype: disable=import-error
+except ImportError:
+ load_dataset = None
+
+
+# ---------------------------------------------------------------------------
+# Dataset registry + minimal data helpers (needed to materialise the probe input
+# in the right prompt/response format).
+# ---------------------------------------------------------------------------
+
+DATASET_REGISTRY = {
+ "cnn_dailymail": {
+ "loader": "tfds",
+ "tfds_name": "cnn_dailymail",
+ "input_field": "article",
+ "output_field": "highlights",
+ "prompt_prefix": "Summarize the following news article:\n",
+ "prompt_suffix": "\nHighlights:\n",
+ "note": (
+ "~287K news articles -> multi-sentence highlights. "
+ "Long inputs. Requires `pip install beautifulsoup4 lxml`."
+ ),
+ },
+ "xsum_hf": {
+ "loader": "hf",
+ "hf_name": "EdinburghNLP/xsum",
+ "input_field": "document",
+ "output_field": "summary",
+ "prompt_prefix": "Summarize the following article in one sentence:\n",
+ "prompt_suffix": "\nSummary:\n",
+ "note": "~204K BBC articles -> 1-sentence summaries via HuggingFace.",
+ },
+}
+
+
+def make_source_to_gemma3_format(cfg):
+ """Returns a tf.data.map fn that emits {prompts, responses} string dicts."""
+ in_field, out_field = cfg["input_field"], cfg["output_field"]
+ prefix, suffix = cfg["prompt_prefix"], cfg["prompt_suffix"]
+
+ def fn(d):
+ return {
+ "prompts": tf.strings.join([prefix, d[in_field], suffix]),
+ "responses": d[out_field],
+ }
+
+ return fn
+
+
+def load_dataset_split(cfg, split_spec):
+ """Loads one split, dispatching by cfg['loader']. Returns a
+ tf.data.Dataset with known cardinality (so train_size is meaningful)."""
+ if cfg.get("loader") == "hf":
+ if load_dataset is None:
+ raise ImportError(
+ "The HuggingFace `datasets` package is required for the xsum_hf "
+ "dataset. Install it with `pip install datasets`."
+ )
+ in_field, out_field = cfg["input_field"], cfg["output_field"]
+ hf_ds = load_dataset(cfg["hf_name"], split=split_spec)
+
+ def gen():
+ for ex in hf_ds:
+ yield {in_field: ex[in_field], out_field: ex[out_field]}
+
+ return tf.data.Dataset.from_generator(
+ gen,
+ output_signature={
+ in_field: tf.TensorSpec(shape=(), dtype=tf.string),
+ out_field: tf.TensorSpec(shape=(), dtype=tf.string),
+ },
+ ).apply(tf.data.experimental.assert_cardinality(len(hf_ds)))
+
+ return tfds.load(cfg["tfds_name"], split=split_spec)
+
+
+# ---------------------------------------------------------------------------
+# Layer enumeration + selective LoRA enabling.
+# ---------------------------------------------------------------------------
+
+
+def get_lora_candidate_layers(backbone, attn_only=False):
+ """Dense / EinsumDense sublayers of `backbone` eligible for a LoRA adapter."""
+ out, seen = [], set()
+ # pylint: disable-next=protected-access
+ for layer in backbone._flatten_layers(recursive=True, include_self=False):
+ if id(layer) in seen:
+ continue
+ if not isinstance(layer, (keras.layers.Dense, keras.layers.EinsumDense)):
+ continue
+ if not (hasattr(layer, "kernel") and hasattr(layer, "enable_lora")):
+ continue
+ if attn_only and "attention" not in layer.path:
+ continue
+ seen.add(id(layer))
+ out.append(layer)
+ return out
+
+
+def enable_lora_on_paths(backbone, paths, rank):
+ """Enable LoRA on layers whose `.path` is in `paths`; freeze the rest."""
+ p2l = {
+ l.path: l for l in get_lora_candidate_layers(backbone, attn_only=False)
+ }
+ ids = {id(p2l[p]) for p in paths if p in p2l}
+ backbone.trainable = True
+ backbone._lora_rank = rank # pylint: disable=protected-access
+ # pylint: disable-next=protected-access
+ for layer in backbone._flatten_layers(include_self=False):
+ if id(layer) in ids:
+ layer.trainable = True
+ layer.enable_lora(rank=rank)
+ bias = getattr(layer, "bias", None)
+ if bias is not None:
+ bias.trainable = False
+ else:
+ layer.trainable = False
+ return ids
+
+
+# ---------------------------------------------------------------------------
+# DP probe: per-sample top-k voting over LoRA candidates.
+# ---------------------------------------------------------------------------
+
+
+@dataclasses.dataclass
+class ProbeResult:
+ selected_paths: set # paths kept after `top_k_percent` threshold
+ ranked: list # [(path, score), ...] descending
+ n_seen: int # number of probe samples actually processed
+ topk: int # k in the per-sample top-k vote
+ noise_multiplier: float # in units of sqrt(topk) sensitivity
+ used_noise: bool
+
+
+def _probe_topk_vote(
+ gemma_lm,
+ train_ds,
+ *,
+ num_probe_samples,
+ topk,
+ noise_multiplier,
+ microbatch_size,
+ use_noise,
+ seed,
+ attn_only,
+):
+ """Internal: returns (layer_scores: dict[path->score], n_seen)."""
+ trainable_vars = list(gemma_lm.trainable_variables)
+ ntvars = [v.value for v in gemma_lm.non_trainable_variables]
+ tvars = [v.value for v in trainable_vars]
+
+ candidates = get_lora_candidate_layers(gemma_lm.backbone, attn_only=attn_only)
+ kid_to_path = {id(l.kernel): l.path for l in candidates}
+ cand_indices = tuple(
+ i for i, v in enumerate(trainable_vars) if id(v) in kid_to_path
+ )
+ cand_paths = [kid_to_path[id(trainable_vars[i])] for i in cand_indices]
+ num_cand = len(cand_indices)
+ if num_cand == 0:
+ raise ValueError("No LoRA-candidate layers found.")
+ if topk > num_cand:
+ raise ValueError(f"topk={topk} > num_candidates={num_cand}.")
+
+ loss_obj = keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True, reduction="sum_over_batch_size"
+ )
+
+ def _loss(tvars, x_b, y_b, sw_b):
+ y_pred, _ = gemma_lm.stateless_call(tvars, ntvars, x_b, training=False)
+ return loss_obj(y_b, y_pred.astype(jnp.float32), sample_weight=sw_b)
+
+ per_grad = jax.grad(_loss, argnums=0)
+
+ def _norms_one(tvars, x_one, y_one, sw_one):
+ # vmap strips the batch axis; functional Gemma needs (1, seq, ...) shapes.
+ x_b = jax.tree.map(lambda v: v[None, ...], x_one)
+ grads = per_grad(tvars, x_b, y_one[None, ...], sw_one[None, ...])
+ return jnp.stack(
+ [jnp.linalg.norm(grads[i].astype(jnp.float32)) for i in cand_indices]
+ )
+
+ @functools.partial(jax.jit, donate_argnums=(0,))
+ def _step(votes, tvars, x, y, sw):
+ in_axes_x = jax.tree.map(lambda _: 0, x)
+ norms = jax.vmap(_norms_one, in_axes=(None, in_axes_x, 0, 0))(
+ tvars, x, y, sw
+ )
+ _, idx = jax.lax.top_k(norms, k=topk)
+ return votes + jax.nn.one_hot(idx, num_cand, dtype=jnp.float32).sum((0, 1))
+
+ def _to_jnp(v):
+ return jnp.asarray(v.numpy() if hasattr(v, "numpy") else v)
+
+ preproc = gemma_lm.preprocessor
+ ds = (
+ train_ds.unbatch()
+ .take(num_probe_samples)
+ .batch(microbatch_size, drop_remainder=True)
+ )
+
+ l2_sens = float(topk) ** 0.5
+ print(
+ f" probe: top-{topk} vote over {num_probe_samples} samples, "
+ f"{num_cand} candidates, L2 sensitivity={l2_sens:.3f}"
+ )
+
+ votes = jnp.zeros((num_cand,), dtype=jnp.float32)
+ n_seen = 0
+ for chunk in ds:
+ x, y, sw = preproc(chunk)
+ x = {k: _to_jnp(v) for k, v in x.items()}
+ y = _to_jnp(y)
+ sw = _to_jnp(sw) if sw is not None else jnp.ones_like(y, dtype=jnp.float32)
+ votes = _step(votes, tvars, x, y, sw)
+ n_seen += microbatch_size
+ if (n_seen // microbatch_size) % 32 == 0:
+ print(f" probe: {n_seen}/{num_probe_samples}")
+
+ if use_noise:
+ noise = (noise_multiplier * l2_sens) * jax.random.normal(
+ jax.random.PRNGKey(seed), (num_cand,), dtype=jnp.float32
+ )
+ votes = votes + noise
+
+ return {cand_paths[i]: float(votes[i]) for i in range(num_cand)}, n_seen
+
+
+def run_probe(
+ gemma_lm,
+ train_ds,
+ *,
+ top_k_percent,
+ num_probe_samples=1024,
+ topk=8,
+ noise_multiplier=6.0,
+ microbatch_size=1,
+ use_noise=True,
+ seed=0,
+ attn_only=True,
+ clear_caches_after=True,
+):
+ """DP probe + selection in one call.
+
+ Per training sample, vote +1 on the topk LoRA candidates with the largest
+ per-sample gradient L2 norm. Add Gaussian noise to the vote histogram
+ (std = noise_multiplier * sqrt(topk)). Return the top `top_k_percent`%
+ layers by noisy score, plus the full ranked list and DP metadata needed
+ for downstream composition.
+
+ Frees the probe-time JIT cache before returning unless told otherwise.
+ """
+ scores, n_seen = _probe_topk_vote(
+ gemma_lm,
+ train_ds,
+ num_probe_samples=num_probe_samples,
+ topk=topk,
+ noise_multiplier=noise_multiplier,
+ microbatch_size=microbatch_size,
+ use_noise=use_noise,
+ seed=seed,
+ attn_only=attn_only,
+ )
+ ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
+ num_keep = max(1, round(len(ranked) * top_k_percent / 100.0))
+ selected = {p for p, _ in ranked[:num_keep]}
+
+ print(
+ f"\n{n_seen} samples; keeping top {num_keep}/{len(ranked)} "
+ f"({top_k_percent}%):"
+ )
+ for i, (p, s) in enumerate(ranked[:num_keep]):
+ print(f" #{i+1:3d} score={s:.3e} {p}")
+
+ if clear_caches_after:
+ gc.collect()
+ jax.clear_caches()
+
+ return ProbeResult(
+ selected_paths=selected,
+ ranked=ranked,
+ n_seen=n_seen,
+ topk=topk,
+ noise_multiplier=noise_multiplier,
+ used_noise=use_noise,
+ )
+
+
+# ---------------------------------------------------------------------------
+# DP accounting: probe (1 Poisson-Gaussian) + DP-SGD (T Poisson-Gaussian).
+# ---------------------------------------------------------------------------
+
+
+def compose_probe_and_training(
+ probe_sigma,
+ num_probe_samples,
+ train_sigma,
+ effective_batch_size,
+ train_steps,
+ train_size,
+ delta,
+):
+ """Renyi-DP composed epsilon at (delta) for probe + T DP-SGD steps."""
+ acc = accountants.RdpAccountantConfig().create_accountant()
+ acc.compose(
+ dp_accounting.PoissonSampledDpEvent(
+ sampling_probability=num_probe_samples / train_size,
+ event=dp_accounting.GaussianDpEvent(probe_sigma),
+ ),
+ 1,
+ )
+ acc.compose(
+ dp_accounting.PoissonSampledDpEvent(
+ sampling_probability=effective_batch_size / train_size,
+ event=dp_accounting.GaussianDpEvent(train_sigma),
+ ),
+ train_steps,
+ )
+ return acc.get_epsilon(delta)
+
+
+def calibrate_train_sigma(target_eps, **kw):
+ """Binary-search sigma_train so compose(probe, T x train) == target_eps."""
+
+ def at(s):
+ return compose_probe_and_training(train_sigma=s, **kw)
+
+ if at(1e10) >= target_eps:
+ raise ValueError(
+ "Probe alone exceeds target_eps; raise noise_multiplier or "
+ "reduce num_probe_samples."
+ )
+ lo, hi = 0.1, 10.0
+ while at(hi) > target_eps:
+ lo, hi = hi, hi * 2
+ while at(lo) <= target_eps and lo > 1e-3:
+ hi, lo = lo, lo * 0.5
+ for _ in range(80):
+ mid = 0.5 * (lo + hi)
+ if at(mid) > target_eps:
+ lo = mid
+ else:
+ hi = mid
+ if hi - lo < 1e-3:
+ break
+ return hi