Skip to content

Commit

Permalink
padding_mode in R2Conv, v0.1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabri95 committed Oct 8, 2020
2 parents 01a3259 + fc64997 commit 1abf950
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 11 deletions.
2 changes: 1 addition & 1 deletion e2cnn/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
__title__ = "e2cnn"
__summary__ = "E(2)-Equivariant CNNs Library for PyTorch"
__url__ = 'https://github.com/QUVA-Lab/e2cnn'
__version__ = "0.1"
__version__ = "0.1.1"
__author__ = "Gabriele Cesa, Maurice Weiler"
__email__ = "[email protected]"
__license__ = "BSD 3-Clause Clear"
45 changes: 35 additions & 10 deletions e2cnn/nn/modules/r2_conv/r2convolution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from torch.nn.functional import conv2d
from torch.nn.functional import conv2d, pad

from e2cnn.nn import init
from e2cnn.nn import FieldType
Expand Down Expand Up @@ -31,6 +31,7 @@ def __init__(self,
padding: int = 0,
stride: int = 1,
dilation: int = 1,
padding_mode: str = 'zeros',
groups: int = 1,
bias: bool = True,
basisexpansion: str = 'blocks',
Expand Down Expand Up @@ -110,9 +111,10 @@ def __init__(self,
in_type (FieldType): the type of the input field, specifying its transformation law
out_type (FieldType): the type of the output field, specifying its transformation law
kernel_size (int): the size of the (square) filter
padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0``
stride(int, optional): the stride of the kernel. Default: ``1``
dilation(int, optional): the spacing between kernel elements. Default: ``1``
padding (int, optional): implicit zero paddings on both sides of the input. Default: ``0``
padding_mode(str, optional): ``zeros``, ``reflect``, ``replicate`` or ``circular``. Default: ``zeros``
stride (int, optional): the stride of the kernel. Default: ``1``
dilation (int, optional): the spacing between kernel elements. Default: ``1``
groups (int, optional): number of blocked connections from input channels to output channels.
It allows depthwise convolution. When used, the input and output types need to be
divisible in ``groups`` groups, all equal to each other.
Expand Down Expand Up @@ -160,8 +162,21 @@ def __init__(self,
self.stride = stride
self.dilation = dilation
self.padding = padding
self.padding_mode = padding_mode
self.groups = groups

if isinstance(padding, tuple) and len(padding) == 2:
_padding = padding
elif isinstance(padding, int):
_padding = (padding, padding)
else:
raise ValueError('padding needs to be either an integer or a tuple containing two integers but {} found'.format(padding))

padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
if padding_mode not in padding_modes:
raise ValueError("padding_mode must be one of [{}], but got padding_mode='{}'".format(padding_modes, padding_mode))
self._reversed_padding_repeated_twice = tuple(x for x in reversed(_padding) for _ in range(2))

if groups > 1:
# Check the input and output classes can be split in `groups` groups, all equal to each other
# first, check that the number of fields is divisible by `groups`
Expand Down Expand Up @@ -310,13 +325,23 @@ def forward(self, input: GeometricTensor):
filter, bias = self.expand_parameters()

# use it for convolution and return the result
output = conv2d(input.tensor, filter,
padding=self.padding,
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
bias=bias)

if self.padding_mode != 'zeros':
output = conv2d(input.tensor, filter,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
bias=bias)
else:
output = conv2d(pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode),
filter,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
bias=bias)

return GeometricTensor(output, self.out_type)

def train(self, mode=True):
Expand Down
30 changes: 30 additions & 0 deletions test/nn/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,36 @@ def test_flip(self):
cl.eval()
cl.check_equivariance()

def test_padding_mode_reflect(self):
g = Flip2dOnR2(axis=np.pi / 2)

r1 = FieldType(g, [g.trivial_repr])
r2 = FieldType(g, [g.regular_repr])

s = 3
cl = R2Conv(r1, r2, s, bias=True, padding=1, padding_mode='reflect', initialize=False)

for _ in range(32):
init.generalized_he_init(cl.weights.data, cl.basisexpansion)
cl.eval()
cl.check_equivariance()

def test_padding_mode_circular(self):
g = FlipRot2dOnR2(4, axis=np.pi / 2)

r1 = FieldType(g, [g.trivial_repr])
r2 = FieldType(g, [g.regular_repr])

for mode in ['circular', 'reflect', 'replicate']:
for s in [3, 5, 7]:
padding = s // 2
cl = R2Conv(r1, r2, s, bias=True, padding=padding, padding_mode=mode, initialize=False)

for _ in range(10):
init.generalized_he_init(cl.weights.data, cl.basisexpansion)
cl.eval()
cl.check_equivariance()


if __name__ == '__main__':
unittest.main()

0 comments on commit 1abf950

Please sign in to comment.