diff --git a/MainPaper/train.py b/MainPaper/train.py index e83f2d1..769148f 100644 --- a/MainPaper/train.py +++ b/MainPaper/train.py @@ -11,6 +11,8 @@ from util import * import matplotlib.pyplot as plt +import time + DATA_SUMMARY_CSV_PATH = '../data_summary.csv' DATA_SUMMARY_HEADER = {"person":"person_name", "video":"video_name", "frame":"frame_number", "pspi":"pspi_score", "image":"image_path"} @@ -49,12 +51,12 @@ def print_opts(opts): 'image_scale_to_before_crop': 320, 'image_size': 160, 'number_output': 1, - 'image_sample_size': 1, - 'batch_size': 100, + 'image_sample_size': 5, + 'batch_size': 50, 'drop_out': 0, 'fc2_size': 200, 'learning_rate': 0.001, - 'epoch': 1 + 'epoch': 10 } ARGS.update(args_dict) #endregion @@ -153,7 +155,7 @@ def train(train_data_path, ops): title = f"Epoch={epoch}, Train loss of each batch" x_label = "Batch number" y_label = "Loss" - save_path = os.path.join("./result", f"Epoch_{epoch}_loss.png") + save_path = os.path.join("./result", f"Epoch_{epoch+1}_loss.png") draw_line_chart(x_list, train_loss_for_each_batch_list, title, x_label, y_label, save_path) avg_loss_for_each_epoch = np.array(train_loss_for_each_batch_list).mean() @@ -172,7 +174,10 @@ def train(train_data_path, ops): -def evaluation(net, test_data_loader): +def evaluation(net, test_data_path, ops): + test_dataset = MyDataset(test_data_path, ops) + test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=ops.batch_size, shuffle=True) + mse_loss = nn.MSELoss() total_number_batch = 0 @@ -204,8 +209,13 @@ def main(): create_data_csv(DATA_CSV_PATH, DATA_SUMMARY_CSV_PATH, ARGS) split_dataset(DATA_CSV_PATH, TRAIN_DATA_CSV_PATH, TEST_DATA_CSV_PATH, RANDOM_SEED, TRAIN_FRACTION) + start = time.time() net = train(TRAIN_DATA_CSV_PATH, ARGS) - evaluation(net, TEST_DATA_CSV_PATH) + end = time.time() + print(f"Runtime of the program is {end - start}") + + + evaluation(net, TEST_DATA_CSV_PATH, ARGS)