-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
86 lines (69 loc) · 3.17 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from comet_ml import Experiment
import torch
import torch.nn as nn
import torchaudio
import os
import Model
import utils
import TrainTest
def main(learning_rate=5e-4,
batch_size=20,
epochs=10,
train_url="train-clean-100",
test_url="test-clean",
experiment=Experiment(
api_key="RDMwDBjNyTYsfyWcwNEzZQulc",
project_name="general",
workspace="dustin-wang"
)):
hparams = {
"n_cnn_layers": 3,
"n_rnn_layers": 5,
"rnn_dim": 512,
"n_class": 29,
"n_feats": 128,
"stride": 2,
"dropout": 0.1,
"learning_rate": learning_rate,
"batch_size": batch_size,
"epochs": epochs
}
torchaudio.set_audio_backend("soundfile")
experiment.log_parameters(hparams)
use_cuda = torch.cuda.is_available()
torch.manual_seed(7)
device = torch.device("cuda" if use_cuda else "cpu")
if not os.path.isdir("./data"):
os.makedirs("./data")
train_dataset = torchaudio.datasets.LIBRISPEECH("./data", url=train_url, download=True)
test_dataset = torchaudio.datasets.LIBRISPEECH("./data", url=test_url, download=True)
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=hparams['batch_size'],
shuffle=True,
collate_fn=lambda x: utils.data_processing(x, 'train'),
**kwargs) # For each batch of samples run this function on them
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=hparams['batch_size'],
shuffle=False,
collate_fn=lambda x: utils.data_processing(x, 'valid'),
**kwargs)
model = Model.SpeechRecognitionModel(
hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
hparams['n_class'], hparams['n_feats'], hparams['stride'], hparams['dropout']
).to(device)
print(model)
print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))
optimizer = torch.optim.AdamW(model.parameters(), hparams['learning_rate'])
criterion = nn.CTCLoss(blank=28).to(device)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=hparams['learning_rate'],
steps_per_epoch=int(len(train_loader)),
epochs=hparams['epochs'],
anneal_strategy='linear')
iter_meter = utils.IterMeter()
for epoch in range(1, epochs + 1):
TrainTest.train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, experiment)
TrainTest.test(model, device, test_loader, criterion, epoch, iter_meter, experiment)
experiment.end()
torch.save(model.state_dict(), "./model_saves")
main()