Skip to content

Incorrect implementation for residual connection #85

Description

@kensun619

The implementation for Residual() is incorrect in repvit.py.

if isinstance(self.m, Conv2d_BN):
m = self.m.fuse()
assert(m.groups == m.in_channels)
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
identity = torch.nn.functional.pad(identity, [1,1,1,1])
m.weight += identity.to(m.weight.device)
return m

this is for converting 1x1 conv to 3x3 conv. For identity connection, the implementation is more like

identity = torch.zeros_like(m.weight)
for i in range(m.weight.shape[0]):
identity[i, i, 1, 1] = 1.0 # center of 3x3 kernel
m.weight += identity

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions