-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdvi-tf_test.py
110 lines (90 loc) · 5.33 KB
/
dvi-tf_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
########################################################################################################################
# IMPORT #
########################################################################################################################
import torch
import sys
import os
import json
import argparse
from singleVis.data import NormalDataProvider
from singleVis.projector import tfDVIProjector
from singleVis.eval.evaluator import Evaluator
########################################################################################################################
# DVI PARAMETERS #
########################################################################################################################
"""This serve as an example of DeepVisualInsight implementation in pytorch."""
VIS_METHOD = "tfDVI" # DeepVisualInsight
########################################################################################################################
# LOAD PARAMETERS #
########################################################################################################################
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
# os.environ["CUDA_VISIBLE_DEVICES"]="0"
parser = argparse.ArgumentParser(description='Process hyperparameters...')
parser.add_argument('--content_path',"-c", type=str)
args = parser.parse_args()
CONTENT_PATH = args.content_path
sys.path.append(CONTENT_PATH)
with open(os.path.join(CONTENT_PATH, "config.json"), "r") as f:
config = json.load(f)
config = config[VIS_METHOD]
# record output information
# now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
# sys.stdout = open(os.path.join(CONTENT_PATH, now+".txt"), "w")
SETTING = config["SETTING"]
CLASSES = config["CLASSES"]
DATASET = config["DATASET"]
PREPROCESS = config["VISUALIZATION"]["PREPROCESS"]
GPU_ID = config["GPU"]
EPOCH_START = config["EPOCH_START"]
EPOCH_END = config["EPOCH_END"]
EPOCH_PERIOD = config["EPOCH_PERIOD"]
# Training parameter (subject model)
TRAINING_PARAMETER = config["TRAINING"]
NET = TRAINING_PARAMETER["NET"]
LEN = TRAINING_PARAMETER["train_num"]
# Training parameter (visualization model)
VISUALIZATION_PARAMETER = config["VISUALIZATION"]
LAMBDA1 = VISUALIZATION_PARAMETER["LAMBDA1"]
LAMBDA2 = VISUALIZATION_PARAMETER["LAMBDA2"]
B_N_EPOCHS = VISUALIZATION_PARAMETER["BOUNDARY"]["B_N_EPOCHS"]
L_BOUND = VISUALIZATION_PARAMETER["BOUNDARY"]["L_BOUND"]
ENCODER_DIMS = VISUALIZATION_PARAMETER["ENCODER_DIMS"]
DECODER_DIMS = VISUALIZATION_PARAMETER["DECODER_DIMS"]
S_N_EPOCHS = VISUALIZATION_PARAMETER["S_N_EPOCHS"]
N_NEIGHBORS = VISUALIZATION_PARAMETER["N_NEIGHBORS"]
PATIENT = VISUALIZATION_PARAMETER["PATIENT"]
FLAG = VISUALIZATION_PARAMETER["FLAG"]
MAX_EPOCH = VISUALIZATION_PARAMETER["MAX_EPOCH"]
EVALUATION_NAME = VISUALIZATION_PARAMETER["EVALUATION_NAME"]
# Define hyperparameters
DEVICE = torch.device("cuda:{}".format(GPU_ID) if torch.cuda.is_available() else "cpu")
import Model.model as subject_model
net = eval("subject_model.{}()".format(NET))
########################################################################################################################
# TRAINING SETTING #
########################################################################################################################
# Define data_provider
data_provider = NormalDataProvider(CONTENT_PATH, net, EPOCH_START, EPOCH_END, EPOCH_PERIOD, device=DEVICE, classes=CLASSES,epoch_name="Epoch", verbose=1)
if PREPROCESS:
data_provider._meta_data()
if B_N_EPOCHS >0:
data_provider._estimate_boundary(LEN//10, l_bound=L_BOUND)
# Define Projector
projector = tfDVIProjector(CONTENT_PATH, flag=FLAG)
########################################################################################################################
# VISUALIZATION #
########################################################################################################################
# from singleVis.visualizer import visualizer
# vis = visualizer(data_provider, projector, 200, "tab10")
# save_dir = os.path.join(data_provider.content_path, "img")
# if not os.path.exists(save_dir):
# os.mkdir(save_dir)
# for i in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD):
# vis.save_default_fig(i, path=os.path.join(save_dir, "{}_{}_{}.png".format(DATASET, i, VIS_METHOD)))
########################################################################################################################
# EVALUATION #
########################################################################################################################
eval_epochs = range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD)
evaluator = Evaluator(data_provider, projector)
for eval_epoch in eval_epochs:
evaluator.save_epoch_eval(eval_epoch, 15, temporal_k=5, file_name="{}".format(EVALUATION_NAME))