-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathengine.py
51 lines (39 loc) · 1.66 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: CC-BY-NC-4.0
import time
import torch
from utils import utils
def validate(val_loader, model, criterion, args):
batch_time = utils.AverageMeter('Time', ':6.3f')
losses = utils.AverageMeter('Loss', ':.4e')
top1 = utils.AverageMeter('Acc@1', ':6.2f')
top5 = utils.AverageMeter('Acc@5', ':6.2f')
progress = utils.ProgressMeter(
len(val_loader),
[batch_time, losses, top1, top5],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# if i % args.print_freq == 0:
# progress.display(i)
sum1, cnt1, sum5, cnt5 = utils.torch_dist_sum(args.gpu, top1.sum, top1.count, top5.sum, top5.count)
top1_acc = sum(sum1.float()) / sum(cnt1.float())
top5_acc = sum(sum5.float()) / sum(cnt5.float())
return top1_acc, top5_acc