diff --git a/e2cnn/nn/modules/dropout/field.py b/e2cnn/nn/modules/dropout/field.py index d890ad5e..9082c0d8 100644 --- a/e2cnn/nn/modules/dropout/field.py +++ b/e2cnn/nn/modules/dropout/field.py @@ -13,9 +13,34 @@ from torch.nn import Parameter from typing import List, Tuple, Any + __all__ = ["FieldDropout"] +def dropout_field(input: torch.Tensor, p: float, training: bool, inplace: bool): + + if training: + shape = list(input.size()) + shape[2] = 1 + + if input.device == torch.device('cpu'): + mask = torch.FloatTensor(*shape) + else: + device = input.device + mask = torch.cuda.FloatTensor(*shape, device=device) + + mask = mask.uniform_() > p + mask = mask.to(torch.float) + + if inplace: + input *= mask / (1. - p) + return input + else: + return input * mask / (1. - p) + else: + return input + + class FieldDropout(EquivariantModule): def __init__(self, @@ -89,7 +114,7 @@ def __init__(self, _indices[s] = torch.LongTensor(_indices[s]) # register the indices tensors as parameters of this module - self.register_parameter('indices_{}'.format(s), _indices[s]) + self.register_buffer('indices_{}'.format(s), _indices[s]) self._order = list(self._contiguous.keys()) @@ -118,16 +143,21 @@ def forward(self, input: GeometricTensor) -> GeometricTensor: for s in self._order: indices = getattr(self, f"indices_{s}") + + shape = input.shape[:1] + (self._nfields[s], s) + input.shape[2:] + if self._contiguous[s]: # if the fields are contiguous, we can use slicing - out = F.dropout(input[:, indices[0]:indices[1], ...], self.p, self.training, self.inplace) + out = dropout_field(input[:, indices[0]:indices[1], ...].view(shape), self.p, self.training, self.inplace) if not self.inplace: - output[:, indices[0]:indices[1], ...] = out + shape = input.shape[:1] + (self._nfields[s] * s, ) + input.shape[2:] + output[:, indices[0]:indices[1], ...] = out.view(shape) else: # otherwise we have to use indexing - out = F.dropout(input[:, indices[0], ...], self.p, self.training, self.inplace) + out = dropout_field(input[:, indices, ...].view(shape), self.p, self.training, self.inplace) if not self.inplace: - output[:, indices[0], ...] = out + shape = input.shape[:1] + (self._nfields[s] * s, ) + input.shape[2:] + output[:, indices, ...] = out.view(shape) if self.inplace: output = input diff --git a/test/nn/test_dropout.py b/test/nn/test_dropout.py new file mode 100644 index 00000000..a8f3b44b --- /dev/null +++ b/test/nn/test_dropout.py @@ -0,0 +1,126 @@ +import unittest +from unittest import TestCase + +from e2cnn.nn import * +from e2cnn.gspaces import * + +import torch +import torch.nn.functional as F +import numpy as np + +import random + + +class TestDropout(TestCase): + + def test_pointwise_do_unsorted_inplace(self): + N = 8 + g = FlipRot2dOnR2(N) + + r = FieldType(g, [r for r in g.representations.values() if 'pointwise' in r.supported_nonlinearities]*3) + + do = PointwiseDropout(r, inplace=True) + + self.check_do(do) + + def test_pointwise_do_unsorted(self): + N = 8 + g = FlipRot2dOnR2(N) + + r = FieldType(g, [r for r in g.representations.values() if 'pointwise' in r.supported_nonlinearities]*3) + + do = PointwiseDropout(r) + + self.check_do(do) + + def test_pointwise_do_sorted_inplace(self): + N = 8 + g = FlipRot2dOnR2(N) + + r = FieldType(g, [r for r in g.representations.values() if 'pointwise' in r.supported_nonlinearities]*3).sorted() + + do = PointwiseDropout(r, inplace=True) + + self.check_do(do) + + def test_pointwise_do_sorted(self): + N = 8 + g = FlipRot2dOnR2(N) + + r = FieldType(g, [r for r in g.representations.values() if 'pointwise' in r.supported_nonlinearities]*3).sorted() + + do = PointwiseDropout(r) + + self.check_do(do) + + def test_field_do_sorted(self): + N = 8 + g = FlipRot2dOnR2(N) + + r = FieldType(g, list(g.representations.values())*3).sorted() + + bn = FieldDropout(r) + + self.check_do(bn) + + def test_field_do_unsorted(self): + N = 8 + g = FlipRot2dOnR2(N) + + r = FieldType(g, list(g.representations.values())*3) + + bn = FieldDropout(r) + + self.check_do(bn) + + def test_field_do_sorted_inplace(self): + N = 8 + g = FlipRot2dOnR2(N) + + r = FieldType(g, list(g.representations.values())*3).sorted() + + bn = FieldDropout(r, inplace=True) + + self.check_do(bn) + + def test_field_do_unsorted_inplace(self): + N = 8 + g = FlipRot2dOnR2(N) + + r = FieldType(g, list(g.representations.values())*3) + + bn = FieldDropout(r, inplace=True) + + self.check_do(bn) + + def check_do(self, do: EquivariantModule): + + x = 5 * torch.randn(3000, do.in_type.size, 20, 20) + 10 + x = torch.abs(x) + x1 = x + x2 = x.clone() + x1 = GeometricTensor(x1, do.in_type) + x2 = GeometricTensor(x2, do.in_type) + + do.train() + + y1 = do(x1) + + do.eval() + + y2 = do(x2) + + y1 = y1.tensor.permute(1, 0, 2, 3).reshape(do.in_type.size, -1) + y2 = y2.tensor.permute(1, 0, 2, 3).reshape(do.in_type.size, -1) + + m1 = y1.mean(1) + m2 = y2.mean(1) + + # print(m1) + # print(m2) + + self.assertTrue(torch.allclose(m1, m2, rtol=5e-3, atol=5e-3)) + + +if __name__ == '__main__': + unittest.main()