-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathmain.py
60 lines (48 loc) · 1.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from utils import parse_args, create_experiment_dirs, calculate_flops
from model import MobileNet
from train import Train
from data_loader import DataLoader
from summarizer import Summarizer
import tensorflow as tf
def main():
# Parse the JSON arguments
try:
config_args = parse_args()
except:
print("Add a config file using \'--config file_name.json\'")
exit(1)
# Create the experiment directories
_, config_args.summary_dir, config_args.checkpoint_dir = create_experiment_dirs(config_args.experiment_dir)
# Reset the default Tensorflow graph
tf.reset_default_graph()
# Tensorflow specific configuration
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# Data loading
data = DataLoader(config_args.batch_size, config_args.shuffle)
print("Loading Data...")
config_args.img_height, config_args.img_width, config_args.num_channels, \
config_args.train_data_size, config_args.test_data_size = data.load_data()
print("Data loaded\n\n")
# Model creation
print("Building the model...")
model = MobileNet(config_args)
print("Model is built successfully\n\n")
# Summarizer creation
summarizer = Summarizer(sess, config_args.summary_dir)
# Train class
trainer = Train(sess, model, data, summarizer)
if config_args.to_train:
try:
print("Training...")
trainer.train()
print("Training Finished\n\n")
except KeyboardInterrupt:
trainer.save_model()
if config_args.to_test:
print("Final test!")
trainer.test('val')
print("Testing Finished\n\n")
if __name__ == '__main__':
main()