diff --git a/src/retention.py b/src/retention.py index 642d29f..e34ce07 100644 --- a/src/retention.py +++ b/src/retention.py @@ -105,7 +105,7 @@ def _get_D(self, sequence_length): m = torch.arange(sequence_length).unsqueeze(0) # Broadcast self.gamma ** (n - m) with appropriate masking to set values where n < m to 0 - D = (self.gamma ** (n - m)) * (n >= m).float() #this results in some NaN when n is much larger than m + D = (self.gamma ** (n - m)) * (n >= m).float() #this results in some NaN when m is much larger than n # fill the NaN with 0 D[D != D] = 0