diff --git a/gliner/modeling/span_rep.py b/gliner/modeling/span_rep.py index 954ac0f5..9f737169 100644 --- a/gliner/modeling/span_rep.py +++ b/gliner/modeling/span_rep.py @@ -374,14 +374,10 @@ def extract_elements(sequence, indices): Returns: torch.Tensor: Extracted elements of shape [B, K, D]. """ - D = sequence.size(-1) - - # Expand indices to [B, K, D] + B, L, D = sequence.size() + indices = torch.clamp(indices, 0, L - 1) expanded_indices = indices.unsqueeze(2).expand(-1, -1, D) - - # Gather the elements extracted_elements = torch.gather(sequence, 1, expanded_indices) - return extracted_elements