Skip to content

Commit

Permalink
update process_data.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhimaoLin committed Apr 13, 2022
1 parent dd4b189 commit 17d1ebf
Showing 1 changed file with 49 additions and 10 deletions.
59 changes: 49 additions & 10 deletions process_data.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,48 @@
import os
import re
from shared_util.preprocess import PreProcess


import matplotlib.pyplot as plt
import matplotlib.image as mpimg
# from pain_detector import PainDetector

# global variables
PSPI_DIR = 'data\Frame_Labels\Frame_Labels\PSPI'
IMAGES_DIR = 'data\Images\Images'
PSPI_DIR = './data/Frame_Labels/Frame_Labels/PSPI'
IMAGES_DIR = './data/Images/Images'
DATA_SUMMARY_CSV = 'data_summary.csv'


#region arguments section
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self


def print_opts(opts):
"""Prints the values of all command-line arguments.
"""
print('=' * 80)
print('Opts'.center(80))
print('-' * 80)
for key in opts.__dict__:
if opts.__dict__[key]:
print('{:>30}: {:<30}'.format(key, opts.__dict__[key]).center(80))
print('=' * 80)


ARGS = AttrDict()
args_dict = {
# 'fan_checkpoint': os.path.abspath("../shared_util/face_alignment/checkpoints/59448122/59448122_3/model_epoch13.pt"),
# 'standard_face_path': os.path.abspath("../shared_util/face_alignment/standard_face_68.npy"),
'image_scale_to_before_crop': 256,
'image_size': 160
}
ARGS.update(args_dict)
#endregion



# Write a row to a file. The row has to be a tuple of strings
# Input: "./data_summary.csv", ("person_name", "video_name", "frame_number", "pspi_score", "image_path"), "w"
Expand All @@ -22,6 +56,8 @@ def write_row_to_file(file_path, row, mode="a"):
# Input: 'data\Frame_Labels\Frame_Labels\PSPI', 'data\Images\Images', 'data_summary.csv'
# output: The path to the csv file
def create_data_summary_csv(pspi_dir, images_dir, data_summary_csv):
preprocess = PreProcess(ARGS)

# The header of the csv
header = ("person_name", "video_name", "frame_number", "pspi_score", "image_path")
write_row_to_file(data_summary_csv, header, "w")
Expand All @@ -30,14 +66,14 @@ def create_data_summary_csv(pspi_dir, images_dir, data_summary_csv):
for file in files:
# Process labels
label_path = os.path.join(root, file)
label_path = label_path.replace("/", "\\")
label_path = label_path.replace("\\", "/")

path_array = label_path.split("\\")
person_name = path_array[4]
video_name = path_array[5]
path_array = label_path.split("/")
person_name = path_array[5]
video_name = path_array[6]

frame_number = '0'
m = re.search(video_name + '(\d+)', path_array[6])
m = re.search(video_name + '(\d+)', path_array[7])
if m:
frame_number = m.group(1)

Expand All @@ -52,9 +88,12 @@ def create_data_summary_csv(pspi_dir, images_dir, data_summary_csv):
image_path = os.path.join(images_dir, person_name, video_name, image_file_name)
image_path = os.path.abspath(image_path)

# Write to the csv
row = (person_name, video_name, frame_number, pspi_score, image_path)
write_row_to_file(data_summary_csv, row)
if preprocess.test_image(image_path):
# Write to the csv
row = (person_name, video_name, frame_number, pspi_score, image_path)
write_row_to_file(data_summary_csv, row)
else:
print(f"Skip image: [{image_path}]")

return data_summary_csv

Expand Down

0 comments on commit 17d1ebf

Please sign in to comment.