-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
127 lines (117 loc) · 5.03 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import utils
import numpy as np
import torch
import torch.nn as nn
from lda import LDAmodel
from model import LSTMnet,CNNnet,ANNnet,GCNnet
from dataset import FMGdataset
from model_config import build_args
import matplotlib.pyplot as plt
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
def main(args):
os.environ['CUDA_VIVIBLE_DEVICES'] = args.gpu
device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
args.device=device
adj=None
if args.model_name == "LSTM":
model=LSTMnet(in_dim=len(args.channels),hidden_dim=args.hidden_dim,n_layer=args.n_layer,n_class=args.n_class)
elif args.model_name == "ANN":
model=ANNnet(in_dim=len(args.channels),hidden_dim=args.hidden_dim,n_layer=args.n_layer,n_class=args.n_class)
elif args.model_name == "CNN":
model=CNNnet(in_dim=len(args.channels),hidden_dim=args.hidden_dim,n_layer=args.n_layer,n_class=args.n_class)
elif args.model_name == "GCN":
model=GCNnet(input_dim=args.L_win,hidden_dim=args.hidden_dim,output_dim=args.n_class,num_channel=len(args.channels))
adj=utils.get_adjmatrix(args)
else:
print("Model's name is not in the list!");return -1
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.BCELoss(reduction='none')
#train_val
print("TrainVal_subjects_index:{}".format(args.subindex))
train_dataset=FMGdataset(args,test_ratio=args.test_ratio,phase="train")
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size,shuffle=True,num_workers=1)
test_dataset=FMGdataset(args,test_ratio=args.test_ratio,phase="test")
test_loader=torch.utils.data.DataLoader(test_dataset,batch_size=1,shuffle=False,num_workers=1)
if args.mode == "train":
save_path="./models"
utils.make_dir(save_path)
best_acc = 0.0
for epoch in range(args.epochs):
utils.train(epoch,model,train_loader,optimizer,criterion,device,adj)
_,acc,_=utils.test(epoch,model,test_loader,criterion,args.n_class,device,adj)
if acc > best_acc:
best_acc=acc
torch.save(model.state_dict(),os.path.join(save_path,"best{}_{}_S{}.pth".format(args.model_name,args.dataset,args.subindex)))
print("Current accuracy:{:.3f},Best accuracy:{:.3f}".format(acc,best_acc))
print("#-----------------------------------------------------------------------#")
return best_acc
elif args.mode == "inference":
print("inference start")
model.load_state_dict(torch.load("./models/best{}_{}_S{}.pth".format(args.model_name,args.dataset,args.subindex)))
cm,acc,t_mean=utils.test(0,model,test_loader,criterion,args.n_class,device,adj)
print("Current accuracy:{:.3f}".format(acc))
return cm,acc,t_mean
def get_model_performance():
utils.set_seed(0)
for model_name in ["LDA","ANN","CNN","LSTM","GCN"]:
args=build_args(model_name)
CM=np.zeros((5,args.n_class,args.n_class))
ACC=np.zeros(5)
for i in range(5):
print("#-----------------------------------------------------------------------#")
print(model_name,i)
args.subindex=i
if model_name == "LDA":
cm,acc,_=LDAmodel(args)
else:
args.mode = "train"
main(args)
args.mode = "inference"
cm,acc,_=main(args)
CM[i,:,:]=cm
ACC[i]=acc
save_dir=os.path.join(args.output_root,"model_performance")
utils.make_dir(save_dir)
np.save(save_dir+"/{}_CM.npy".format(model_name),CM)
np.save(save_dir+"/{}_ACC.npy".format(model_name),ACC)
def get_time_delay():
for model_name in ["LDA","ANN","CNN","LSTM","GCN"]:
print(model_name)
args=build_args(model_name)
Time=np.zeros(5)
for i in range(5):
args.subindex=i
if model_name == "LDA":
_,_,t_mean=LDAmodel(args)
else:
args.mode = "inference"
_,_,t_mean=main(args)
Time[i]=t_mean+args.L_win/1000
save_dir=os.path.join(args.output_root,"model_usedtime")
utils.make_dir(save_dir)
np.save(save_dir+"/{}_Time.npy".format(model_name),Time)
print(Time)
if __name__ == "__main__":
# Normal Function
utils.set_seed(0)
model_name="GCN"
args=build_args(model_name)
args.subindex=0 #index of subjects
if model_name == "LDA":
cm,acc,t_mean=LDAmodel(args)
else:
args.mode = "train"
main(args)
args.mode = "inference"
cm,acc,t_mean=main(args)
print("Average inference time:{}".format(t_mean))
title="Confision matrix of "+args.model_name
utils.v_confusion_matrix(cm,args.part_actions,title=title,save_path="./figure/{}_CM.pdf".format(args.model_name))
plt.show()
#model_performance
# get_model_performance()
#Time_delay Analysis
# get_time_delay()