diff --git a/gemma/gm/nn/gemma3n/_transformer.py b/gemma/gm/nn/gemma3n/_transformer.py index 44ddfa59..64dbbe14 100644 --- a/gemma/gm/nn/gemma3n/_transformer.py +++ b/gemma/gm/nn/gemma3n/_transformer.py @@ -227,13 +227,13 @@ def __call__( # pytype: disable=signature-mismatch *, images: UInt8['*B N H W C'] | UInt8['*B H W C'] | None = None, # TODO(epot): Cleanup and simplify the API. - positions: Int['*B L'] | None = None, + positions: Int['*B L_with_mm'] | None = None, positions_offset: Int['*B'] | None = None, cache: _config.Cache | None = None, # During training and pre-filling, the attention mask is `*B L L` # When sampling (after prefilling), tokens are decoded one by one, # so the attention mask is `*B 1 cache_length` - attention_mask: Bool['*B L cache_length'] | None = None, + attention_mask: Bool['*B L_with_mm cache_length'] | None = None, return_last_only: bool | None = None, return_hidden_states: bool | None = None, ) -> Output: # Output['*B'] @@ -375,8 +375,8 @@ def _encode_and_get_inputs( *, tokens: Int['B L_no_mm'], images: UInt8['B H W C'] | UInt8['B N H W C'] | None = None, - attention_mask: Bool['B L_no_mm cache_length'] | None = None, - positions: Int['B L_no_mm'] | None = None, + attention_mask: Bool['B L_with_mm cache_length'] | None = None, + positions: Int['B L_with_mm'] | None = None, positions_offset: Int['B'] | None = None, ) -> _Inputs: """Encode the text tokens, eventually including the vision embeddings.""" @@ -400,6 +400,7 @@ def _encode_and_get_inputs( # Currently, The placeholders are required so the mask, positions are # correctly computed. x = self.embedder.encode(inputs.tokens_with_mm) + seq_len_with_mm = inputs.tokens_with_mm.shape[1] # Encode the vision tokens and merge them with the text embeddings. if inputs.images is not None: @@ -422,15 +423,27 @@ def _encode_and_get_inputs( # it's the user responsibility to correctly take into account the extra # tokens inserted for the images. # This is what the `gm.text.Sampler` implementation does. + # if positions is None: + # positions = _pos_utils.build_positions_from_mask(inputs.inputs_mask) + # # For multi-turn, during the pre-fill phase, the positions should be + # # shifted to take into account the previous turns. + # if positions_offset is not None: + # positions += positions_offset[..., None] if positions is None: + # Build correct positions aligned with MM-expanded tokens positions = _pos_utils.build_positions_from_mask(inputs.inputs_mask) - # For multi-turn, during the pre-fill phase, the positions should be - # shifted to take into account the previous turns. if positions_offset is not None: positions += positions_offset[..., None] + else: + if positions.shape[1] != seq_len_with_mm: + # Expand positions using inputs_mask (safe + correct) + positions = _pos_utils.build_positions_from_mask(inputs.inputs_mask) if attention_mask is None: attention_mask = inputs.attention_mask + else: + if attention_mask.shape[1] != seq_len_with_mm: + attention_mask = inputs.attention_mask return _Inputs( embeddings=x,