Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,28 @@
)


def _get_layer_kv_target_shape(base_config, layer_idx: int) -> tuple[int, int] | None:
"""Per-layer (num_kv_heads, head_dim) for prefix-tuning injection, or None for uniform models.

Models with heterogeneous attention (e.g. Gemma4) expose `global_head_dim` / `num_global_key_value_heads` alongside
the sliding-layer `head_dim` / `num_key_value_heads`. The provisioned prefix is sized for the global footprint;
this returns the shape each layer actually expects so the caller can slice down or skip layers that don't fit.
"""
layer_types = getattr(base_config, "layer_types", None)
Comment on lines +72 to +79

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so ig we're supporting specifically gemma4 with hardcoded attr names

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if there is a more general approach, LMK, otherwise I'm okay with a Gemma-specific solution.

global_head_dim = getattr(base_config, "global_head_dim", None)
if not layer_types or global_head_dim is None:
return None

is_sliding = layer_types[layer_idx] == "sliding_attention"
head_dim = base_config.head_dim if is_sliding else global_head_dim
num_global_kv = getattr(base_config, "num_global_key_value_heads", None)
if not is_sliding and num_global_kv is not None:
num_kv_heads = num_global_kv
else:
num_kv_heads = base_config.num_key_value_heads
return num_kv_heads, head_dim


class PeftModel(PushToHubMixin, torch.nn.Module):
"""
Base model encompassing various Peft methods.
Expand Down Expand Up @@ -785,7 +807,7 @@ def get_prompt(
if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
past_key_values = post_process_fn(past_key_values)
elif ("gemma2" in model_type) or ("gemma3_text" in model_type):
elif ("gemma2" in model_type) or ("gemma3_text" in model_type) or ("gemma4" in model_type):
# TODO: remove this logic once transformers < 4.56 is dropped
transformers_lt_4_56 = packaging.version.parse(transformers.__version__) < packaging.version.parse(
"4.56.0.dev0"
Expand Down Expand Up @@ -815,12 +837,54 @@ def get_prompt(
# transformers 4.56+ uses DynamicCache for gemma
new_cache = DynamicCache(config=base_config)
cache_position = torch.arange(peft_config.num_virtual_tokens, device=past_key_values[0].device)
for layer_idx in range(peft_config.num_layers):
key_states, value_states = past_key_values[0][layer_idx], past_key_values[1][layer_idx]
# Layers from `num_hidden_layers - num_kv_shared_layers` onward share KV with an earlier layer (no own
# k_proj/v_proj) and never call `cache.update`; the prefix reaches them transitively via the source
# layer.
num_kv_shared_layers = getattr(base_config, "num_kv_shared_layers", 0)
first_kv_shared_layer_idx = (
getattr(base_config, "num_hidden_layers", peft_config.num_layers) - num_kv_shared_layers
)
injected_layers: list[int] = []
skipped_layers: list[int] = []
# past_key_values is a tuple of `num_layers` per-layer tensors each shaped
# [2, batch, num_heads, num_virtual_tokens, head_dim], where dim 0 stacks K and V.
for layer_idx, layer_past_key_values in enumerate(past_key_values):
if num_kv_shared_layers > 0 and layer_idx >= first_kv_shared_layer_idx:
skipped_layers.append(layer_idx)
continue
key_states, value_states = layer_past_key_values
Comment on lines +852 to +855

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, prev gemma3 also used to skip layers, so we shouldn't need a prefix cache for it

shape_or_none = _get_layer_kv_target_shape(base_config, layer_idx)
if shape_or_none is not None: # e.g. gemma 4
n_h, d = shape_or_none
# Provisioned shape: [batch, num_heads, num_virtual_tokens, head_dim]. If a layer's KV is wider
# than what we provisioned, we cannot slice up; skip rather than silently truncating to a shape
# the model won't accept.
if n_h > key_states.shape[1] or d > key_states.shape[3]:
skipped_layers.append(layer_idx)
continue
key_states = key_states[:, :n_h, :, :d]
value_states = value_states[:, :n_h, :, :d]
new_cache.update(
key_states, value_states, layer_idx, cache_kwargs={"cache_position": cache_position}
)
injected_layers.append(layer_idx)
past_key_values = new_cache

if not injected_layers:
# raise if no layer was matched; similar logic as in target_modules not matching any layer
raise ValueError(
"Prefix tuning skipped every layer because no layer's KV shape matched the provisioned prefix "
f"(num_attention_heads={peft_config.num_attention_heads}, "
f"head_dim={peft_config.token_dim // peft_config.num_attention_heads}). Override `token_dim` "
"and `num_attention_heads` in `PrefixTuningConfig` to match a layer that should receive the "
"prefix."
)
if skipped_layers:
warnings.warn(
f"Prefix tuning injected into layers {injected_layers}; skipped {skipped_layers} due to KV "
"shape mismatch or shared-KV layers."
)

elif peft_config.num_transformer_submodules == 1:
# Dont' apply this to encoder-decoder models and not to models requiring special processing.
# TODO: remove from_legacy_cache once transformers < 4.56 is dropped
Expand Down
12 changes: 10 additions & 2 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,11 +1161,19 @@ def _prepare_prompt_learning_config(peft_config, model_config):

# For grouped-query attention, see #1901.
if (peft_config.peft_type in {"PREFIX_TUNING", "CARTRIDGE"}) and ("num_key_value_heads" in model_config):
num_key_value_heads = model_config["num_key_value_heads"]
if model_config.get("head_dim", None) is not None:
# Models with heterogeneous attention (e.g. Gemma4) expose distinct shapes for global vs. sliding layers via
# `global_head_dim` / `num_global_key_value_heads`. Provision the prefix for the global-layer footprint; sliding
# layers whose KV shape doesn't match are skipped per-layer at injection time. Matches the default in
# google-deepmind/gemma#631.
if model_config.get("global_head_dim") is not None:
head_dim = model_config["global_head_dim"]
num_key_value_heads = model_config.get("num_global_key_value_heads") or model_config["num_key_value_heads"]
elif model_config.get("head_dim", None) is not None:
head_dim = model_config["head_dim"]
num_key_value_heads = model_config["num_key_value_heads"]
else:
head_dim = peft_config.token_dim // peft_config.num_attention_heads
num_key_value_heads = model_config["num_key_value_heads"]
peft_config.token_dim = head_dim * num_key_value_heads
peft_config.num_attention_heads = num_key_value_heads

Expand Down
82 changes: 82 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,3 +1106,85 @@ def test_merge_and_unload_fixes_tie_word_embeddings_config(self):
assert not merged.config.tie_word_embeddings
assert merged.lm_head.weight is not merged.model.embed_tokens.weight
assert merged.lm_head.weight.data_ptr() != merged.model.embed_tokens.weight.data_ptr()

def test_prefix_tuning_gemma4_works(self):
# see #3201
# The issue was that head dim differs depending on whether sliding window attention is being used or not:
# https://github.com/huggingface/transformers/blob/223fe5231b783fbfb25296bb0a243dad5d158cde/src/transformers/models/gemma4/modeling_gemma4.py#L1147
# Prefix tuning could deal with different sizes, resulting in a size error

model_id = "google/gemma-4-E2B"
with hub_online_once(model_id):
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.bfloat16,
).to(self.torch_device)
config = PrefixTuningConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=20,
prefix_projection=False,
)
model = get_peft_model(model, config)

inputs = torch.arange(10).view(1, -1).to(self.torch_device)
model(inputs) # does not raise

# do mini training run
optim = torch.optim.SGD(model.parameters(), lr=0.001)
losses = []
for _ in range(5):
optim.zero_grad()
outputs = model(inputs)
label = torch.zeros_like(outputs.logits)
label[:, :, 1] = 1
loss = torch.nn.functional.cross_entropy(outputs.logits, label)
loss.backward()
optim.step()
losses.append(loss)

assert torch.isfinite(loss)
assert not torch.isclose(losses[0], losses[-1], atol=1e-6, rtol=1e-3)

def test_prefix_tuning_gemma4_warns_if_some_layers_skipped(self):
# See previous test_prefix_tuning_gemma4_works. When the embedding matrix is too small to fit any layer targeted
# by prefix tuning, raise an error
model_id = "google/gemma-4-E2B"
with hub_online_once(model_id):
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.bfloat16,
).to(self.torch_device)
config = PrefixTuningConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=20,
prefix_projection=False,
)
text_config = model.config.get_text_config()
text_config.num_kv_shared_layers = 1 # set to lower value (was 2)
model = get_peft_model(model, config)

inputs = torch.arange(10).view(1, -1).to(self.torch_device)
with pytest.warns(UserWarning, match=r"skipped \[.*\] due to KV shape"):
model(inputs)

def test_prefix_tuning_gemma4_raises_if_all_layers_skipped(self):
# See previous test_prefix_tuning_gemma4_works. When the embedding matrix is too small to fit any layer targeted
# by prefix tuning, raise an error
model_id = "google/gemma-4-E2B"
with hub_online_once(model_id):
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.bfloat16,
).to(self.torch_device)
config = PrefixTuningConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=20,
prefix_projection=False,
)
model = get_peft_model(model, config)
text_config = model.config.get_text_config()
text_config.num_key_value_heads = 999

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious about this. If configs dim are changed, doesn't it mean that key/value cache will also be a larger tensor?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure, but I think it works because the model is already initialized and the cache is already created at this point, so changing the config won't affect it. But I haven't checked the full code path.


inputs = torch.arange(10).view(1, -1).to(self.torch_device)
with pytest.raises(ValueError, match="skipped every layer"):
model(inputs)
2 changes: 1 addition & 1 deletion tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5401,7 +5401,7 @@ def test_prefix_tuning_multiple_devices_decoder_model(self):
@require_torch_multi_accelerator
def test_prefix_tuning_multiple_devices_encoder_decoder_model(self):
# See issue 2134
model_id = "peft-internal-testing/tiny-random-T5Model"
model_id = "peft-internal-testing/tiny-random-t5"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to(self.device)
device_map = {
Expand Down
Loading