-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
36 lines (29 loc) · 1.23 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from save_vocab import save_vocab, get_class
from dataset import TextDataSet, idx_to_words, read_vocab
from train import train
import torch
from model import TextCNN
vocab_path = "path of vacab.txt"
class_path = "path of class.txt"
# 获取train, dev, test数据集
dataset = TextDataSet("train", vocab_path)
dev_dataset = TextDataSet("dev", vocab_path)
test_dataset = TextDataSet("test", vocab_path)
# 使用DataLoader加载数据
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=4, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classes, class_to_idx = get_class(class_path)
words, word_to_idx = read_vocab(vocab_path)
vocab_size = len(words)
class_num = len(classes)
embedding_dim = 128
# 定义模型,损失函数,优化器
model = TextCNN(vocab_size, embedding_dim, class_num, batch_size=4)
model = model.to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)
epoches = 10
# 训练模型
train(dataloader, dev_dataloader, model, loss_fn, optimizer, epoches, device)