-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathlosses.py
87 lines (70 loc) · 2.86 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torch.nn as nn
import torch.nn.functional as F
def multilabel_categorical_crossentropy(y_true, y_pred):
"""多标签分类的交叉熵
说明:y_true和y_pred的shape一致,y_true的元素非0即1,
1表示对应的类为目标类,0表示对应的类为非目标类。
警告:请保证y_pred的值域是全体实数,换言之一般情况下y_pred
不用加激活函数,尤其是不能加sigmoid或者softmax!预测
阶段则输出y_pred大于0的类。如有疑问,请仔细阅读并理解
本文。
"""
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e30
y_pred_pos = y_pred - (1 - y_true) * 1e30
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat([y_pred_neg, zeros],dim=-1)
y_pred_pos = torch.cat((y_pred_pos, zeros),dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, axis=-1)
pos_loss = torch.logsumexp(y_pred_pos, axis=-1)
return neg_loss + pos_loss
class balanced_loss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, labels):
loss = multilabel_categorical_crossentropy(labels,logits)
loss = loss.mean()
return loss
def get_label(self, logits, num_labels=-1):
th_logit = torch.zeros_like(logits[..., :1])
output = torch.zeros_like(logits).to(logits)
mask = (logits > th_logit)
if num_labels > 0:
top_v, _ = torch.topk(logits, num_labels, dim=1)
top_v = top_v[:, -1]
mask = (logits >= top_v.unsqueeze(1)) & mask
output[mask] = 1.0
output[:, 0] = (output[:,1:].sum(1) == 0.).to(logits)
return output
class ATLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, labels):
# TH label
th_label = torch.zeros_like(labels, dtype=torch.float).to(labels)
th_label[:, 0] = 1.0
labels[:, 0] = 0.0
p_mask = labels + th_label
n_mask = 1 - labels
# Rank positive classes to TH
logit1 = logits - (1 - p_mask) * 1e30
loss1 = -(F.log_softmax(logit1, dim=-1) * labels).sum(1)
# Rank TH to negative classes
logit2 = logits - (1 - n_mask) * 1e30
loss2 = -(F.log_softmax(logit2, dim=-1) * th_label).sum(1)
# Sum two parts
loss = loss1 + loss2
loss = loss.mean()
return loss
def get_label(self, logits, num_labels=-1):
th_logit = logits[:, 0].unsqueeze(1)
output = torch.zeros_like(logits).to(logits)
mask = (logits > th_logit)
if num_labels > 0:
top_v, _ = torch.topk(logits, num_labels, dim=1)
top_v = top_v[:, -1]
mask = (logits >= top_v.unsqueeze(1)) & mask
output[mask] = 1.0
output[:, 0] = (output.sum(1) == 0.).to(logits)
return output