forked from winterant/DeepCoNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
78 lines (61 loc) · 3.3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import time
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from config import Config
from model import DeepCoNN
from utils import load_embedding, DeepCoNNDataset, predict_mse, date
def train(train_dataloader, valid_dataloader, model, config, model_path):
print(f'{date()}## Start the training!')
train_mse = predict_mse(model, train_dataloader, config.device)
valid_mse = predict_mse(model, valid_dataloader, config.device)
print(f'{date()}#### Initial train mse {train_mse:.6f}, validation mse {valid_mse:.6f}')
start_time = time.perf_counter()
opt = torch.optim.Adam(model.parameters(), config.learning_rate, weight_decay=config.l2_regularization)
lr_sch = torch.optim.lr_scheduler.ExponentialLR(opt, config.learning_rate_decay)
best_loss, best_epoch = 100, 0
for epoch in range(config.train_epochs):
model.train() # 将模型设置为训练状态
total_loss, total_samples = 0, 0
for batch in train_dataloader:
user_reviews, item_reviews, ratings = map(lambda x: x.to(config.device), batch)
predict = model(user_reviews, item_reviews)
loss = F.mse_loss(predict, ratings, reduction='sum') # 平方和误差
opt.zero_grad() # 梯度清零
loss.backward() # 反向传播计算梯度
opt.step() # 根据梯度信息更新所有可训练参数
total_loss += loss.item()
total_samples += len(predict)
lr_sch.step()
model.eval() # 停止训练状态
valid_mse = predict_mse(model, valid_dataloader, config.device)
train_loss = total_loss / total_samples
print(f"{date()}#### Epoch {epoch:3d}; train mse {train_loss:.6f}; validation mse {valid_mse:.6f}")
if best_loss > valid_mse:
best_loss = valid_mse
torch.save(model, model_path)
end_time = time.perf_counter()
print(f'{date()}## End of training! Time used {end_time - start_time:.0f} seconds.')
def test(dataloader, model):
print(f'{date()}## Start the testing!')
start_time = time.perf_counter()
test_loss = predict_mse(model, dataloader, config.device)
end_time = time.perf_counter()
print(f"{date()}## Test end, test mse is {test_loss:.6f}, time used {end_time - start_time:.0f} seconds.")
if __name__ == '__main__':
config = Config()
print(config)
print(f'{date()}## Load embedding and data...')
word_emb, word_dict = load_embedding(config.word2vec_file)
train_dataset = DeepCoNNDataset(config.train_file, word_dict, config)
valid_dataset = DeepCoNNDataset(config.valid_file, word_dict, config, retain_rui=False)
test_dataset = DeepCoNNDataset(config.test_file, word_dict, config, retain_rui=False)
train_dlr = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
valid_dlr = DataLoader(valid_dataset, batch_size=config.batch_size)
test_dlr = DataLoader(test_dataset, batch_size=config.batch_size)
model = DeepCoNN(config, word_emb).to(config.device)
del train_dataset, valid_dataset, test_dataset, word_emb, word_dict
os.makedirs(os.path.dirname(config.model_file), exist_ok=True) # 文件夹不存在则创建
train(train_dlr, valid_dlr, model, config, config.model_file)
test(test_dlr, torch.load(config.model_file))