diff --git a/docs/index.rst b/docs/index.rst index 0f92a202..41fc73d0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -40,6 +40,7 @@ examples_guide _collections/examples/dp_sgd_flax_linen_mnist _collections/examples/dp_sgd_keras_gemma3_lora_finetuning_samsum + _collections/examples/dp_sgd_keras_gemma3_dpsapf _collections/examples/dp_sgd_keras_gemma3_synthetic_data .. toctree:: diff --git a/examples/dp_sgd_keras_gemma3_dpsapf.ipynb b/examples/dp_sgd_keras_gemma3_dpsapf.ipynb new file mode 100644 index 00000000..f43644d1 --- /dev/null +++ b/examples/dp_sgd_keras_gemma3_dpsapf.ipynb @@ -0,0 +1,454 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# https://www.apache.org/licenses/LICENSE-2.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "# Auto-LoRA + DP-SGD fine-tuning of Gemma3\n\n**Copyright 2026 DeepMind Technologies Limited.**\n\nThis notebook adapts **DP-SAPF** (*Saliency-Aware Parameter Fine-tuning of Public Models for Differentially Private Image Synthesis*; [arXiv:2605.30312](https://arxiv.org/abs/2605.30312)) to LLM fine-tuning. The core idea of DP-SAPF: under DP-SGD, the noise added to every trainable parameter scales with the parameter count, so adapting **every** layer wastes privacy budget on uninformative directions. Instead, first run a **DP-aware saliency probe** to identify which layers carry the most useful gradient signal for the task, then fine-tune only that small subset — getting a better signal-to-noise ratio under the same (eps, delta) budget.\n\nConcretely, this notebook extends [`dp_sgd_keras_gemma3_lora_finetuning_samsum.ipynb`](dp_sgd_keras_gemma3_lora_finetuning_samsum.ipynb) with two new pieces:\n\n1. **Probe (DP)**: for each training sample, vote +1 on the top-k LoRA-candidate layers ranked by per-sample gradient L2 norm; add Gaussian noise to the vote histogram (L2 sensitivity = $\\sqrt k$).\n2. **DP-SGD fine-tune**: enable LoRA only on the top-k% layers by score; calibrate the training noise so the **composed** (probe + DP-SGD) cost matches a single target $\\varepsilon$.\n\nSelection-specific code (the probe, top-k voting, LoRA gating, DP composition) lives in [`dpsapf_utils.py`](https://github.com/google-deepmind/jax_privacy/tree/main/examples/dpsapf_utils.py); everything else (data loading, model setup, optimizer, fit, eval) stays inline so the moving parts are visible." + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install and configure" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install -q -U \"keras>=3\" keras-nlp keras-hub rouge-score tqdm ipywidgets\n", + "!pip install -q dp_accounting jaxtyping drjax tensorflow_datasets datasets\n", + "!pip install -q beautifulsoup4 lxml jax_privacy==1.0.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# Redirect big artifacts to a shared scratch dir; set env vars BEFORE\n", + "# importing keras / kagglehub / datasets.\n", + "CACHE_ROOT = \"./\" # change to a writable path\n", + "os.makedirs(CACHE_ROOT, exist_ok=True)\n", + "for k, sub in [\n", + " (\"KERAS_HOME\", \"keras\"),\n", + " (\"KAGGLEHUB_CACHE\", \"kagglehub\"),\n", + " (\"TFDS_DATA_DIR\", \"tfds\"),\n", + " (\"HF_HOME\", \"huggingface\"),\n", + " (\"JAX_COMPILATION_CACHE_DIR\", \"jax_compilation_cache\"),\n", + "]:\n", + " os.environ.setdefault(k, os.path.join(CACHE_ROOT, sub))\n", + " os.makedirs(os.environ[k], exist_ok=True)\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", + "os.environ.setdefault(\"XLA_PYTHON_CLIENT_PREALLOCATE\", \"false\")\n", + "os.environ.setdefault(\"XLA_PYTHON_CLIENT_MEM_FRACTION\", \"0.85\")\n", + "os.environ.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"3\")\n", + "\n", + "import gc\n", + "import keras\n", + "import keras_hub\n", + "import tensorflow as tf\n", + "import tqdm\n", + "import jax\n", + "from jax_privacy import keras_api\n", + "\n", + "# Selection-only helpers — everything else is inline.\n", + "import examples.dpsapf_utils as ALU" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Kaggle login\n", + "\n", + "Needed to download Gemma3 weights. You must **accept the Gemma3 license** in your Kaggle account for each variant ([1B](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_instruct_1b), [4B](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_instruct_4b_text)) or the download returns 403." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import kagglehub\n", + "\n", + "kagglehub.login()\n", + "# Or: export KAGGLE_USERNAME=... KAGGLE_KEY=... before launching the notebook." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hyper-parameters\n", + "\n", + "Two datasets: `cnn_dailymail` (TFDS) and `xsum_hf` (HuggingFace). Both ~200-290K examples — well-sized for DP subsampling amplification." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "TEST_RUN = True\n", + "DATASET = \"cnn_dailymail\" # or \"xsum_hf\"\n", + "\n", + "# Model / training\n", + "GEMMA3_MODEL_TYPE = \"gemma3_instruct_4b_text\"\n", + "SEQUENCE_LENGTH = 1024\n", + "TEST_DS_SEQUENCE_LENGTH = 1024\n", + "EPOCHS = 3\n", + "BATCH_SIZE = 4\n", + "GRADIENT_ACCUMULATION_STEPS = 128 # effective batch = 4 * 128 = 512\n", + "TEST_DS_BATCH_SIZE = 4\n", + "LORA_RANK = 64\n", + "LEARNING_RATE = 3e-3\n", + "SEED = 0\n", + "USE_MIXED_PRECISION = False\n", + "\n", + "# Probe + selection\n", + "PROBE_SAMPLES = 10000\n", + "PROBE_TOPK = 8\n", + "PROBE_NOISE_MULTIPLIER = 40.0 # in units of L2 sensitivity sqrt(topk)\n", + "LORA_TOP_K_PERCENT = 10.0 # 1-5%% wins on CNN/DM at 1024+DP\n", + "LORA_ATTN_ONLY = True\n", + "\n", + "# Composed DP target\n", + "TOTAL_EPSILON = 4.0\n", + "DELTA = 2e-5\n", + "CLIPPING_NORM = 0.001\n", + "\n", + "if TEST_RUN:\n", + " GEMMA3_MODEL_TYPE = \"gemma3_instruct_1b\"\n", + " PROBE_SAMPLES = 10000\n", + "\n", + "DATASET_CFG = ALU.DATASET_REGISTRY[DATASET]\n", + "print(f\"Dataset: {DATASET} ({DATASET_CFG['note']})\")\n", + "print(f\"Model: {GEMMA3_MODEL_TYPE} (test_run={TEST_RUN})\")\n", + "print(f\"Effective train batch: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Data\n\nLoad + format using helpers from `dpsapf_utils` (registry + HF/TFDS dispatch + prompt formatter), then shuffle/batch inline." + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fmt = ALU.make_source_to_gemma3_format(DATASET_CFG)\n", + "TRAIN_DS = ALU.load_dataset_split(DATASET_CFG, \"train\").map(fmt)\n", + "VALIDATION_DS = ALU.load_dataset_split(DATASET_CFG, \"validation\").map(fmt)\n", + "TEST_DS = ALU.load_dataset_split(DATASET_CFG, \"test\").map(fmt)\n", + "\n", + "TRAIN_SIZE = int(TRAIN_DS.cardinality().numpy())\n", + "print(\n", + " f\"Train: {TRAIN_SIZE} Val: {int(VALIDATION_DS.cardinality().numpy())} \"\n", + " f\"Test: {int(TEST_DS.cardinality().numpy())}\"\n", + ")\n", + "\n", + "TRAIN_DS_BATCHED = TRAIN_DS.shuffle(2048).batch(BATCH_SIZE, drop_remainder=True)\n", + "VALIDATION_DS_BATCHED = VALIDATION_DS.batch(BATCH_SIZE, drop_remainder=True)\n", + "TEST_DS_BATCHED = TEST_DS.batch(TEST_DS_BATCH_SIZE)\n", + "\n", + "EXAMPLE = (\n", + " VALIDATION_DS.take(1)\n", + " .batch(1, drop_remainder=True)\n", + " .as_numpy_iterator()\n", + " .next()\n", + ")\n", + "for k, v in EXAMPLE.items():\n", + " s = v[0].decode(\"utf-8\")\n", + " print(f'{k}: \"{s[:200]}{\"...\" if len(s) > 200 else \"\"}\"\\n')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Gemma3 (no LoRA yet) and run baseline inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_WEIGHTS_DTYPE = None\n", + "if USE_MIXED_PRECISION:\n", + " keras.mixed_precision.set_global_policy(\"mixed_bfloat16\")\n", + " MODEL_WEIGHTS_DTYPE = \"bfloat16\"\n", + "\n", + "keras.distribution.set_distribution(keras.distribution.DataParallel())\n", + "\n", + "gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(\n", + " GEMMA3_MODEL_TYPE, dtype=MODEL_WEIGHTS_DTYPE\n", + ")\n", + "gemma_lm.preprocessor.sequence_length = SEQUENCE_LENGTH\n", + "\n", + "\n", + "def show_inference():\n", + " print(gemma_lm.generate(EXAMPLE[\"prompts\"])[0])\n", + " print(f\"\\nGold:\\n{EXAMPLE['responses'][0].decode('utf-8')}\")\n", + "\n", + "\n", + "print(\"=== Baseline (before fine-tuning) ===\")\n", + "show_inference()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DP probe + layer selection\n", + "\n", + "Per sample, vote +1 on the top-k LoRA-candidate layers ranked by per-sample gradient L2 norm. Gaussian noise on the vote histogram; L2 sensitivity = $\\sqrt k$ under ADD/REMOVE neighboring. `ALU.run_probe` returns the noisy scores, the top-k% selected paths, and DP metadata used downstream for composition." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "probe = ALU.run_probe(\n", + " gemma_lm,\n", + " TRAIN_DS_BATCHED,\n", + " top_k_percent=LORA_TOP_K_PERCENT,\n", + " num_probe_samples=PROBE_SAMPLES,\n", + " topk=PROBE_TOPK,\n", + " noise_multiplier=PROBE_NOISE_MULTIPLIER,\n", + " attn_only=LORA_ATTN_ONLY,\n", + " seed=SEED,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Enable LoRA on selected layers, set up composed DP-SGD\n", + "\n", + "Both the probe and DP-SGD are Poisson-subsampled Gaussian mechanisms. We solve for $\\sigma_{\\text{train}}$ such that the Renyi-DP composed cost equals `TOTAL_EPSILON`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ALU.enable_lora_on_paths(gemma_lm.backbone, probe.selected_paths, LORA_RANK)\n", + "print(f\"LoRA enabled on {len(probe.selected_paths)} layers (rank={LORA_RANK}).\")\n", + "\n", + "STEPS_PER_EPOCH = TRAIN_SIZE // BATCH_SIZE\n", + "TOTAL_TRAIN_STEPS = EPOCHS * STEPS_PER_EPOCH\n", + "EFFECTIVE_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS\n", + "\n", + "SIGMA_TRAIN = ALU.calibrate_train_sigma(\n", + " target_eps=TOTAL_EPSILON,\n", + " probe_sigma=probe.noise_multiplier,\n", + " num_probe_samples=probe.n_seen,\n", + " effective_batch_size=EFFECTIVE_BATCH_SIZE,\n", + " train_steps=TOTAL_TRAIN_STEPS,\n", + " train_size=TRAIN_SIZE,\n", + " delta=DELTA,\n", + ")\n", + "composed_eps = ALU.compose_probe_and_training(\n", + " probe_sigma=probe.noise_multiplier,\n", + " num_probe_samples=probe.n_seen,\n", + " train_sigma=SIGMA_TRAIN,\n", + " effective_batch_size=EFFECTIVE_BATCH_SIZE,\n", + " train_steps=TOTAL_TRAIN_STEPS,\n", + " train_size=TRAIN_SIZE,\n", + " delta=DELTA,\n", + ")\n", + "print(\n", + " f\"target total_eps={TOTAL_EPSILON} -> sigma_train={SIGMA_TRAIN:.4f}; \"\n", + " f\"composed (eps, delta) ~= ({composed_eps:.4f}, {DELTA:.1e})\"\n", + ")\n", + "\n", + "dp_cfg = keras_api.DPKerasConfig(\n", + " epsilon=TOTAL_EPSILON,\n", + " delta=DELTA,\n", + " noise_multiplier=SIGMA_TRAIN,\n", + " clipping_norm=CLIPPING_NORM,\n", + " batch_size=BATCH_SIZE,\n", + " train_steps=TOTAL_TRAIN_STEPS,\n", + " train_size=TRAIN_SIZE,\n", + " gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n", + " seed=SEED,\n", + ")\n", + "gemma_lm = keras_api.make_private(gemma_lm, dp_cfg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compile and fine-tune\n", + "\n", + "`fit()` may be called only once per `make_private` — the declared (eps, delta) budget rules out further training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = keras.optimizers.Adam(\n", + " learning_rate=LEARNING_RATE,\n", + " gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n", + ")\n", + "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n", + "gemma_lm.compile(\n", + " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " optimizer=optimizer,\n", + " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", + ")\n", + "\n", + "gemma_lm.fit(\n", + " x=TRAIN_DS_BATCHED,\n", + " epochs=EPOCHS,\n", + " validation_data=VALIDATION_DS_BATCHED,\n", + ")\n", + "\n", + "print(\"\\n=== After fine-tuning ===\")\n", + "show_inference()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate ROUGE on the test split" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gemma_lm.preprocessor.sequence_length = TEST_DS_SEQUENCE_LENGTH\n", + "\n", + "METRIC_FNS = {\n", + " \"rouge_1\": keras_hub.metrics.RougeN(order=1),\n", + " \"rouge_2\": keras_hub.metrics.RougeN(order=2),\n", + " \"rouge_l\": keras_hub.metrics.RougeL(),\n", + "}\n", + "\n", + "\n", + "def _common_prefix(a, b):\n", + " i = 0\n", + " while i < len(a) and i < len(b) and a[i] == b[i]:\n", + " i += 1\n", + " return i\n", + "\n", + "\n", + "for batch in tqdm.tqdm(TEST_DS_BATCHED):\n", + " prompts = [p.decode(\"utf-8\") for p in batch[\"prompts\"].numpy()]\n", + " outputs = gemma_lm.generate(prompts)\n", + " outputs = [o[_common_prefix(p, o) :] for p, o in zip(prompts, outputs)]\n", + " targets = [s.decode(\"utf-8\") for s in batch[\"responses\"].numpy()]\n", + " for m in METRIC_FNS.values():\n", + " m.update_state(targets, outputs)\n", + "\n", + "RESULT = {k: float(m.result()[\"f1_score\"]) for k, m in METRIC_FNS.items()}\n", + "print(RESULT)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Results\n", + "\n", + "ROUGE F1 on the test split, Gemma3-1B + LoRA (rank 64) + DP-SGD at $\\varepsilon=4$, sequence length 1024. **Baseline** is `keras_hub`'s default `enable_lora()` (query + value only); **100%** enables LoRA on every attention projection (query/key/value/attention_output) in every block; **top 5%** keeps only the top-5% attention layers ranked by the DP probe.\n", + "\n", + "| Dataset | Method | ROUGE-1 | ROUGE-2 | ROUGE-L |\n", + "|---|---|---|---|---|\n", + "| CNN/DailyMail | baseline (query + value, keras_hub default) | 0.217 | 0.084 | 0.157 |\n", + "| CNN/DailyMail | 100% attn projections | 0.154 | 0.046 | 0.114 |\n", + "| CNN/DailyMail | **top 5% attn (probe)** | **0.237** | **0.096** | **0.168** |\n", + "| XSum | baseline (query + value, keras_hub default) | 0.247 | 0.064 | 0.192 |\n", + "| XSum | 100% attn projections | 0.205 | 0.046 | 0.158 |\n", + "| XSum | **top 5% attn (probe)** | **0.264** | **0.073** | **0.203** |\n", + "\n", + "Selecting the top 5% attention layers via the DP probe beats the `keras_hub` default on both datasets, and is dramatically better than enabling LoRA on every attention projection — fewer trainable parameters means less DP noise per useful direction. The effect is consistent across two different summarization styles (multi-sentence highlights vs single-sentence abstract)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Citation\n", + "\n", + "If you find this notebook useful in your research, please cite:\n", + "\n", + "```bibtex\n", + "@article{gong2026dpsapf,\n", + " title = {DP-SAPF: Saliency-Aware Parameter Fine-tuning of Public Models for Differentially Private Image Synthesis},\n", + " author = {Chen Gong and Kecen Li and Zinan Lin and Tianhao Wang},\n", + " journal = {arXiv preprint arXiv:2605.30312},\n", + " year = {2026},\n", + " url = {https://arxiv.org/abs/2605.30312}\n", + "}\n", + "```" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/dpsapf_utils.py b/examples/dpsapf_utils.py new file mode 100644 index 00000000..46d58a99 --- /dev/null +++ b/examples/dpsapf_utils.py @@ -0,0 +1,395 @@ +# Copyright 2026 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Selection-only helpers for the auto-LoRA + DP-SGD notebook. + +Scope: the DP probe, layer-selection thresholding, and the probe-side DP +accounting. Data loading, model setup, training, and ROUGE eval stay in +the notebook so the moving parts are visible. + +Public API: + DATASET_REGISTRY # cnn_dailymail, xsum_hf + make_source_to_gemma3_format(cfg) # tf.data.map prompt/response fn + load_dataset_split(cfg, split_spec) # TFDS vs HuggingFace dispatcher + + get_lora_candidate_layers(backbone, ...) # enumerate LoRA-eligible layers + run_probe(gemma_lm, train_ds, top_k_percent, **kw) -> ProbeResult + enable_lora_on_paths(backbone, paths, rank) + + compose_probe_and_training(...) # RDP composition (probe + DP-SGD) + calibrate_train_sigma(target_eps, **kw) # binary search for sigma_train +""" + +import dataclasses +import functools +import gc + +import dp_accounting +import jax +import jax.numpy as jnp +import keras +import tensorflow as tf +import tensorflow_datasets as tfds + +from jax_privacy.accounting import accountants + +# HuggingFace `datasets` is only needed for the xsum_hf entry of +# DATASET_REGISTRY; keep it optional so TFDS-only users don't need it. +try: + from datasets import load_dataset # pytype: disable=import-error +except ImportError: + load_dataset = None + + +# --------------------------------------------------------------------------- +# Dataset registry + minimal data helpers (needed to materialise the probe input +# in the right prompt/response format). +# --------------------------------------------------------------------------- + +DATASET_REGISTRY = { + "cnn_dailymail": { + "loader": "tfds", + "tfds_name": "cnn_dailymail", + "input_field": "article", + "output_field": "highlights", + "prompt_prefix": "Summarize the following news article:\n", + "prompt_suffix": "\nHighlights:\n", + "note": ( + "~287K news articles -> multi-sentence highlights. " + "Long inputs. Requires `pip install beautifulsoup4 lxml`." + ), + }, + "xsum_hf": { + "loader": "hf", + "hf_name": "EdinburghNLP/xsum", + "input_field": "document", + "output_field": "summary", + "prompt_prefix": "Summarize the following article in one sentence:\n", + "prompt_suffix": "\nSummary:\n", + "note": "~204K BBC articles -> 1-sentence summaries via HuggingFace.", + }, +} + + +def make_source_to_gemma3_format(cfg): + """Returns a tf.data.map fn that emits {prompts, responses} string dicts.""" + in_field, out_field = cfg["input_field"], cfg["output_field"] + prefix, suffix = cfg["prompt_prefix"], cfg["prompt_suffix"] + + def fn(d): + return { + "prompts": tf.strings.join([prefix, d[in_field], suffix]), + "responses": d[out_field], + } + + return fn + + +def load_dataset_split(cfg, split_spec): + """Loads one split, dispatching by cfg['loader']. Returns a + tf.data.Dataset with known cardinality (so train_size is meaningful).""" + if cfg.get("loader") == "hf": + if load_dataset is None: + raise ImportError( + "The HuggingFace `datasets` package is required for the xsum_hf " + "dataset. Install it with `pip install datasets`." + ) + in_field, out_field = cfg["input_field"], cfg["output_field"] + hf_ds = load_dataset(cfg["hf_name"], split=split_spec) + + def gen(): + for ex in hf_ds: + yield {in_field: ex[in_field], out_field: ex[out_field]} + + return tf.data.Dataset.from_generator( + gen, + output_signature={ + in_field: tf.TensorSpec(shape=(), dtype=tf.string), + out_field: tf.TensorSpec(shape=(), dtype=tf.string), + }, + ).apply(tf.data.experimental.assert_cardinality(len(hf_ds))) + + return tfds.load(cfg["tfds_name"], split=split_spec) + + +# --------------------------------------------------------------------------- +# Layer enumeration + selective LoRA enabling. +# --------------------------------------------------------------------------- + + +def get_lora_candidate_layers(backbone, attn_only=False): + """Dense / EinsumDense sublayers of `backbone` eligible for a LoRA adapter.""" + out, seen = [], set() + # pylint: disable-next=protected-access + for layer in backbone._flatten_layers(recursive=True, include_self=False): + if id(layer) in seen: + continue + if not isinstance(layer, (keras.layers.Dense, keras.layers.EinsumDense)): + continue + if not (hasattr(layer, "kernel") and hasattr(layer, "enable_lora")): + continue + if attn_only and "attention" not in layer.path: + continue + seen.add(id(layer)) + out.append(layer) + return out + + +def enable_lora_on_paths(backbone, paths, rank): + """Enable LoRA on layers whose `.path` is in `paths`; freeze the rest.""" + p2l = { + l.path: l for l in get_lora_candidate_layers(backbone, attn_only=False) + } + ids = {id(p2l[p]) for p in paths if p in p2l} + backbone.trainable = True + backbone._lora_rank = rank # pylint: disable=protected-access + # pylint: disable-next=protected-access + for layer in backbone._flatten_layers(include_self=False): + if id(layer) in ids: + layer.trainable = True + layer.enable_lora(rank=rank) + bias = getattr(layer, "bias", None) + if bias is not None: + bias.trainable = False + else: + layer.trainable = False + return ids + + +# --------------------------------------------------------------------------- +# DP probe: per-sample top-k voting over LoRA candidates. +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class ProbeResult: + selected_paths: set # paths kept after `top_k_percent` threshold + ranked: list # [(path, score), ...] descending + n_seen: int # number of probe samples actually processed + topk: int # k in the per-sample top-k vote + noise_multiplier: float # in units of sqrt(topk) sensitivity + used_noise: bool + + +def _probe_topk_vote( + gemma_lm, + train_ds, + *, + num_probe_samples, + topk, + noise_multiplier, + microbatch_size, + use_noise, + seed, + attn_only, +): + """Internal: returns (layer_scores: dict[path->score], n_seen).""" + trainable_vars = list(gemma_lm.trainable_variables) + ntvars = [v.value for v in gemma_lm.non_trainable_variables] + tvars = [v.value for v in trainable_vars] + + candidates = get_lora_candidate_layers(gemma_lm.backbone, attn_only=attn_only) + kid_to_path = {id(l.kernel): l.path for l in candidates} + cand_indices = tuple( + i for i, v in enumerate(trainable_vars) if id(v) in kid_to_path + ) + cand_paths = [kid_to_path[id(trainable_vars[i])] for i in cand_indices] + num_cand = len(cand_indices) + if num_cand == 0: + raise ValueError("No LoRA-candidate layers found.") + if topk > num_cand: + raise ValueError(f"topk={topk} > num_candidates={num_cand}.") + + loss_obj = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="sum_over_batch_size" + ) + + def _loss(tvars, x_b, y_b, sw_b): + y_pred, _ = gemma_lm.stateless_call(tvars, ntvars, x_b, training=False) + return loss_obj(y_b, y_pred.astype(jnp.float32), sample_weight=sw_b) + + per_grad = jax.grad(_loss, argnums=0) + + def _norms_one(tvars, x_one, y_one, sw_one): + # vmap strips the batch axis; functional Gemma needs (1, seq, ...) shapes. + x_b = jax.tree.map(lambda v: v[None, ...], x_one) + grads = per_grad(tvars, x_b, y_one[None, ...], sw_one[None, ...]) + return jnp.stack( + [jnp.linalg.norm(grads[i].astype(jnp.float32)) for i in cand_indices] + ) + + @functools.partial(jax.jit, donate_argnums=(0,)) + def _step(votes, tvars, x, y, sw): + in_axes_x = jax.tree.map(lambda _: 0, x) + norms = jax.vmap(_norms_one, in_axes=(None, in_axes_x, 0, 0))( + tvars, x, y, sw + ) + _, idx = jax.lax.top_k(norms, k=topk) + return votes + jax.nn.one_hot(idx, num_cand, dtype=jnp.float32).sum((0, 1)) + + def _to_jnp(v): + return jnp.asarray(v.numpy() if hasattr(v, "numpy") else v) + + preproc = gemma_lm.preprocessor + ds = ( + train_ds.unbatch() + .take(num_probe_samples) + .batch(microbatch_size, drop_remainder=True) + ) + + l2_sens = float(topk) ** 0.5 + print( + f" probe: top-{topk} vote over {num_probe_samples} samples, " + f"{num_cand} candidates, L2 sensitivity={l2_sens:.3f}" + ) + + votes = jnp.zeros((num_cand,), dtype=jnp.float32) + n_seen = 0 + for chunk in ds: + x, y, sw = preproc(chunk) + x = {k: _to_jnp(v) for k, v in x.items()} + y = _to_jnp(y) + sw = _to_jnp(sw) if sw is not None else jnp.ones_like(y, dtype=jnp.float32) + votes = _step(votes, tvars, x, y, sw) + n_seen += microbatch_size + if (n_seen // microbatch_size) % 32 == 0: + print(f" probe: {n_seen}/{num_probe_samples}") + + if use_noise: + noise = (noise_multiplier * l2_sens) * jax.random.normal( + jax.random.PRNGKey(seed), (num_cand,), dtype=jnp.float32 + ) + votes = votes + noise + + return {cand_paths[i]: float(votes[i]) for i in range(num_cand)}, n_seen + + +def run_probe( + gemma_lm, + train_ds, + *, + top_k_percent, + num_probe_samples=1024, + topk=8, + noise_multiplier=6.0, + microbatch_size=1, + use_noise=True, + seed=0, + attn_only=True, + clear_caches_after=True, +): + """DP probe + selection in one call. + + Per training sample, vote +1 on the topk LoRA candidates with the largest + per-sample gradient L2 norm. Add Gaussian noise to the vote histogram + (std = noise_multiplier * sqrt(topk)). Return the top `top_k_percent`% + layers by noisy score, plus the full ranked list and DP metadata needed + for downstream composition. + + Frees the probe-time JIT cache before returning unless told otherwise. + """ + scores, n_seen = _probe_topk_vote( + gemma_lm, + train_ds, + num_probe_samples=num_probe_samples, + topk=topk, + noise_multiplier=noise_multiplier, + microbatch_size=microbatch_size, + use_noise=use_noise, + seed=seed, + attn_only=attn_only, + ) + ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True) + num_keep = max(1, round(len(ranked) * top_k_percent / 100.0)) + selected = {p for p, _ in ranked[:num_keep]} + + print( + f"\n{n_seen} samples; keeping top {num_keep}/{len(ranked)} " + f"({top_k_percent}%):" + ) + for i, (p, s) in enumerate(ranked[:num_keep]): + print(f" #{i+1:3d} score={s:.3e} {p}") + + if clear_caches_after: + gc.collect() + jax.clear_caches() + + return ProbeResult( + selected_paths=selected, + ranked=ranked, + n_seen=n_seen, + topk=topk, + noise_multiplier=noise_multiplier, + used_noise=use_noise, + ) + + +# --------------------------------------------------------------------------- +# DP accounting: probe (1 Poisson-Gaussian) + DP-SGD (T Poisson-Gaussian). +# --------------------------------------------------------------------------- + + +def compose_probe_and_training( + probe_sigma, + num_probe_samples, + train_sigma, + effective_batch_size, + train_steps, + train_size, + delta, +): + """Renyi-DP composed epsilon at (delta) for probe + T DP-SGD steps.""" + acc = accountants.RdpAccountantConfig().create_accountant() + acc.compose( + dp_accounting.PoissonSampledDpEvent( + sampling_probability=num_probe_samples / train_size, + event=dp_accounting.GaussianDpEvent(probe_sigma), + ), + 1, + ) + acc.compose( + dp_accounting.PoissonSampledDpEvent( + sampling_probability=effective_batch_size / train_size, + event=dp_accounting.GaussianDpEvent(train_sigma), + ), + train_steps, + ) + return acc.get_epsilon(delta) + + +def calibrate_train_sigma(target_eps, **kw): + """Binary-search sigma_train so compose(probe, T x train) == target_eps.""" + + def at(s): + return compose_probe_and_training(train_sigma=s, **kw) + + if at(1e10) >= target_eps: + raise ValueError( + "Probe alone exceeds target_eps; raise noise_multiplier or " + "reduce num_probe_samples." + ) + lo, hi = 0.1, 10.0 + while at(hi) > target_eps: + lo, hi = hi, hi * 2 + while at(lo) <= target_eps and lo > 1e-3: + hi, lo = lo, lo * 0.5 + for _ in range(80): + mid = 0.5 * (lo + hi) + if at(mid) > target_eps: + lo = mid + else: + hi = mid + if hi - lo < 1e-3: + break + return hi