-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrainTest.py
65 lines (51 loc) · 2.74 KB
/
TrainTest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torchaudio
import torch.nn as nn
from comet_ml import Experiment
from comet_ml.integration.pytorch import log_model
from torchmetrics.text import WordErrorRate
import utils
def train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, experiment):
model.train()
data_len = len(train_loader.dataset)
with experiment.train():
for batch_idx, _data in enumerate(train_loader):
spectrograms, labels, input_lengths, label_lengths = _data
spectrograms, labels = spectrograms.to(device), labels.to(device)
optimizer.zero_grad()
output = model(spectrograms) # OUT (batch, time, n_class)
output = nn.functional.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
loss = criterion(output, labels, input_lengths, label_lengths)
loss.backward()
experiment.log_metric('loss', loss.item(), step=iter_meter.get())
experiment.log_metric('learning_rate', scheduler.get_lr(), step=iter_meter.get())
optimizer.step()
scheduler.step()
iter_meter.step()
if batch_idx % 100 == 0 or batch_idx == data_len:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(spectrograms), data_len,
100. * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader, criterion, epoch, iter_meter, experiment):
print('\nevaluating…')
model.eval()
test_loss = 0
test_wer = []
with experiment.test():
with torch.no_grad():
for I, _data in enumerate(test_loader):
spectrograms, labels, input_lengths, label_lengths = _data
spectrograms, labels = spectrograms.to(device), labels.to(device)
output = model(spectrograms) # (batch, time, n_class)
output = nn.functional.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
loss = criterion(output, labels, input_lengths, label_lengths)
test_loss += loss.item() / len(test_loader)
decoded_preds, decoded_targets = utils.GreedyDecoder(output.transpose(0, 1), labels, label_lengths)
for j in range(len(decoded_preds)):
test_wer.append(WordErrorRate(preds=decoded_preds[j], targets=decoded_targets[j]))
avg_wer = sum(test_wer)/len(test_wer)
experiment.log_metric('test_loss', test_loss, step=iter_meter.get())
experiment.log_metric('wer', avg_wer, step=iter_meter.get())
print('Test set: Average loss: {:.4f}, Average WER: {:.4f}\n'.format(test_loss, avg_wer))