-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSmooth_AP_loss.py
218 lines (172 loc) · 9.08 KB
/
Smooth_AP_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# requirements:
# python 3.x
# torch = 1.1.0
import torch
def sigmoid(tensor, temp=1.0):
""" temperature controlled sigmoid
takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
"""
exponent = -tensor / temp
# clamp the input tensor for stability
exponent = torch.clamp(exponent, min=-50, max=50)
y = 1.0 / (1.0 + torch.exp(exponent))
return y
def compute_aff(x, y):
"""computes the affinity matrix between an input vector and itself"""
return torch.mm(x, y.t())
class SmoothAP(torch.nn.Module):
"""PyTorch implementation of the Smooth-AP loss.
implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
have the same number of instances represented in the mini-batch and must be ordered sequentially by class.
e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:
labels = ( A, A, A, B, B, B, C, C, C)
(the order of the classes however does not matter)
For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
same class. The loss returns the average Smooth-AP across all instances in the mini-batch.
Args:
anneal : float
the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
batch_size : int
the batch size being used during training.
num_id : int
the number of different classes that are represented in the batch.
feat_dims : int
the dimension of the input feature embeddings
Shape:
- Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
- Output: scalar
Examples::
>>> loss = SmoothAP(0.01, 60, 6, 256)
>>> input = torch.randn(60, 256, requires_grad=True).cuda()
>>> output = loss(input)
>>> output.backward()
"""
def __init__(self, anneal, cartoon_len, portrait_len):
"""
Parameters
----------
anneal : float
the temperature of the sigmoid that is used to smooth the ranking function
batch_size : int
the batch size being used
num_id : int
the number of different classes that are represented in the batch
feat_dims : int
the dimension of the input feature embeddings
"""
super(SmoothAP, self).__init__()
self.anneal = anneal
self.cartoon_len = cartoon_len
self.portrait_len = portrait_len
def forward(self, cartoon_preds, cartoon_labels, portrait_preds, portrait_labels):
"""Forward pass for all input predictions: preds - (batch_size x feat_dims) """
pos_mask = (cartoon_labels == portrait_labels.repeat(1, self.cartoon_len).T)
batch_mask = torch.nonzero(pos_mask.sum(-1), as_tuple=True)
pos_mask = pos_mask[batch_mask]
cartoon_preds = cartoon_preds[batch_mask]
loss = torch.tensor(0.0).cuda()
loss.requires_grad = True
if len(pos_mask) == 0:
return loss
# ------ differentiable ranking of all retrieval set ------
# compute the relevance scores via cosine similarity of the CNN-produced embedding vectors
sim_all = compute_aff(cartoon_preds, portrait_preds)
sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.portrait_len, 1)
# compute the difference matrix
sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
# pass through the sigmoid
sim_sg = sigmoid(sim_diff, temp=self.anneal)
# compute the rankings
sim_all_rk = torch.sum(sim_sg, dim=-1) + 1
# ------ differentiable ranking of only positive set in retrieval set ------
# compute the mask which only gives non-zero weights to the positive set
sg_pos_mask = pos_mask.unsqueeze(1).repeat(1, self.portrait_len, 1)
sim_pos_sg = sim_sg * sg_pos_mask
sim_pos_rk = (torch.sum(sim_pos_sg, dim=-1) + 1) * pos_mask
ap = (sim_pos_rk / sim_all_rk).sum(-1) / pos_mask.sum(-1)
loss = 1 - ap.mean()
return loss
class SmoothAP_test(torch.nn.Module):
"""PyTorch implementation of the Smooth-AP loss.
implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
have the same number of instances represented in the mini-batch and must be ordered sequentially by class.
e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:
labels = ( A, A, A, B, B, B, C, C, C)
(the order of the classes however does not matter)
For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
same class. The loss returns the average Smooth-AP across all instances in the mini-batch.
Args:
anneal : float
the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
batch_size : int
the batch size being used during training.
num_id : int
the number of different classes that are represented in the batch.
feat_dims : int
the dimension of the input feature embeddings
Shape:
- Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
- Output: scalar
Examples::
>>> loss = SmoothAP(0.01, 60, 6, 256)
>>> input = torch.randn(60, 256, requires_grad=True).cuda()
>>> output = loss(input)
>>> output.backward()
"""
def __init__(self, anneal, batch_size, num_id, feat_dims):
"""
Parameters
----------
anneal : float
the temperature of the sigmoid that is used to smooth the ranking function
batch_size : int
the batch size being used
num_id : int
the number of different classes that are represented in the batch
feat_dims : int
the dimension of the input feature embeddings
"""
super(SmoothAP_test, self).__init__()
assert(batch_size%num_id==0)
self.anneal = anneal
self.batch_size = batch_size
self.num_id = num_id
self.feat_dims = feat_dims
def forward(self, cartoon_preds, portrait_preds):
"""Forward pass for all input predictions: preds - (batch_size x feat_dims) """
# ------ differentiable ranking of all retrieval set ------
# compute the relevance scores via cosine similarity of the CNN-produced embedding vectors
sim_all = compute_aff(cartoon_preds, portrait_preds)
sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1)
# compute the difference matrix
sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
# pass through the sigmoid
sim_sg = sigmoid(sim_diff, temp=self.anneal)
# compute the rankings
sim_all_rk = torch.sum(sim_sg, dim=-1) + 1
# ------ differentiable ranking of only positive set in retrieval set ------
# compute the mask which only gives non-zero weights to the positive set
cartoon_xs = cartoon_preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims)
portrait_xs = portrait_preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims)
# compute the relevance scores
sim_pos = torch.bmm(cartoon_xs, portrait_xs.permute(0, 2, 1))
sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat(1, 1, int(self.batch_size / self.num_id), 1)
# compute the difference matrix
sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2)
# pass through the sigmoid
sim_pos_sg = sigmoid(sim_pos_diff, temp=self.anneal)
# compute the rankings of the positive set
sim_pos_rk = torch.sum(sim_pos_sg, dim=-1) + 1
# sum the values of the Smooth-AP for all instances in the mini-batch
ap = torch.zeros(1).cuda()
group = int(self.batch_size / self.num_id)
for ind in range(self.num_id):
pos_divide = torch.sum(sim_pos_rk[ind] / (sim_all_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)]))
ap = ap + ((pos_divide / group) / self.batch_size)
return (1-ap)