forked from Element-Research/rnn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMaskZeroCriterion.lua
132 lines (117 loc) · 4.24 KB
/
MaskZeroCriterion.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
------------------------------------------------------------------------
--[[ MaskZeroCriterion ]]--
-- Decorator that zeros err and gradInputs of the encapsulated criterion
-- for commensurate input rows which are tensors of zeros
------------------------------------------------------------------------
local MaskZeroCriterion, parent = torch.class("nn.MaskZeroCriterion", "nn.Criterion")
function MaskZeroCriterion:__init(criterion, nInputDim)
parent.__init(self)
self.criterion = criterion
assert(torch.isTypeOf(criterion, 'nn.Criterion'))
assert(torch.type(nInputDim) == 'number', 'Expecting nInputDim number at arg 2')
self.nInputDim = nInputDim
end
function MaskZeroCriterion:recursiveGetFirst(input)
if torch.type(input) == 'table' then
return self:recursiveGetFirst(input[1])
else
assert(torch.isTensor(input))
return input
end
end
function MaskZeroCriterion:recursiveMask(dst, src, mask)
if torch.type(src) == 'table' then
dst = torch.type(dst) == 'table' and dst or {}
for k,v in ipairs(src) do
dst[k] = self:recursiveMask(dst[k], v, mask)
end
else
assert(torch.isTensor(src))
dst = torch.isTensor(dst) and dst or src.new()
dst:index(src, 1, mask)
end
return dst
end
function MaskZeroCriterion:updateOutput(input, target)
-- recurrent module input is always the first one
local rmi = self:recursiveGetFirst(input):contiguous()
if rmi:dim() == self.nInputDim then
error("does not support online (i.e. non-batch) mode")
elseif rmi:dim() - 1 == self.nInputDim then
rmi = rmi:view(rmi:size(1), -1) -- collapse non-batch dims
else
error("nInputDim error: "..rmi:dim()..", "..self.nInputDim)
end
-- build mask
local vectorDim = rmi:dim()
self._zeroMask = self._zeroMask or rmi.new()
self._zeroMask:norm(rmi, 2, vectorDim)
local zeroMask = self._zeroMask
if torch.isTypeOf(zeroMask, 'torch.CudaTensor') or
torch.isTypeOf(zeroMask, 'torch.ClTensor') then
self.__zeroMask = self.__zeroMask or torch.FloatTensor()
self.__zeroMask:resize(self._zeroMask:size()):copy(self._zeroMask)
zeroMask = self._zeroMask
end
self.zeroMask = self.zeroMask or torch.LongTensor()
self.zeroMask:resize(self._zeroMask:size(1)):zero()
local i, j = 0, 0
zeroMask:apply(function(norm)
i = i + 1
if norm ~= 0 then
j = j + 1
self.zeroMask[j] = i
end
end)
self.zeroMask:resize(j)
if j > 0 then
self.input = self:recursiveMask(self.input, input, self.zeroMask)
self.target = self:recursiveMask(self.target, target, self.zeroMask)
-- forward through decorated criterion
self.output = self.criterion:updateOutput(self.input, self.target)
else
-- when all samples are masked, then loss is zero (issue 128)
self.output = 0
end
return self.output
end
function MaskZeroCriterion:recursiveMaskGradInput(dst, mask, src, input)
if torch.type(input) == 'table' then
dst = (torch.type(dst) == 'table') and dst or {dst}
src = (torch.type(src) == 'table') and src or {src}
for key,_ in pairs(input) do
dst[key] = self:recursiveMaskGradInput(dst[key], mask, src[key], input[key])
end
for i=#input+1,#dst do
dst[i] = nil
end
elseif torch.isTensor(input) then
dst = torch.isTensor(dst) and dst or input.new()
dst:resizeAs(input):zero()
if mask:nElement() > 0 then
assert(src)
dst:indexCopy(1, mask, src)
end
else
error("expecting nested tensors or tables. Got "..
torch.type(dst).." and "..torch.type(input).." instead")
end
return dst
end
function MaskZeroCriterion:updateGradInput(input, target)
if self.zeroMask:nElement() > 0 then
assert(self.input and self.target)
self._gradInput = self.criterion:updateGradInput(self.input, self.target)
end
self.gradInput = self:recursiveMaskGradInput(self.gradInput, self.zeroMask, self._gradInput, input)
return self.gradInput
end
function MaskZeroCriterion:type(type, ...)
self.zeroMask = nil
self._zeroMask = nil
self.__zeroMask = nil
self.input = nil
self.target = nil
self._gradInput = nil
return parent.type(self, type, ...)
end