-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathpred_main.py
21 lines (18 loc) · 1016 Bytes
/
pred_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from predictor import predict
import argparse
# This can run with only predictor dataloader and networks libraries
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--predict", action='store_true', default=False)
parser.add_argument("--visualize_coco", action='store_true', default=False)
parser.add_argument("--checkpoint_path", type=str, default='/home/noam/ZazuML/best_checkpoint.pt')
parser.add_argument("--dataset_path", type=str, default='')
parser.add_argument("--output_path", type=str, default='')
args = parser.parse_args()
if args.predict:
predict(pred_on_path=args.dataset_path, output_path=args.output_path,
checkpoint_path=args.checkpoint_path, threshold=0.5)
if args.visualize_coco:
from dataloaders import CocoDataset
dataset = CocoDataset(args.dataset_path, set_name='train')
dataset.visualize(args.output_path)