Skip to content

Commit

Permalink
fine tune
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhimaoLin committed Apr 16, 2022
1 parent 1d70645 commit 183433f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 19 deletions.
14 changes: 9 additions & 5 deletions ExtendedMTL4Pain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,16 @@ def print_opts(opts):

ARGS = AttrDict()
args_dict = {
'image_scale_to_before_crop': 256,
'image_scale_to_before_crop': 256, # 300 is better
'image_size': 160,
'number_output':1,
'batch_size': 50,
'overall_learning_rate': 0.00001,
'last_layer_learning_rate': 0.0001,
'weight_decay': 0.0005,
'epoch': 2
'epoch': 3,
'train_sample': 5000,
'test_sample': 300
}
ARGS.update(args_dict)
#endregion
Expand Down Expand Up @@ -215,21 +217,23 @@ def main():
split_dataset(DATA_SUMMARY_CSV_PATH, TRAIN_DATA_CSV_PATH, TEST_DATA_CSV_PATH, RANDOM_SEED, train_fraction=TRAIN_FRACTION)

# region Test Code
sample_data(TRAIN_DATA_CSV_PATH, 50, RANDOM_SEED)
sample_data(TRAIN_DATA_CSV_PATH, ARGS.train_sample, RANDOM_SEED)
# endregion

print_opts(ARGS)

start = time.time()
net = train(TRAIN_DATA_CSV_PATH, ARGS)
end = time.time()
print(f"Runtime of the program is [{end - start}] seconds")
print(f"Runtime of the program is [{(end - start)/60}] minutes")

model_path = os.path.join(RESULT_PATH, "model.pt")
save_trained_model(net, model_path)

old_model = load_trained_model(model_path, create_model, ARGS)

# region Test Code
sample_data(TEST_DATA_CSV_PATH, 50, RANDOM_SEED)
sample_data(TEST_DATA_CSV_PATH, ARGS.test_sample, RANDOM_SEED)
# endregion

evaluation(old_model, TEST_DATA_CSV_PATH, ARGS)
Expand Down
38 changes: 24 additions & 14 deletions MainPaper/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ def print_opts(opts):
'image_size': 160,
'number_output': 1,
'image_sample_size': 5,
'batch_size': 50,
'batch_size': 100,
'drop_out': 0,
'fc2_size': 200,
'learning_rate': 0.001,
'epoch': 2
'epoch': 3,
'train_sample': 3000,
'test_sample': 150
}
ARGS.update(args_dict)
#endregion
Expand Down Expand Up @@ -221,24 +223,32 @@ def main():
split_dataset(DATA_CSV_PATH, TRAIN_DATA_CSV_PATH, TEST_DATA_CSV_PATH, RANDOM_SEED, TRAIN_FRACTION)

# region Test Code
sample_data(TRAIN_DATA_CSV_PATH, 1000, RANDOM_SEED)
sample_data(TRAIN_DATA_CSV_PATH, ARGS.train_sample, RANDOM_SEED)
# endregion

start = time.time()
net = train(TRAIN_DATA_CSV_PATH, ARGS)
end = time.time()
print(f"Runtime of the program is [{end - start}] seconds")
for dropout in [0, 0.25]:
ARGS.drop_out = dropout
global RESULT_PATH
RESULT_PATH = "result"
RESULT_PATH = os.path.join(RESULT_PATH, f"dropout_{dropout}")

model_path = os.path.join(RESULT_PATH, "model.pt")
save_trained_model(net, model_path)

old_model = load_trained_model(model_path, create_model, ARGS)
start = time.time()
print_opts(ARGS)
net = train(TRAIN_DATA_CSV_PATH, ARGS)
end = time.time()
print(f"Runtime of the program is [{(end - start)/60}] minutes")

# region Test Code
sample_data(TEST_DATA_CSV_PATH, 100, RANDOM_SEED)
# endregion
model_path = os.path.join(RESULT_PATH, "model.pt")
save_trained_model(net, model_path)

old_model = load_trained_model(model_path, create_model, ARGS)

# region Test Code
sample_data(TEST_DATA_CSV_PATH, ARGS.test_sample, RANDOM_SEED)
# endregion

evaluation(old_model, TEST_DATA_CSV_PATH, ARGS)
evaluation(old_model, TEST_DATA_CSV_PATH, ARGS)



Expand Down

0 comments on commit 183433f

Please sign in to comment.