diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 203b36be8a..ecc597a9a2 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -1,6 +1,5 @@ from typing import Literal -import pytest import torch from torch import nn from transformers import GemmaForCausalLM @@ -92,7 +91,7 @@ def forward( self, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | pytest.Cache | None = None, + past_key_values: list[torch.FloatTensor] | None = None, inputs_embeds: list[torch.FloatTensor] | None = None, use_cache: bool | None = None, adarms_cond: list[torch.Tensor] | None = None,