-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathPGLoss.py
35 lines (27 loc) · 1023 Bytes
/
PGLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import math
class PGLoss(torch.nn.Module):
def __init__(self, ignore_index=None, size_average=False, reduce=True):
super(PGLoss, self).__init__()
self.size_average = size_average
self.ignore_index = ignore_index
self.reduce = reduce
def forward(self, logprobs, label, reward, use_cuda):
bsz, seqlen, _ = logprobs.size()
loss = 0
logprobs = logprobs.clone()
for i in range(bsz):
trg_label = label[i,:]
row_idx = torch.LongTensor(range(seqlen))
if use_cuda:
row_idx = row_idx.cuda()
if self.ignore_index != None:
logprobs[:, :, self.ignore_index] = 0
loss = loss + (-torch.sum(logprobs[i, :, :][row_idx, trg_label] * reward[i]))
if self.size_average:
loss = loss/bsz
return loss