-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsplit_data_to_classes.py
61 lines (48 loc) · 1.66 KB
/
split_data_to_classes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import concurrent.futures
import json
import os
import sys
from pathlib import Path
from urllib.error import HTTPError
import requests
from dotenv import load_dotenv
from loguru import logger
from helpers import load_export_data
def main(data):
labels = list(
set(
sum(
sum([[x['rectanglelabels'] for x in data[n]['label']]
for n in range(len(data))
if 'label' in list(data[n].keys())], []), [])))
Path(f'dataset').mkdir(exist_ok=True)
for label in labels:
Path(f'dataset/{label}').mkdir(exist_ok=True)
not_downloaded = []
def process(x):
global i
i += 1
try:
for label in x['label']:
r = requests.get(x['image'])
with open(
f'dataset/{label["rectanglelabels"][0]}/{Path(x["image"]).name}',
'wb') as f:
f.write(r.content)
except (HTTPError, KeyError) as e:
logger.warning(f'{sys.exc_info()}: {x}')
not_downloaded.append(x)
logger.debug(f'{i}/{len(data)}')
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
results = [executor.submit(process, x) for x in data]
for future in concurrent.futures.as_completed(results):
futures.append(future.result())
with open('not_downloaded.json', 'w') as j:
json.dump(not_downloaded, j, indent=4)
logger.debug(f'Not downloaded:\n{not_downloaded}')
if __name__ == '__main__':
load_dotenv()
data = load_export_data(project_id=1, TOKEN=os.environ['TOKEN'])
i = 0
main(data=data)