-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_database.py
57 lines (42 loc) · 1.23 KB
/
generate_database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import sys
import os
import random
import re
def generate_train(files, n):
random.shuffle(files)
return random.sample(files, n)
def generate_test(files, train_files):
return list(set(files) - set(train_files))
'''def get_trailing_numbers(s):
m = re.search(r'\d+$', s)
return m.group()
'''
def write_file(path, files):
print("[+]Write ", path)
with open(path, "w") as f:
for file in files:
f.write(file)
f.write("\n")
def generate_data(src, db):
train_files = []
test_files = []
for folder in os.listdir(src):
print("[+]Access folder ", folder)
folder_path = os.path.join(src, folder)
files = [os.path.join(folder_path, file) for file in os.listdir(folder_path)]
n = len(files)
n_train = int(n * 0.8)
train_files.extend(generate_train(files, n_train))
test_files.extend(generate_test(files, train_files))
print("[+]Create folder ", db)
# os.makedirs(db)
print("[+]Change current wd to ", db)
os.chdir(db)
write_file("train.txt", train_files)
write_file("test.txt", test_files)
def main():
src = sys.argv[1]
db = sys.argv[2]
generate_data(src, db)
if __name__ == '__main__':
main()