forked from Cram3r95/mapfe4mp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
143 lines (111 loc) · 5.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python3.8
# -*- coding: utf-8 -*-
## Main file to train the model
"""
Created on Sun Mar 06 23:47:19 2022
@author: Carlos Gómez-Huélamo
"""
# General purpose imports
import sys
import yaml
import git
import logging
import os
import sys
import argparse
import importlib
import pdb
from prodict import Prodict
from datetime import datetime
#######################################
repo = git.Repo('.', search_parent_directories=True)
BASE_DIR = repo.working_tree_dir
sys.path.append(BASE_DIR)
TRAINER_LIST = [
"social_lstm_mhsa",
"social_lstm_mhsa_mm"
"gan_social_lstm_mhsa",
"pv_lstm",
"pv_lstm_mm",
"sophie",
"sophie_mm",
"mapfe4mp",
"cghformer"
]
def create_logger(file_path):
"""
"""
FORMAT = '[%(levelname)s: %(lineno)4d]: %(message)s'
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format=FORMAT)
file_handler = logging.FileHandler(file_path, mode="a")
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)
return logger
if __name__ == "__main__":
"""
"""
parser = argparse.ArgumentParser()
parser.add_argument("--trainer", required=True, type=str, choices=TRAINER_LIST)
parser.add_argument("--device_gpu", required=True, type=int, default=0)
parser.add_argument("--from_exp", type=str, default=None)
parser.add_argument("--overwrite_exp", type=bool, default=False)
parser.add_argument("--num_ckpt", type=str, default="0")
parser.add_argument("--batch_size", type=int, default=0)
parser.add_argument("--output_dir", type=str, default="save")
args = parser.parse_args()
print(args.trainer)
# Get model trainer for the current architecture
curr_trainer = "model.trainers.trainer_%s" % args.trainer
curr_trainer_module = importlib.import_module(curr_trainer)
model_trainer = curr_trainer_module.model_trainer
# Get configuration for the current architecture
if not args.from_exp: # Initialize new experiment from your current
# config file in configs folder
config_path = "./config/config_%s.yml" % args.trainer
else: # Continue training from previous checkpoint with the
# corresponding config file
assert os.path.isdir(args.from_exp), print("Checkpoint not found!")
config_path = os.path.join(args.from_exp,"config_file.yml")
now = datetime.now()
exp_name = now.strftime("exp-%Y-%m-%d_%Hh") # -%M")
with open(config_path) as config_file:
config = yaml.safe_load(config_file)
config["device_gpu"] = args.device_gpu
config["base_dir"] = BASE_DIR
if not config["hyperparameters"]["exp_name"]: # Empty -> Fill with current hour and day
config["hyperparameters"]["exp_name"] = exp_name
split_percentage_str = str(int(100*config["dataset"]["split_percentage"])) + "_percent"
config["hyperparameters"]["output_dir"] = os.path.join(config["hyperparameters"]["save_root_dir"],
config["model"]["name"],
split_percentage_str,
config["hyperparameters"]["exp_name"])
route_path = config["hyperparameters"]["output_dir"] + "/config_file.yml"
if args.from_exp and os.path.isdir(args.from_exp): # Overwrite checkpoint_start_from
model = config["dataset_name"] + "_" + args.num_ckpt + "_with_model.pt"
config["hyperparameters"]["checkpoint_start_from"] = os.path.join(args.from_exp,model)
else:
dataset_name = config["dataset_name"]
filename = os.path.join(config["hyperparameters"]["output_dir"],f"{dataset_name}_0_with_model.pt")
if args.overwrite_exp:
path_to_remove = config["hyperparameters"]["output_dir"]
os.system(f"rm -rf {path_to_remove}")
else:
assert not os.path.exists(filename),print("This path already has a checkpoint!")
if not os.path.exists(config["hyperparameters"]["output_dir"]):
print("Create experiment path: ", config["hyperparameters"]["output_dir"])
os.makedirs(config["hyperparameters"]["output_dir"]) # makedirs creates intermediate folders
with open(route_path,'w') as yaml_file:
yaml.dump(config, yaml_file, default_flow_style=False)
config = Prodict.from_dict(config)
# Create logger
logger = create_logger(os.path.join(config["hyperparameters"]["output_dir"], f"{config.dataset_name}_{exp_name}.log"))
logger.info("Config file: {}".format(config_path))
# Modify some variables of the configuration given input arguments
# E.g. The original config has batch_size = 1024 but now your GPU is
# almost full -> You keep 1024 in the file (as a reference) but here
# you specify a smaller batch_size (e.g. 512), that is, we do not want
# to incorporate these temporal changes to the original config file
if args.batch_size: config.dataset.batch_size = args.batch_size
if args.output_dir != "save": config.hyperparameters.output_dir = args.output_dir
model_trainer(config, logger)