diff --git a/colabs/prefix_finetuning.ipynb b/colabs/prefix_finetuning.ipynb new file mode 100644 index 00000000..309bbeef --- /dev/null +++ b/colabs/prefix_finetuning.ipynb @@ -0,0 +1,423 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "DDUNCwVslByF" + }, + "source": [ + "# Prefix-Tuning\n", + "\n", + "This is an example on fine-tuning Gemma with Prefix tokens. It's best to first read the [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) colab to understand this one.\n", + "\n", + "See the [Prefix sampling](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/prefix_sampling.ipynb) if you just want to do inference with Prefix-Tuning.\n" + ] + }, + { + "metadata": { + "id": "FgQUqR_cMZ2h" + }, + "cell_type": "code", + "source": [ + "!pip install -q gemma" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TZQiYQy7EJe3" + }, + "outputs": [], + "source": [ + "# Common imports\n", + "import os\n", + "import optax\n", + "import treescope\n", + "\n", + "# Gemma imports\n", + "from kauldron import kd\n", + "from gemma import gm" + ] + }, + { + "metadata": { + "id": "vjp0xiR13fj6" + }, + "cell_type": "markdown", + "source": [ + "By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):" + ] + }, + { + "metadata": { + "id": "v4XpXYV13fj6" + }, + "cell_type": "code", + "source": [ + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\"" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5Yi0BzpgEMwS" + }, + "source": [ + "\n", + "## Config updates\n", + "\n", + "If you're familiar with the [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) tutorial, switching to Prefix-Tuning only require 3 changes to the trainer." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GV8_s4cTS7r8" + }, + "source": [ + "### 1. Model\n", + "\n", + "Wrap the model in the `gm.nn.PrefixTuning.from_model`. This will apply load the learnable key-values tokens into the cache." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "V1hVhUsNTDNs" + }, + "outputs": [], + "source": [ + "model = gm.nn.PrefixTuning.from_model(\n", + " prefix_length=100,\n", + " global_layers_only=True,\n", + " model=gm.nn.Gemma3_4B(text_only=True, tokens=\"batch.input\"),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bWpNYlosfA4p" + }, + "source": [ + "Internally, this uses the [`gemma.peft`](https://github.com/google-deepmind/gemma/blob/main/gemma/peft) mini-library to perform model surgery." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IVgOieKuS9Ac" + }, + "source": [ + "### 2. Checkpoint\n", + "\n", + "Wrap the init transform in a `gm.ckpts.SkipPeft`. The wrapper is required because the param structure with and without Prefix-Tuning is different.\n", + "\n", + "Only the initial pretrained weights are loaded, but the LoRA weights are kept to their random initialization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EIQ3yOhZe8Ep" + }, + "outputs": [], + "source": [ + "init_transform = gm.ckpts.SkipPeft(\n", + " wrapped=gm.ckpts.LoadCheckpoint(\n", + " path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NjJLqi_Xgrhh" + }, + "source": [ + "Note: If you're loading the weights directly with `gm.ckpts.load_params`, you can use the `peft.split_params` and `peft.merge_params` instead. See [Prefix sampling](third_party/py/gemma/colabs/prefix_sampling.ipynb) for an example." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M9uec_APS_oi" + }, + "source": [ + "### 3. Optimizer\n", + "\n", + "Finally, we add a mask to the optimizer (with `kd.optim.partial_updates`), so only the Prefix weights are trained." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "g7HJYT4ae8nV" + }, + "outputs": [], + "source": [ + "optimizer = kd.optim.partial_updates(\n", + " optax.adafactor(learning_rate=1e-2),\n", + " # We only optimize the Prefix weights. The rest of the model is frozen.\n", + " mask=kd.optim.select(\"prefix_.*\"),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IY-2uskIj5z7" + }, + "source": [ + "## Training" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2hWTB015lM0Z" + }, + "source": [ + "### Data pipeline\n", + "\n", + "Like for the [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) example, we recreate the tokenizer:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bywIkAHklSlX" + }, + "outputs": [], + "source": [ + "tokenizer = gm.text.Gemma3Tokenizer()\n", + "\n", + "tokenizer.encode('This is an example sentence', add_bos=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2_J-Op0DlSNv" + }, + "source": [ + "And the data pipeline:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qf3-uXF6n2e0" + }, + "outputs": [], + "source": [ + "ds = kd.data.py.Tfds(\n", + " name='mtnt/en-fr',\n", + " split='train',\n", + " shuffle=True,\n", + " batch_size=8,\n", + " transforms=[\n", + " # Create the model inputs/targets/loss_mask.\n", + " gm.data.Seq2SeqTask(\n", + " # Select which field from the dataset to use.\n", + " # https://www.tensorflow.org/datasets/catalog/mtnt\n", + " in_prompt='src',\n", + " in_response='dst',\n", + " # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}\n", + " out_input='input',\n", + " out_target='target',\n", + " out_target_mask='loss_mask',\n", + " tokenizer=tokenizer,\n", + " # Padding parameters\n", + " max_length=200,\n", + " truncate=True,\n", + " ),\n", + " ],\n", + ")\n", + "\n", + "ex = ds[0]\n", + "\n", + "treescope.show(ex)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3ny2J07G2X7i" + }, + "source": [ + "We can decode an example from the batch to inspect the model input and check it is properly formatted:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ep2uhBLh07cw" + }, + "outputs": [], + "source": [ + "text = tokenizer.decode(ex['input'][0])\n", + "\n", + "print(text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L_ND9CJDlcSy" + }, + "source": [ + "### Trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S3fXHa_4LnEH" + }, + "source": [ + "We then create the trainer, reusing the `model`, `init_transform` and `optimizer` created above:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Bv854FDSn7Z-" + }, + "outputs": [], + "source": [ + "trainer = kd.train.Trainer(\n", + " seed=42, # The seed of enlightenment\n", + " workdir='/tmp/ckpts', # TODO(epot): Make the workdir optional by default\n", + " # Dataset\n", + " train_ds=ds,\n", + " # Model\n", + " model=model,\n", + " init_transform=init_transform,\n", + " # Training parameters\n", + " num_train_steps=5000,\n", + " train_losses={\n", + " \"loss\": kd.losses.SoftmaxCrossEntropyWithIntLabels(\n", + " logits=\"preds.logits\",\n", + " labels=\"batch.target\",\n", + " mask=\"batch.loss_mask\",\n", + " ),\n", + " },\n", + " optimizer=optimizer,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xd1RcRekMkRG" + }, + "source": [ + "Trainning can be launched with the `.train()` method.\n", + "\n", + "Note that the trainer like the model are immutables, so it does not store the state nor params. Instead the state containing the trained parameters is returned." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xvIDsFPz75GT" + }, + "outputs": [], + "source": [ + "state, aux = trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QAbX_hwileZF" + }, + "source": [ + "## Evaluation\n", + "\n", + "Here, we only perform a qualitative evaluation by sampling a prompt.\n", + "\n", + "For more info on evals:\n", + "\n", + "* See the [sampling](https://gemma-llm.readthedocs.io/en/latest/sampling.html) tutorial for more info on running inference.\n", + "* To add evals during training, see the Kauldron [evaluator](https://kauldron.readthedocs.io/en/latest/eval.html) documentation.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fOrkpxlAlf2V" + }, + "outputs": [], + "source": [ + "sampler = gm.text.ChatSampler(\n", + " model=model,\n", + " params=state.params,\n", + " tokenizer=tokenizer,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x54YaAteRV94" + }, + "source": [ + "We test a sentence, using the same formatting used during fine-tuning:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yM0l9EnPMdHf" + }, + "outputs": [], + "source": [ + "sampler.chat('I\\'m feeling happy today!')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sPQUGkR3ZcO_" + }, + "source": [ + "We can confirm sampling works.\n", + "Note: The model predictions are bound to drift with with the addition of randomly initialized prefix tokens." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "last_runtime": {}, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/colabs/prefix_sampling.ipynb b/colabs/prefix_sampling.ipynb new file mode 100644 index 00000000..0e953e44 --- /dev/null +++ b/colabs/prefix_sampling.ipynb @@ -0,0 +1,261 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "qKlB5QTDIV6S" + }, + "source": [ + "# Prefix Tuning (Sampling)\n", + "Example on using Prefix Tuning with Gemma (for inference)." + ] + }, + { + "metadata": { + "id": "TR-L25KVKT_F" + }, + "cell_type": "code", + "source": [ + "!pip install -q gemma" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "I6fEKB1tISVW" + }, + "outputs": [], + "source": [ + "# Common imports\n", + "import os\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import treescope\n", + "\n", + "# Gemma imports\n", + "from gemma import gm\n", + "from gemma import peft" + ] + }, + { + "metadata": { + "id": "cxGT2XeU4L47" + }, + "cell_type": "markdown", + "source": [ + "By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):" + ] + }, + { + "metadata": { + "id": "o4MidM--4L47" + }, + "cell_type": "code", + "source": [ + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\"" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-kdAZkvOIryQ" + }, + "source": [ + "## Initializing the model\n", + "\n", + "To use Gemma with Prefix Tuning, simply wrap any Gemma model in `gm.nn.PrefixTuning.from_model`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x-BbrzCVIupV" + }, + "outputs": [], + "source": [ + "model = gm.nn.PrefixTuning.from_model(\n", + " prefix_length=100,\n", + " global_layers_only=True,\n", + " model=gm.nn.Gemma3_4B(text_only=True),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hI3Lg07SJff4" + }, + "source": [ + "Initialize the weights:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1shC1DpiJfsw" + }, + "outputs": [], + "source": [ + "token_ids = jnp.zeros((1, 256,), dtype=jnp.int32) # Create the (batch_size, seq_length)\n", + "\n", + "params = model.init(\n", + " jax.random.key(0),\n", + " token_ids,\n", + ")\n", + "\n", + "params = params['params']" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "T3dWILqKKzG3" + }, + "source": [ + "Inspect the params shape/structure. We can see Prefix weights have been added." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LMq2Z9nXKcad" + }, + "outputs": [], + "source": [ + "treescope.show(params)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bGJl5YpKKOf-" + }, + "source": [ + "Restore the pre-trained params. We use `peft.split_params` and `peft.merge_params` to replace the randomly initialized params with the pre-trained ones.\n", + "\n", + "When using `gm.ckpts.load_params`, make sure to pass the `params=original` kwarg. This ensure that:\n", + "\n", + "* The memory from the old params is released (so only a single copy of the weights stays in memory)\n", + "* The restored params reuse the same sharding as the input (here there's no sharding, so isn't required)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AcO6oBuLKNjb" + }, + "outputs": [], + "source": [ + "# Splits the params into non-LoRA and LoRA weights\n", + "original, lora = peft.split_params(params)\n", + "\n", + "# Load the params from the checkpoint\n", + "original = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT, params=original)\n", + "\n", + "# Merge the pretrained params back with LoRA\n", + "params = peft.merge_params(original, lora)" + ] + }, + { + "metadata": { + "id": "b8y4YAAi9_Sv" + }, + "cell_type": "markdown", + "source": [ + "## Fine-tuning\n", + "\n", + "See our [finetuning guide](https://gemma-llm.readthedocs.io/en/latest/lora_finetuning.html) for more info.\n", + "\n", + "For a end-to-end finetuning example, refer to [prefix-Tuning](third_party/py/gemma/colabs/prefix_finetuning.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MvsQbQM4I4Cs" + }, + "source": [ + "## Inference\n", + "\n", + "Here's an example of running a single model call:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eqU7a4eCI5Wr" + }, + "outputs": [], + "source": [ + "tokenizer = gm.text.Gemma3Tokenizer()\n", + "\n", + "prompt = tokenizer.encode('The capital of France is')\n", + "prompt = jnp.asarray([tokenizer.special_tokens.BOS] + prompt)\n", + "\n", + "\n", + "# Run the model\n", + "out = model.apply(\n", + " {'params': params},\n", + " tokens=prompt,\n", + " return_last_only=True, # Only predict the last token\n", + ")\n", + "\n", + "\n", + "# Show the token distribution\n", + "tokenizer.plot_logits(out.logits)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6dOSL9MHuMUa" + }, + "source": [ + "To sample an entire sentence:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_ckwREdyqown" + }, + "outputs": [], + "source": [ + "sampler = gm.text.ChatSampler(\n", + " model=model,\n", + " params=params,\n", + " tokenizer=tokenizer,\n", + ")\n", + "\n", + "sampler.chat('The capital of France is?')" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": {}, + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/gemma/gm/ckpts/__init__.py b/gemma/gm/ckpts/__init__.py index a0794b33..ce1c4758 100644 --- a/gemma/gm/ckpts/__init__.py +++ b/gemma/gm/ckpts/__init__.py @@ -22,3 +22,4 @@ from gemma.gm.ckpts._lora import SkipLoRA from gemma.gm.ckpts._paths import CheckpointPath from gemma.gm.ckpts._policy import AnchoredPolicyLoader +SkipPeft = SkipLoRA # Alias for SkipLoRA with a more generic name. diff --git a/gemma/gm/nn/__init__.py b/gemma/gm/nn/__init__.py index 59fd8a26..676c3033 100644 --- a/gemma/gm/nn/__init__.py +++ b/gemma/gm/nn/__init__.py @@ -49,6 +49,7 @@ # Wrapper (LoRA, quantization, DPO,...) # **************************************************************************** from gemma.gm.nn._lora import LoRA + from gemma.gm.nn._prefix import PrefixTuning from gemma.gm.nn._quantization import QuantizationAwareWrapper from gemma.gm.nn._quantization import IntWrapper from gemma.gm.nn._policy import AnchoredPolicy diff --git a/gemma/gm/nn/_prefix.py b/gemma/gm/nn/_prefix.py new file mode 100644 index 00000000..f3ff9016 --- /dev/null +++ b/gemma/gm/nn/_prefix.py @@ -0,0 +1,302 @@ +# Copyright 2026 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prefix tuning module for Gemma KV cache injection.""" + +import dataclasses +from flax import linen as nn +from gemma.gm.nn import _config +from gemma.gm.nn import _modules +from gemma.gm.nn import _transformer +import jax +import jax.numpy as jnp + +# Very large negative position used to mask a token from the sliding window +# in local attention layers. By setting the position to a value far outside +# the typical range, tokens associated with this position are effectively +# excluded from the local attention window, preventing them from being attended to. +_MASKED_TOKEN_POSITION = -1000000 + + +class PrefixTuning(_transformer.Transformer): + """Wrapper around Gemma model to apply prefix tuning via KV cache injection. + + This class extends a Gemma Transformer to inject learnable prefixes into the + KV cache, allowing the model to condition on these prefixes. + + Attributes: + prefix_length: The length of the prefix to inject. + global_layers_only: Whether to apply prefixes only to global layers. + Note: If False, prefixes are applied to all layers. When using local + attention layers, ensure `prefix_length` is within the sliding window size + to allow tokens to attend to the prefix. + """ + + _: dataclasses.KW_ONLY + prefix_length: int + global_layers_only: bool = True + + @classmethod + def from_model( + cls, model: _transformer.Transformer, prefix_length: int, **kwargs + ): + """Creates a prefix-tuned model using the configuration from an existing model.""" + return cls( + config=model.config, + prefix_length=prefix_length, + dtype=model.dtype, + return_last_only=model.return_last_only, + # Tunnel through standard fields bound to the original model + tokens=model.tokens, + images=model.images, + positions=model.positions, + attention_mask=model.attention_mask, + **kwargs, + ) + + @nn.compact + def __call__( + self, + tokens: jax.Array, + *, + images: jax.Array | None = None, + positions: jax.Array | None = None, + cache: _config.Cache | None = None, + attention_mask: jax.Array | None = None, + return_last_only: bool | None = None, + return_hidden_states: bool | None = None, + ) -> _transformer.Output: + """Applies prefix tuning by injecting KV cache and adjusting the attention mask. + + This method injects learnable prefix parameters into the KV cache for + global attention layers (and optionally local layers if `global_layers_only` + is set to False). It modifies the cache to include these prefixes and + adjusts the attention mask to allow all tokens to attend to + the prefix tokens. + + Args: + tokens: Input token IDs. + images: Input images. + positions: Input positions. If not provided, they are inferred from the + input sequence length. + cache: An optional cache structure for incremental decoding. If not + provided, a new cache is initialized. + attention_mask: An optional attention mask. If not provided, a causal + mask is created, allowing attention to the prefix. + return_last_only: If true, only return the logits for the last token. + return_hidden_states: If true, return all hidden states. + + Returns: + An Output object containing logits and optionally hidden states. + """ + + config = self.config + + is_1d = tokens.ndim == 1 + if is_1d: + tokens = jnp.expand_dims(tokens, axis=0) + + batch_size = tokens.shape[0] + seq_len = tokens.shape[1] + + # 1. Define the prefix parameters + prefix_params = {} + for i, attn_type in enumerate(config.attention_types): + layer_name = f'layer_{i}' + + should_apply = ( + not self.global_layers_only + or attn_type == _modules.AttentionType.GLOBAL + ) + + if should_apply: + # Learnable prefixes + prefix_k = self.param( + f'prefix_k_{i}', + nn.initializers.xavier_uniform(), + (self.prefix_length, config.num_kv_heads, config.head_dim), + self.dtype, + ) + prefix_v = self.param( + f'prefix_v_{i}', + nn.initializers.xavier_uniform(), + (self.prefix_length, config.num_kv_heads, config.head_dim), + self.dtype, + ) + prefix_k_expanded = jnp.broadcast_to( + prefix_k, + ( + batch_size, + self.prefix_length, + config.num_kv_heads, + config.head_dim, + ), + ) + prefix_v_expanded = jnp.broadcast_to( + prefix_v, + ( + batch_size, + self.prefix_length, + config.num_kv_heads, + config.head_dim, + ), + ) + else: + # Fixed zeros + prefix_k_expanded = jnp.zeros( + ( + batch_size, + self.prefix_length, + config.num_kv_heads, + config.head_dim, + ), + dtype=self.dtype, + ) + prefix_v_expanded = jnp.zeros( + ( + batch_size, + self.prefix_length, + config.num_kv_heads, + config.head_dim, + ), + dtype=self.dtype, + ) + + prefix_params[layer_name] = { + 'k': prefix_k_expanded, + 'v': prefix_v_expanded, + } + + # 2. Prepare the cache if not provided + if cache is None: + cache_length = self.prefix_length + seq_len + cache = self.config.init_cache( + batch_size=batch_size, dtype=self.dtype, cache_length=cache_length + ) + else: + cache_length = cache['layer_0']['k'].shape[1] + if cache_length < seq_len + self.prefix_length: + new_cache_length = seq_len + self.prefix_length + new_cache = self.config.init_cache( + batch_size=batch_size, + dtype=self.dtype, + cache_length=new_cache_length, + ) + for l_name, layer_data in cache.items(): + for k_name, val in layer_data.items(): + if k_name in ('k', 'v'): + new_cache[l_name][k_name] = jax.lax.dynamic_update_slice( + new_cache[l_name][k_name], val, (0, 0, 0, 0) + ) + elif k_name == 'positions': + new_cache[l_name][k_name] = jax.lax.dynamic_update_slice( + new_cache[l_name][k_name], val, (0, 0) + ) + elif k_name == 'end_index': + new_cache[l_name][k_name] = val + + cache = new_cache + + if cache is not None: + def _inject_prefix_to_cache(c): + for i, attn_type in enumerate(config.attention_types): + layer_name = f'layer_{i}' + should_apply = ( + not self.global_layers_only + or attn_type == _modules.AttentionType.GLOBAL + ) + + # Shift end_index + c[layer_name]['end_index'] = ( + c[layer_name]['end_index'] + self.prefix_length + ) + + # Inject K and V using dynamic_update_slice + c[layer_name]['k'] = jax.lax.dynamic_update_slice( + c[layer_name]['k'], prefix_params[layer_name]['k'], (0, 0, 0, 0) + ) + c[layer_name]['v'] = jax.lax.dynamic_update_slice( + c[layer_name]['v'], prefix_params[layer_name]['v'], (0, 0, 0, 0) + ) + + # Set positions + if should_apply: + positions_val = jnp.broadcast_to( + jnp.arange(self.prefix_length), (batch_size, self.prefix_length) + ) + else: + positions_val = jnp.broadcast_to( + jnp.full( + (self.prefix_length,), + _MASKED_TOKEN_POSITION, + dtype=jnp.int32, + ), + (batch_size, self.prefix_length), + ) + + c[layer_name]['positions'] = jax.lax.dynamic_update_slice( + c[layer_name]['positions'], positions_val, (0, 0) + ) + return c + + cache = jax.lax.cond( + jnp.all(cache['layer_0']['end_index'] == 0), + _inject_prefix_to_cache, + lambda c: c, + cache, + ) + + # 3. Prepare Attention Mask + if attention_mask is None: + # Create default causal mask + prefix + prefix_mask = jnp.ones( + (batch_size, seq_len, self.prefix_length), dtype=jnp.bool_ + ) + causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_)) + causal_mask = jnp.broadcast_to( + causal_mask, (batch_size, seq_len, seq_len) + ) + attention_mask = jnp.concatenate([prefix_mask, causal_mask], axis=-1) + else: + prefix_mask = jnp.ones( + (batch_size, seq_len, self.prefix_length), dtype=jnp.bool_ + ) + attention_mask = jnp.concatenate([prefix_mask, attention_mask], axis=-1) + + if cache is not None: + cache_len = cache['layer_0']['k'].shape[1] + # Truncate attention mask to match the cache length + attention_mask = attention_mask[..., :cache_len] + + # 4. Call the base class + out = super().__call__( + tokens=tokens, + images=images, + positions=positions, + cache=cache, + attention_mask=attention_mask, + return_last_only=return_last_only, + return_hidden_states=return_hidden_states, + ) + + if is_1d: + logits = jnp.squeeze(out.logits, axis=0) + hidden_states = ( + jnp.squeeze(out.hidden_states, axis=0) + if out.hidden_states is not None + else None + ) + out = out.replace(logits=logits, hidden_states=hidden_states) + + return out diff --git a/gemma/gm/nn/_prefix_test.py b/gemma/gm/nn/_prefix_test.py new file mode 100644 index 00000000..ca6996a2 --- /dev/null +++ b/gemma/gm/nn/_prefix_test.py @@ -0,0 +1,90 @@ +# Copyright 2026 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from gemma import gm +from gemma.gm.nn._prefix import PrefixTuning +from gemma.gm.nn._transformer import ModelInfo +import jax +import jax.numpy as jnp +import numpy as np + + +class TinyGemma(gm.nn.Transformer): + """A small dummy Gemma3-like model for fast testing.""" + + config: gm.nn.config.TransformerConfig = gm.nn.config.TransformerConfig( + final_logit_softcap=None, + num_embed=128, + embed_dim=32, + hidden_dim=64, + num_heads=2, + head_dim=16, + num_kv_heads=1, + use_post_attn_norm=True, + use_post_ffw_norm=True, + use_qk_norm=True, + attention_types=tuple([ + gm.nn.AttentionType.LOCAL_SLIDING, + gm.nn.AttentionType.GLOBAL, + ]), + query_pre_attn_norm=gm.nn.config.QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, + attn_logits_soft_cap=None, + sliding_window_size=32, + transpose_gating_einsum=True, + ) + INFO = ModelInfo(tokenizer_version=3) + + +def test_prefix_tuning_global_layers_only(): + model = TinyGemma() + prefix_model = PrefixTuning.from_model( + model, prefix_length=2, global_layers_only=True + ) + + batch_size = 1 + seq_len = 4 + tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + rng = jax.random.key(0) + + # Use a single initialization to avoid redundancy + variables = prefix_model.init(rng, tokens) + params = variables['params'] + + # Verify that some prefix parameters are created + has_prefix = any(k.startswith('prefix_k_') for k in params) + assert has_prefix + + # Run forward pass capturing output cache and intermediates + out, state = prefix_model.apply( + variables, + tokens, + capture_intermediates=True, + mutable=['intermediates'], + ) + + assert out.cache is not None + + # Verify cache shape for layer 0 + layer_0_cache = out.cache['layer_0'] + assert layer_0_cache['k'].shape[1] == 6 # prefix (2) + seq (4) = 6 + + captured = state['intermediates'] + + # Find attention weights for layer 0 + attention_weights = captured['layer_0']['attn']['attention_weights'][ + '__call__' + ][0] + + prefix_weights = attention_weights[0, :, :, :2] + np.testing.assert_allclose(prefix_weights, 0.0, atol=1e-6) diff --git a/gemma/peft/_tree_utils.py b/gemma/peft/_tree_utils.py index 2546180b..48c4cf65 100644 --- a/gemma/peft/_tree_utils.py +++ b/gemma/peft/_tree_utils.py @@ -23,11 +23,11 @@ class SplittedParams(NamedTuple): original: _ParamsDict - lora: _ParamsDict + peft: _ParamsDict def split_params(params: _ParamsDict) -> SplittedParams: - """Split a nested tree into 2 trees, one with and without 'lora' branches. + """Split a nested tree into 2 trees, one with and without 'peft' branches. Example: @@ -45,7 +45,7 @@ def split_params(params: _ParamsDict) -> SplittedParams: } - original, lora = peft.split_params(params) + original, peft = peft.split_params(params) assert original == { 'dense': { @@ -54,7 +54,7 @@ def split_params(params: _ParamsDict) -> SplittedParams: }, 'other': other, } - assert lora == { + assert peft == { 'dense': { 'lora': { 'a': a, @@ -65,30 +65,33 @@ def split_params(params: _ParamsDict) -> SplittedParams: ``` Args: - params: A nested dictionary representing the input tree containing 'lora' - branches. + params: A nested dictionary representing the input tree containing `peft + branches (e.g. 'lora', 'prefix'). Returns: - A named tuple: `(original, lora)` + A named tuple: `(original, peft)` """ + node_name_prefixes = ('lora', 'prefix') original_tree = {} - lora_tree = {} + peft_tree = {} - def _split_recursive(input_subtree, original_subtree, lora_subtree): + def _split_recursive(input_subtree, original_subtree, peft_subtree): for key, value in input_subtree.items(): if isinstance(value, dict): - if key == 'lora': - lora_subtree[key] = value + if key.startswith(node_name_prefixes): + peft_subtree[key] = value else: original_subtree[key] = {} - lora_subtree[key] = {} - _split_recursive(value, original_subtree[key], lora_subtree[key]) - elif key != 'lora': + peft_subtree[key] = {} + _split_recursive(value, original_subtree[key], peft_subtree[key]) + elif key.startswith(node_name_prefixes): + peft_subtree[key] = value + else: original_subtree[key] = value - _split_recursive(params, original_tree, lora_tree) + _split_recursive(params, original_tree, peft_tree) - # Remove empty dicts in lora_tree + # Remove empty dicts in peft_tree def _remove_empty_dicts(tree): if not isinstance(tree, dict): return tree @@ -103,38 +106,38 @@ def _remove_empty_dicts(tree): new_tree[key] = value return new_tree - lora_tree = _remove_empty_dicts(lora_tree) + peft_tree = _remove_empty_dicts(peft_tree) - return SplittedParams(original_tree, lora_tree) + return SplittedParams(original_tree, peft_tree) -def merge_params(original: _ParamsDict, lora: _ParamsDict) -> _ParamsDict: +def merge_params(original: _ParamsDict, peft: _ParamsDict) -> _ParamsDict: """Inverse of `split_params`. Args: - original: The original tree without the 'lora' branches. - lora: The tree containing the 'lora' branches. + original: The original tree without the 'peft' branches. + peft: The tree containing the 'peft' branches. Returns: The merged tree. """ - def _merge_recursive(original_subtree, lora_subtree): + def _merge_recursive(original_subtree, peft_subtree): new_tree = {} for key, value in original_subtree.items(): - if isinstance(value, dict) and key in lora_subtree: - new_tree[key] = _merge_recursive(value, lora_subtree[key]) + if isinstance(value, dict) and key in peft_subtree: + new_tree[key] = _merge_recursive(value, peft_subtree[key]) else: new_tree[key] = value # Add the branches not present in the original tree - for k in sorted(set(lora_subtree) - set(original_subtree)): - new_tree[k] = lora_subtree[k] + for k in sorted(set(peft_subtree) - set(original_subtree)): + new_tree[k] = peft_subtree[k] return new_tree - return _merge_recursive(original, lora) + return _merge_recursive(original, peft) def fuse_params(): diff --git a/gemma/peft/_tree_utils_test.py b/gemma/peft/_tree_utils_test.py index bc55ddf3..a9c13373 100644 --- a/gemma/peft/_tree_utils_test.py +++ b/gemma/peft/_tree_utils_test.py @@ -32,11 +32,11 @@ def test_split_params(): }, }, 'other': 0, - # Nested branches are fully removed from the lora tree. + # Nested branches are fully removed from the peft tree. 'b': {'f': {'a': {}}}, } - original, lora = peft.split_params(params) + original, peft_params = peft.split_params(params) assert original == { 'dense': { 'kernel': 0, @@ -46,7 +46,7 @@ def test_split_params(): 'other': 0, 'b': {'f': {'a': {}}}, } - assert lora == { + assert peft_params == { 'dense': { 'lora': { 'a': 0, @@ -61,4 +61,31 @@ def test_split_params(): }, } - assert peft.merge_params(original, lora) == params + assert peft.merge_params(original, peft_params) == params + + +def test_split_params_with_prefix(): + params = { + 'dense': { + 'kernel': 0, + 'bias': 1, + }, + 'prefix_k_0': 0, + 'prefix_v_0': 1, + 'other': 0, + } + + original, prefix = peft.split_params(params) + assert original == { + 'dense': { + 'kernel': 0, + 'bias': 1, + }, + 'other': 0, + } + assert prefix == { + 'prefix_k_0': 0, + 'prefix_v_0': 1, + } + + assert peft.merge_params(original, prefix) == params