From 9c92d05b169acee97885451d6b494b61e9be1b42 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 30 Apr 2026 15:02:24 +0200 Subject: [PATCH 1/2] FIX Error when prefix tuning Gemma 4 There was an issue with applying prefix tuning to Gemma 4 because the model uses different head dimensions for layers that use sliding window attention. As prefix tuning only initializes a single projection matrix that is used for all layers, this would lead to a shape mismatch. The solution is to "overprovision" the matrix and then slice the prefix down to size of the layer is smaller. This is not quite as parameter efficient as it could be, but the overhead shouldn't be too large. For robustness, we also skip layers if the matrix is underprovisioned, but we warn about it and raise an error if all layers are skipped. Alternatively, we could implement one project per layer, each with the right size, like in https://github.com/google-deepmind/gemma/pull/631/. However, this would be a big refactor and also very hard to make backwards compatible with existing checkpoints, so going with the less efficient solution is preferable. This PR also contains an independent, single line fix to a prefix tuning test that was referencing a non-existing model. --- src/peft/peft_model.py | 70 ++++++++++++++++++++++++++++-- src/peft/utils/other.py | 12 +++++- tests/test_decoder_models.py | 82 ++++++++++++++++++++++++++++++++++++ tests/test_gpu_examples.py | 2 +- 4 files changed, 160 insertions(+), 6 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 6f167c365b..86b8aa64b3 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -69,6 +69,28 @@ ) +def _get_layer_kv_target_shape(base_config, layer_idx: int) -> tuple[int, int] | None: + """Per-layer (num_kv_heads, head_dim) for prefix-tuning injection, or None for uniform models. + + Models with heterogeneous attention (e.g. Gemma4) expose `global_head_dim` / `num_global_key_value_heads` alongside + the sliding-layer `head_dim` / `num_key_value_heads`. The provisioned prefix is sized for the global footprint; + this returns the shape each layer actually expects so the caller can slice down or skip layers that don't fit. + """ + layer_types = getattr(base_config, "layer_types", None) + global_head_dim = getattr(base_config, "global_head_dim", None) + if not layer_types or global_head_dim is None: + return None + + is_sliding = layer_types[layer_idx] == "sliding_attention" + head_dim = base_config.head_dim if is_sliding else global_head_dim + num_global_kv = getattr(base_config, "num_global_key_value_heads", None) + if not is_sliding and num_global_kv is not None: + num_kv_heads = num_global_kv + else: + num_kv_heads = base_config.num_key_value_heads + return num_kv_heads, head_dim + + class PeftModel(PushToHubMixin, torch.nn.Module): """ Base model encompassing various Peft methods. @@ -785,7 +807,7 @@ def get_prompt( if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None: post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type] past_key_values = post_process_fn(past_key_values) - elif ("gemma2" in model_type) or ("gemma3_text" in model_type): + elif ("gemma2" in model_type) or ("gemma3_text" in model_type) or ("gemma4" in model_type): # TODO: remove this logic once transformers < 4.56 is dropped transformers_lt_4_56 = packaging.version.parse(transformers.__version__) < packaging.version.parse( "4.56.0.dev0" @@ -815,12 +837,54 @@ def get_prompt( # transformers 4.56+ uses DynamicCache for gemma new_cache = DynamicCache(config=base_config) cache_position = torch.arange(peft_config.num_virtual_tokens, device=past_key_values[0].device) - for layer_idx in range(peft_config.num_layers): - key_states, value_states = past_key_values[0][layer_idx], past_key_values[1][layer_idx] + # Layers from `num_hidden_layers - num_kv_shared_layers` onward share KV with an earlier layer (no own + # k_proj/v_proj) and never call `cache.update`; the prefix reaches them transitively via the source + # layer. + num_kv_shared_layers = getattr(base_config, "num_kv_shared_layers", 0) + first_kv_shared_layer_idx = ( + getattr(base_config, "num_hidden_layers", peft_config.num_layers) - num_kv_shared_layers + ) + injected_layers: list[int] = [] + skipped_layers: list[int] = [] + # past_key_values is a tuple of `num_layers` per-layer tensors each shaped + # [2, batch, num_heads, num_virtual_tokens, head_dim], where dim 0 stacks K and V. + for layer_idx, layer_past_key_values in enumerate(past_key_values): + if num_kv_shared_layers > 0 and layer_idx >= first_kv_shared_layer_idx: + skipped_layers.append(layer_idx) + continue + key_states, value_states = layer_past_key_values + shape_or_none = _get_layer_kv_target_shape(base_config, layer_idx) + if shape_or_none is not None: # e.g. gemma 4 + n_h, d = shape_or_none + # Provisioned shape: [batch, num_heads, num_virtual_tokens, head_dim]. If a layer's KV is wider + # than what we provisioned, we cannot slice up; skip rather than silently truncating to a shape + # the model won't accept. + if n_h > key_states.shape[1] or d > key_states.shape[3]: + skipped_layers.append(layer_idx) + continue + key_states = key_states[:, :n_h, :, :d] + value_states = value_states[:, :n_h, :, :d] new_cache.update( key_states, value_states, layer_idx, cache_kwargs={"cache_position": cache_position} ) + injected_layers.append(layer_idx) past_key_values = new_cache + + if not injected_layers: + # raise if no layer was matched; similar logic as in target_modules not matching any layer + raise ValueError( + "Prefix tuning skipped every layer because no layer's KV shape matched the provisioned prefix " + f"(num_attention_heads={peft_config.num_attention_heads}, " + f"head_dim={peft_config.token_dim // peft_config.num_attention_heads}). Override `token_dim` " + "and `num_attention_heads` in `PrefixTuningConfig` to match a layer that should receive the " + "prefix." + ) + if skipped_layers: + warnings.warn( + f"Prefix tuning injected into layers {injected_layers}; skipped {skipped_layers} due to KV " + "shape mismatch or shared-KV layers." + ) + elif peft_config.num_transformer_submodules == 1: # Dont' apply this to encoder-decoder models and not to models requiring special processing. # TODO: remove from_legacy_cache once transformers < 4.56 is dropped diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index e4723297b0..074dff2829 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -1161,11 +1161,19 @@ def _prepare_prompt_learning_config(peft_config, model_config): # For grouped-query attention, see #1901. if (peft_config.peft_type in {"PREFIX_TUNING", "CARTRIDGE"}) and ("num_key_value_heads" in model_config): - num_key_value_heads = model_config["num_key_value_heads"] - if model_config.get("head_dim", None) is not None: + # Models with heterogeneous attention (e.g. Gemma4) expose distinct shapes for global vs. sliding layers via + # `global_head_dim` / `num_global_key_value_heads`. Provision the prefix for the global-layer footprint; sliding + # layers whose KV shape doesn't match are skipped per-layer at injection time. Matches the default in + # google-deepmind/gemma#631. + if model_config.get("global_head_dim") is not None: + head_dim = model_config["global_head_dim"] + num_key_value_heads = model_config.get("num_global_key_value_heads") or model_config["num_key_value_heads"] + elif model_config.get("head_dim", None) is not None: head_dim = model_config["head_dim"] + num_key_value_heads = model_config["num_key_value_heads"] else: head_dim = peft_config.token_dim // peft_config.num_attention_heads + num_key_value_heads = model_config["num_key_value_heads"] peft_config.token_dim = head_dim * num_key_value_heads peft_config.num_attention_heads = num_key_value_heads diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index ba1569ce15..6b37d8ae0a 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -1106,3 +1106,85 @@ def test_merge_and_unload_fixes_tie_word_embeddings_config(self): assert not merged.config.tie_word_embeddings assert merged.lm_head.weight is not merged.model.embed_tokens.weight assert merged.lm_head.weight.data_ptr() != merged.model.embed_tokens.weight.data_ptr() + + def test_prefix_tuning_gemma4_works(self): + # see #3201 + # The issue was that head dim differs depending on whether sliding window attention is being used or not: + # https://github.com/huggingface/transformers/blob/223fe5231b783fbfb25296bb0a243dad5d158cde/src/transformers/models/gemma4/modeling_gemma4.py#L1147 + # Prefix tuning could deal with different sizes, resulting in a size error + + model_id = "google/gemma-4-E2B" + with hub_online_once(model_id): + model = AutoModelForCausalLM.from_pretrained( + model_id, + dtype=torch.bfloat16, + ).to(self.torch_device) + config = PrefixTuningConfig( + task_type=TaskType.CAUSAL_LM, + num_virtual_tokens=20, + prefix_projection=False, + ) + model = get_peft_model(model, config) + + inputs = torch.arange(10).view(1, -1).to(self.torch_device) + model(inputs) # does not raise + + # do mini training run + optim = torch.optim.SGD(model.parameters(), lr=0.001) + losses = [] + for _ in range(5): + optim.zero_grad() + outputs = model(inputs) + label = torch.zeros_like(outputs.logits) + label[:, :, 1] = 1 + loss = torch.nn.functional.cross_entropy(outputs.logits, label) + loss.backward() + optim.step() + losses.append(loss) + + assert torch.isfinite(loss) + assert not torch.isclose(losses[0], losses[-1], atol=1e-5, rtol=1e-2) + + def test_prefix_tuning_gemma4_warns_if_some_layers_skipped(self): + # See previous test_prefix_tuning_gemma4_works. When the embedding matrix is too small to fit any layer targeted + # by prefix tuning, raise an error + model_id = "google/gemma-4-E2B" + with hub_online_once(model_id): + model = AutoModelForCausalLM.from_pretrained( + model_id, + dtype=torch.bfloat16, + ).to(self.torch_device) + config = PrefixTuningConfig( + task_type=TaskType.CAUSAL_LM, + num_virtual_tokens=20, + prefix_projection=False, + ) + text_config = model.config.get_text_config() + text_config.num_kv_shared_layers = 1 # set to lower value (was 2) + model = get_peft_model(model, config) + + inputs = torch.arange(10).view(1, -1).to(self.torch_device) + with pytest.warns(UserWarning, match=r"skipped \[.*\] due to KV shape"): + model(inputs) + + def test_prefix_tuning_gemma4_raises_if_all_layers_skipped(self): + # See previous test_prefix_tuning_gemma4_works. When the embedding matrix is too small to fit any layer targeted + # by prefix tuning, raise an error + model_id = "google/gemma-4-E2B" + with hub_online_once(model_id): + model = AutoModelForCausalLM.from_pretrained( + model_id, + dtype=torch.bfloat16, + ).to(self.torch_device) + config = PrefixTuningConfig( + task_type=TaskType.CAUSAL_LM, + num_virtual_tokens=20, + prefix_projection=False, + ) + model = get_peft_model(model, config) + text_config = model.config.get_text_config() + text_config.num_key_value_heads = 999 + + inputs = torch.arange(10).view(1, -1).to(self.torch_device) + with pytest.raises(ValueError, match="skipped every layer"): + model(inputs) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 3b1ea37978..2e2e943d61 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -5401,7 +5401,7 @@ def test_prefix_tuning_multiple_devices_decoder_model(self): @require_torch_multi_accelerator def test_prefix_tuning_multiple_devices_encoder_decoder_model(self): # See issue 2134 - model_id = "peft-internal-testing/tiny-random-T5Model" + model_id = "peft-internal-testing/tiny-random-t5" tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to(self.device) device_map = { From c020c11c397cdf2d66a34dccceab4246517a28c1 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 30 Apr 2026 15:50:43 +0200 Subject: [PATCH 2/2] Change tolerances for Windows (!!!) Tests passed on Linux but not on Windows. Trying to guess tolerances that could work. --- tests/test_decoder_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 6b37d8ae0a..ca02b0b252 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -1143,7 +1143,7 @@ def test_prefix_tuning_gemma4_works(self): losses.append(loss) assert torch.isfinite(loss) - assert not torch.isclose(losses[0], losses[-1], atol=1e-5, rtol=1e-2) + assert not torch.isclose(losses[0], losses[-1], atol=1e-6, rtol=1e-3) def test_prefix_tuning_gemma4_warns_if_some_layers_skipped(self): # See previous test_prefix_tuning_gemma4_works. When the embedding matrix is too small to fit any layer targeted