Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions xlstm/components/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,19 @@ def conv1d_step(
) -> tuple[torch.Tensor, torch.Tensor]:
"""
B: batch size
S: sequence length
S: sequence length (must equal 1 in step mode)
D: feature dimension
KS: kernel size
Args:
x (torch.Tensor): (B, S, D)
conv_state (torch.Tensor): (B, KS, D)
conv1d_weight (torch.Tensor): (KS, D)
"""
assert (
x.shape[0] == conv_state.shape[0]
), f"x has batch size {x.shape[0]} but conv_state has batch size {conv_state.shape[0]}"
assert (
x.shape[2] == conv_state.shape[2]
), f"x has feature dimension {x.shape[2]} but conv_state has feature dimension {conv_state.shape[2]}"
assert x.shape[1] == 1, f"x has sequence length {x.shape[1]} but it should be 1"
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=1))
conv_state[:, -1:, :] = x
y = torch.sum(conv_state * conv1d_weight, dim=1, keepdim=True)
assert x.shape[1] == 1, "x must have a sequence length equal to 1"
new_conv_state = torch.roll(conv_state, shifts=-1, dims=1).clone()
new_conv_state[:, -1:, :] = x

y = torch.sum(new_conv_state * conv1d_weight, dim=1, keepdim=True)
if conv1d_bias is not None:
y += conv1d_bias
return y, conv_state
y = y + conv1d_bias

return y, new_conv_state


class CausalConv1d(nn.Module):
Expand Down