diff --git a/xlstm/components/conv.py b/xlstm/components/conv.py index 5144fd2..b1b8374 100644 --- a/xlstm/components/conv.py +++ b/xlstm/components/conv.py @@ -29,27 +29,19 @@ def conv1d_step( ) -> tuple[torch.Tensor, torch.Tensor]: """ B: batch size - S: sequence length + S: sequence length (must equal 1 in step mode) D: feature dimension KS: kernel size - Args: - x (torch.Tensor): (B, S, D) - conv_state (torch.Tensor): (B, KS, D) - conv1d_weight (torch.Tensor): (KS, D) """ - assert ( - x.shape[0] == conv_state.shape[0] - ), f"x has batch size {x.shape[0]} but conv_state has batch size {conv_state.shape[0]}" - assert ( - x.shape[2] == conv_state.shape[2] - ), f"x has feature dimension {x.shape[2]} but conv_state has feature dimension {conv_state.shape[2]}" - assert x.shape[1] == 1, f"x has sequence length {x.shape[1]} but it should be 1" - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=1)) - conv_state[:, -1:, :] = x - y = torch.sum(conv_state * conv1d_weight, dim=1, keepdim=True) + assert x.shape[1] == 1, "x must have a sequence length equal to 1" + new_conv_state = torch.roll(conv_state, shifts=-1, dims=1).clone() + new_conv_state[:, -1:, :] = x + + y = torch.sum(new_conv_state * conv1d_weight, dim=1, keepdim=True) if conv1d_bias is not None: - y += conv1d_bias - return y, conv_state + y = y + conv1d_bias + + return y, new_conv_state class CausalConv1d(nn.Module):