-
Notifications
You must be signed in to change notification settings - Fork 2.3k
FIX Error when prefix tuning Gemma 4 #3205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,6 +69,28 @@ | |
| ) | ||
|
|
||
|
|
||
| def _get_layer_kv_target_shape(base_config, layer_idx: int) -> tuple[int, int] | None: | ||
| """Per-layer (num_kv_heads, head_dim) for prefix-tuning injection, or None for uniform models. | ||
|
|
||
| Models with heterogeneous attention (e.g. Gemma4) expose `global_head_dim` / `num_global_key_value_heads` alongside | ||
| the sliding-layer `head_dim` / `num_key_value_heads`. The provisioned prefix is sized for the global footprint; | ||
| this returns the shape each layer actually expects so the caller can slice down or skip layers that don't fit. | ||
| """ | ||
| layer_types = getattr(base_config, "layer_types", None) | ||
| global_head_dim = getattr(base_config, "global_head_dim", None) | ||
| if not layer_types or global_head_dim is None: | ||
| return None | ||
|
|
||
| is_sliding = layer_types[layer_idx] == "sliding_attention" | ||
| head_dim = base_config.head_dim if is_sliding else global_head_dim | ||
| num_global_kv = getattr(base_config, "num_global_key_value_heads", None) | ||
| if not is_sliding and num_global_kv is not None: | ||
| num_kv_heads = num_global_kv | ||
| else: | ||
| num_kv_heads = base_config.num_key_value_heads | ||
| return num_kv_heads, head_dim | ||
|
|
||
|
|
||
| class PeftModel(PushToHubMixin, torch.nn.Module): | ||
| """ | ||
| Base model encompassing various Peft methods. | ||
|
|
@@ -785,7 +807,7 @@ def get_prompt( | |
| if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None: | ||
| post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type] | ||
| past_key_values = post_process_fn(past_key_values) | ||
| elif ("gemma2" in model_type) or ("gemma3_text" in model_type): | ||
| elif ("gemma2" in model_type) or ("gemma3_text" in model_type) or ("gemma4" in model_type): | ||
| # TODO: remove this logic once transformers < 4.56 is dropped | ||
| transformers_lt_4_56 = packaging.version.parse(transformers.__version__) < packaging.version.parse( | ||
| "4.56.0.dev0" | ||
|
|
@@ -815,12 +837,54 @@ def get_prompt( | |
| # transformers 4.56+ uses DynamicCache for gemma | ||
| new_cache = DynamicCache(config=base_config) | ||
| cache_position = torch.arange(peft_config.num_virtual_tokens, device=past_key_values[0].device) | ||
| for layer_idx in range(peft_config.num_layers): | ||
| key_states, value_states = past_key_values[0][layer_idx], past_key_values[1][layer_idx] | ||
| # Layers from `num_hidden_layers - num_kv_shared_layers` onward share KV with an earlier layer (no own | ||
| # k_proj/v_proj) and never call `cache.update`; the prefix reaches them transitively via the source | ||
| # layer. | ||
| num_kv_shared_layers = getattr(base_config, "num_kv_shared_layers", 0) | ||
| first_kv_shared_layer_idx = ( | ||
| getattr(base_config, "num_hidden_layers", peft_config.num_layers) - num_kv_shared_layers | ||
| ) | ||
| injected_layers: list[int] = [] | ||
| skipped_layers: list[int] = [] | ||
| # past_key_values is a tuple of `num_layers` per-layer tensors each shaped | ||
| # [2, batch, num_heads, num_virtual_tokens, head_dim], where dim 0 stacks K and V. | ||
| for layer_idx, layer_past_key_values in enumerate(past_key_values): | ||
| if num_kv_shared_layers > 0 and layer_idx >= first_kv_shared_layer_idx: | ||
| skipped_layers.append(layer_idx) | ||
| continue | ||
| key_states, value_states = layer_past_key_values | ||
|
Comment on lines
+852
to
+855
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice, prev gemma3 also used to skip layers, so we shouldn't need a prefix cache for it |
||
| shape_or_none = _get_layer_kv_target_shape(base_config, layer_idx) | ||
| if shape_or_none is not None: # e.g. gemma 4 | ||
| n_h, d = shape_or_none | ||
| # Provisioned shape: [batch, num_heads, num_virtual_tokens, head_dim]. If a layer's KV is wider | ||
| # than what we provisioned, we cannot slice up; skip rather than silently truncating to a shape | ||
| # the model won't accept. | ||
| if n_h > key_states.shape[1] or d > key_states.shape[3]: | ||
| skipped_layers.append(layer_idx) | ||
| continue | ||
| key_states = key_states[:, :n_h, :, :d] | ||
| value_states = value_states[:, :n_h, :, :d] | ||
| new_cache.update( | ||
| key_states, value_states, layer_idx, cache_kwargs={"cache_position": cache_position} | ||
| ) | ||
| injected_layers.append(layer_idx) | ||
| past_key_values = new_cache | ||
|
|
||
| if not injected_layers: | ||
| # raise if no layer was matched; similar logic as in target_modules not matching any layer | ||
| raise ValueError( | ||
| "Prefix tuning skipped every layer because no layer's KV shape matched the provisioned prefix " | ||
| f"(num_attention_heads={peft_config.num_attention_heads}, " | ||
| f"head_dim={peft_config.token_dim // peft_config.num_attention_heads}). Override `token_dim` " | ||
| "and `num_attention_heads` in `PrefixTuningConfig` to match a layer that should receive the " | ||
| "prefix." | ||
| ) | ||
| if skipped_layers: | ||
| warnings.warn( | ||
| f"Prefix tuning injected into layers {injected_layers}; skipped {skipped_layers} due to KV " | ||
| "shape mismatch or shared-KV layers." | ||
| ) | ||
|
|
||
| elif peft_config.num_transformer_submodules == 1: | ||
| # Dont' apply this to encoder-decoder models and not to models requiring special processing. | ||
| # TODO: remove from_legacy_cache once transformers < 4.56 is dropped | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1106,3 +1106,85 @@ def test_merge_and_unload_fixes_tie_word_embeddings_config(self): | |
| assert not merged.config.tie_word_embeddings | ||
| assert merged.lm_head.weight is not merged.model.embed_tokens.weight | ||
| assert merged.lm_head.weight.data_ptr() != merged.model.embed_tokens.weight.data_ptr() | ||
|
|
||
| def test_prefix_tuning_gemma4_works(self): | ||
| # see #3201 | ||
| # The issue was that head dim differs depending on whether sliding window attention is being used or not: | ||
| # https://github.com/huggingface/transformers/blob/223fe5231b783fbfb25296bb0a243dad5d158cde/src/transformers/models/gemma4/modeling_gemma4.py#L1147 | ||
| # Prefix tuning could deal with different sizes, resulting in a size error | ||
|
|
||
| model_id = "google/gemma-4-E2B" | ||
| with hub_online_once(model_id): | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_id, | ||
| dtype=torch.bfloat16, | ||
| ).to(self.torch_device) | ||
| config = PrefixTuningConfig( | ||
| task_type=TaskType.CAUSAL_LM, | ||
| num_virtual_tokens=20, | ||
| prefix_projection=False, | ||
| ) | ||
| model = get_peft_model(model, config) | ||
|
|
||
| inputs = torch.arange(10).view(1, -1).to(self.torch_device) | ||
| model(inputs) # does not raise | ||
|
|
||
| # do mini training run | ||
| optim = torch.optim.SGD(model.parameters(), lr=0.001) | ||
| losses = [] | ||
| for _ in range(5): | ||
| optim.zero_grad() | ||
| outputs = model(inputs) | ||
| label = torch.zeros_like(outputs.logits) | ||
| label[:, :, 1] = 1 | ||
| loss = torch.nn.functional.cross_entropy(outputs.logits, label) | ||
| loss.backward() | ||
| optim.step() | ||
| losses.append(loss) | ||
|
|
||
| assert torch.isfinite(loss) | ||
| assert not torch.isclose(losses[0], losses[-1], atol=1e-6, rtol=1e-3) | ||
|
|
||
| def test_prefix_tuning_gemma4_warns_if_some_layers_skipped(self): | ||
| # See previous test_prefix_tuning_gemma4_works. When the embedding matrix is too small to fit any layer targeted | ||
| # by prefix tuning, raise an error | ||
| model_id = "google/gemma-4-E2B" | ||
| with hub_online_once(model_id): | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_id, | ||
| dtype=torch.bfloat16, | ||
| ).to(self.torch_device) | ||
| config = PrefixTuningConfig( | ||
| task_type=TaskType.CAUSAL_LM, | ||
| num_virtual_tokens=20, | ||
| prefix_projection=False, | ||
| ) | ||
| text_config = model.config.get_text_config() | ||
| text_config.num_kv_shared_layers = 1 # set to lower value (was 2) | ||
| model = get_peft_model(model, config) | ||
|
|
||
| inputs = torch.arange(10).view(1, -1).to(self.torch_device) | ||
| with pytest.warns(UserWarning, match=r"skipped \[.*\] due to KV shape"): | ||
| model(inputs) | ||
|
|
||
| def test_prefix_tuning_gemma4_raises_if_all_layers_skipped(self): | ||
| # See previous test_prefix_tuning_gemma4_works. When the embedding matrix is too small to fit any layer targeted | ||
| # by prefix tuning, raise an error | ||
| model_id = "google/gemma-4-E2B" | ||
| with hub_online_once(model_id): | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_id, | ||
| dtype=torch.bfloat16, | ||
| ).to(self.torch_device) | ||
| config = PrefixTuningConfig( | ||
| task_type=TaskType.CAUSAL_LM, | ||
| num_virtual_tokens=20, | ||
| prefix_projection=False, | ||
| ) | ||
| model = get_peft_model(model, config) | ||
| text_config = model.config.get_text_config() | ||
| text_config.num_key_value_heads = 999 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious about this. If configs dim are changed, doesn't it mean that key/value cache will also be a larger tensor?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not 100% sure, but I think it works because the model is already initialized and the cache is already created at this point, so changing the config won't affect it. But I haven't checked the full code path. |
||
|
|
||
| inputs = torch.arange(10).view(1, -1).to(self.torch_device) | ||
| with pytest.raises(ValueError, match="skipped every layer"): | ||
| model(inputs) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so ig we're supporting specifically gemma4 with hardcoded attr names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, if there is a more general approach, LMK, otherwise I'm okay with a Gemma-specific solution.