This repository has been archived by the owner on Oct 12, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy path1_prepareData.py
74 lines (63 loc) · 2.78 KB
/
1_prepareData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# -*- coding: utf-8 -*-
from helpers import *
locals().update(importlib.import_module("PARAMETERS").__dict__)
####################################
# Main
####################################
random.seed(0)
makeDirectory(procDir)
imgFilenamesTest = dict()
imgFilenamesTrain = dict()
print("Split images into train or test...")
subdirs = getDirectoriesInDirectory(imgDir)
for subdir in subdirs:
filenames = getFilesInDirectory(imgDir + subdir, ".jpg")
# Randomly assign images into train or test
if imagesSplitBy == 'filename':
filenames = randomizeList(filenames)
splitIndex = int(ratioTrainTest * len(filenames))
imgFilenamesTrain[subdir] = filenames[:splitIndex]
imgFilenamesTest[subdir] = filenames[splitIndex:]
# Randomly assign whole subdirectories to train or test
elif imagesSplitBy == 'subdir':
if random.random() < ratioTrainTest:
imgFilenamesTrain[subdir] = filenames
else:
imgFilenamesTest[subdir] = filenames
else:
raise Exception("Variable 'imagesSplitBy' has to be either 'filename' or 'subdir'")
# Debug print
if subdir in imgFilenamesTrain:
print("Training: {:5} images in directory {}".format(len(imgFilenamesTrain[subdir]), subdir))
if subdir in imgFilenamesTest:
print("Testing: {:5} images in directory {}".format(len(imgFilenamesTest[subdir]), subdir))
# Save assignments of images to train or test
saveToPickle(imgFilenamesTrainPath, imgFilenamesTrain)
saveToPickle(imgFilenamesTestPath, imgFilenamesTest)
# Mappings label <-> id
lutId2Label = dict()
lutLabel2Id = dict()
for index, key in enumerate(imgFilenamesTrain.keys()):
lutLabel2Id[key] = index
lutId2Label[index] = key
saveToPickle(lutLabel2IdPath, lutLabel2Id)
saveToPickle(lutId2LabelPath, lutId2Label)
# Compute positive and negative image pairs
print("Generate training data ...")
imgInfosTrain = getImagePairs(imgFilenamesTrain, train_maxQueryImgsPerSubdir, train_maxNegImgsPerQueryImg)
saveToPickle(imgInfosTrainPath, imgInfosTrain)
print("Generate test data ...")
imgInfosTest = getImagePairs(imgFilenamesTest, test_maxQueryImgsPerSubdir, test_maxNegImgsPerQueryImg)
saveToPickle(imgInfosTestPath, imgInfosTest)
# Sanity check - make sure the test and training set have no images in common
if True:
print("Verifying if training and test set are disjoint:")
pathsTest = getImgPaths(loadFromPickle(imgInfosTestPath))
pathsTrain = getImgPaths(loadFromPickle(imgInfosTrainPath))
# Make sure the training set and test set have zero overlap
overlap = len(pathsTrain.intersection(pathsTest))
if overlap == 0:
print(" Check passed: Training and test set share no images.")
else:
raise Exception("Training and test set share %d images." % overlap)
print("DONE.")