Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions gemma/gm/nn/gemma3n/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,13 @@ def __call__( # pytype: disable=signature-mismatch
*,
images: UInt8['*B N H W C'] | UInt8['*B H W C'] | None = None,
# TODO(epot): Cleanup and simplify the API.
positions: Int['*B L'] | None = None,
positions: Int['*B L_with_mm'] | None = None,
positions_offset: Int['*B'] | None = None,
cache: _config.Cache | None = None,
# During training and pre-filling, the attention mask is `*B L L`
# When sampling (after prefilling), tokens are decoded one by one,
# so the attention mask is `*B 1 cache_length`
attention_mask: Bool['*B L cache_length'] | None = None,
attention_mask: Bool['*B L_with_mm cache_length'] | None = None,
return_last_only: bool | None = None,
return_hidden_states: bool | None = None,
) -> Output: # Output['*B']
Expand Down Expand Up @@ -375,8 +375,8 @@ def _encode_and_get_inputs(
*,
tokens: Int['B L_no_mm'],
images: UInt8['B H W C'] | UInt8['B N H W C'] | None = None,
attention_mask: Bool['B L_no_mm cache_length'] | None = None,
positions: Int['B L_no_mm'] | None = None,
attention_mask: Bool['B L_with_mm cache_length'] | None = None,
positions: Int['B L_with_mm'] | None = None,
positions_offset: Int['B'] | None = None,
) -> _Inputs:
"""Encode the text tokens, eventually including the vision embeddings."""
Expand All @@ -400,6 +400,7 @@ def _encode_and_get_inputs(
# Currently, The placeholders are required so the mask, positions are
# correctly computed.
x = self.embedder.encode(inputs.tokens_with_mm)
seq_len_with_mm = inputs.tokens_with_mm.shape[1]

# Encode the vision tokens and merge them with the text embeddings.
if inputs.images is not None:
Expand All @@ -422,15 +423,27 @@ def _encode_and_get_inputs(
# it's the user responsibility to correctly take into account the extra
# tokens inserted for the images.
# This is what the `gm.text.Sampler` implementation does.
# if positions is None:
# positions = _pos_utils.build_positions_from_mask(inputs.inputs_mask)
# # For multi-turn, during the pre-fill phase, the positions should be
# # shifted to take into account the previous turns.
# if positions_offset is not None:
# positions += positions_offset[..., None]
if positions is None:
# Build correct positions aligned with MM-expanded tokens
positions = _pos_utils.build_positions_from_mask(inputs.inputs_mask)
# For multi-turn, during the pre-fill phase, the positions should be
# shifted to take into account the previous turns.
if positions_offset is not None:
positions += positions_offset[..., None]
else:
if positions.shape[1] != seq_len_with_mm:
# Expand positions using inputs_mask (safe + correct)
positions = _pos_utils.build_positions_from_mask(inputs.inputs_mask)

if attention_mask is None:
attention_mask = inputs.attention_mask
else:
if attention_mask.shape[1] != seq_len_with_mm:
attention_mask = inputs.attention_mask

return _Inputs(
embeddings=x,
Expand Down