Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Jan 30, 2018
2 parents 48d53a2 + 87ee8e1 commit 7eed65b
Show file tree
Hide file tree
Showing 17 changed files with 96 additions and 182 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# YOLO-v2: Real-Time Object Detection

Still under development. 71 mAP on VOC2007 achieved so far.
Still under development. 71 mAP(darknet) and 74mAP(resnet50) on VOC2007 achieved so far.

This is a pre-released version.

Expand All @@ -21,14 +21,15 @@ custom operators are not presented in official MXNet. [Instructions](http://mxne
- Download the pretrained [model](https://github.com/zhreshold/mxnet-yolo/releases/download/0.1-alpha/yolo2_darknet19_416_pascalvoc0712_trainval.zip), and extract to `model/` directory.
- Run
```
# cd /paht/to/mxnet-yolo
# cd /path/to/mxnet-yolo
python demo.py --cpu
# available options
python demo.py -h
```

### Train the model
- Grab a pretrained model, e.g. [`darknet19`](https://github.com/zhreshold/mxnet-yolo/releases/download/0.1-alpha/darknet19_416_ILSVRC2012.zip)
- (optional) Grab a pretrained resnet50 model, [`resnet-50-0000.params`](http://data.dmlc.ml/models/imagenet/resnet/50-layers/resnet-50-0000.params),[`resnet-50-symbol.json`](http://data.dmlc.ml/models/imagenet/resnet/50-layers/resnet-50-symbol.json), this will produce slightly better mAP than `darknet` in my experiments.
- Download PASCAL VOC dataset.
```
cd /path/to/where_you_store_datasets/
Expand All @@ -52,4 +53,8 @@ python tools/prepare_dataset.py --dataset pascal --year 2007 --set test --target
- Start training
```
python train.py --gpus 0,1,2,3 --epoch 0
# choose different networks, such as resnet50_yolo
python train.py --gpus 0,1,2,3 --network resnet50_yolo --data-shape 416 --pretrained model/resnet-50 --epoch 0
# see advanced arguments for training
python train.py -h
```
2 changes: 1 addition & 1 deletion config/config.py → config/default_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from utils import DotDict, namedtuple_with_defaults, zip_namedtuple, config_as_dict
from config.utils import DotDict, namedtuple_with_defaults, zip_namedtuple, config_as_dict

RandCropper = namedtuple_with_defaults('RandCropper',
'min_crop_scales, max_crop_scales, \
Expand Down
2 changes: 1 addition & 1 deletion dataset/concat_db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from imdb import Imdb
from dataset.imdb import Imdb
import random

class ConcatDB(Imdb):
Expand Down
2 changes: 1 addition & 1 deletion dataset/pascal_voc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import print_function
import os
import numpy as np
from imdb import Imdb
from dataset.imdb import Imdb
import xml.etree.ElementTree as ET
from evaluate.eval_voc import voc_eval
import cv2
Expand Down
2 changes: 1 addition & 1 deletion dataset/testdb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from imdb import Imdb
from dataset.imdb import Imdb


class TestDB(Imdb):
Expand Down
2 changes: 1 addition & 1 deletion dataset/yolo_format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import numpy as np
from imdb import Imdb
from dataset.imdb import Imdb


class YoloFormat(Imdb):
Expand Down
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx,
"""
sys.path.append(os.path.join(os.getcwd(), 'symbol'))
if net is not None:
prefix = prefix + "_" + net.strip('_yolo') + '_' + str(416)
prefix = prefix + "_" + net.strip('_yolo') + '_' + str(data_shape)
net = importlib.import_module("symbol_" + net) \
.get_symbol(len(CLASSES), nms_thresh, force_nms)
detector = Detector(net, prefix, epoch, \
Expand Down
8 changes: 4 additions & 4 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def parse_args():
default=os.path.join(os.getcwd(), 'data', 'val.rec'), type=str)
parser.add_argument('--list-path', dest='list_path', help='which list file to use',
default="", type=str)
parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300',
choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use')
parser.add_argument('--network', dest='network', type=str, default='resnet50_yolo',
help='which network to use')
parser.add_argument('--batch-size', dest='batch_size', type=int, default=32,
help='evaluation batch size')
parser.add_argument('--num-class', dest='num_class', type=int, default=20,
Expand All @@ -28,12 +28,12 @@ def parse_args():
parser.add_argument('--epoch', dest='epoch', help='epoch of pretrained model',
default=0, type=int)
parser.add_argument('--prefix', dest='prefix', help='load model prefix',
default=os.path.join(os.getcwd(), 'model', 'ssd'), type=str)
default=os.path.join(os.getcwd(), 'model', 'yolo2_resnet50'), type=str)
parser.add_argument('--gpus', dest='gpu_id', help='GPU devices to evaluate with',
default='0', type=str)
parser.add_argument('--cpu', dest='cpu', help='use cpu to evaluate, this can be slow',
action='store_true')
parser.add_argument('--data-shape', dest='data_shape', type=int, default=300,
parser.add_argument('--data-shape', dest='data_shape', type=int, default=416,
help='set image shape')
parser.add_argument('--mean-r', dest='mean_r', type=float, default=123,
help='red mean value')
Expand Down
8 changes: 4 additions & 4 deletions evaluate/evaluate_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import importlib
import mxnet as mx
from dataset.iterator import DetRecordIter
from config.config import cfg
from config.default_config import cfg
from evaluate.eval_metric import MApMetric, VOC07MApMetric
import logging

Expand Down Expand Up @@ -74,12 +74,12 @@ class names in string, must correspond to num_classes if set
sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol'))
net = importlib.import_module("symbol_" + net) \
.get_symbol(num_classes, nms_thresh, force_nms)
if not 'label' in net.list_arguments():
label = mx.sym.Variable(name='label')
if not 'yolo_output_label' in net.list_arguments():
label = mx.sym.Variable(name='yolo_output_label')
net = mx.sym.Group([net, label])

# init module
mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
mod = mx.mod.Module(net, label_names=('yolo_output_label',), logger=logger, context=ctx,
fixed_param_names=net.list_arguments())
mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label)
mod.set_params(args, auxs, allow_missing=False, force_init=True)
Expand Down
2 changes: 1 addition & 1 deletion mxnet
Submodule mxnet updated 1962 files
156 changes: 9 additions & 147 deletions symbol/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,150 +37,12 @@ def conv_act_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), \
name="{}{}".format(act_type, name))
return relu

def multibox_layer(from_layers, num_classes, sizes=[.2, .95],
ratios=[1], normalization=-1, num_channels=[],
clip=True, interm_layer=0, steps=[]):
"""
the basic aggregation module for SSD detection. Takes in multiple layers,
generate multiple object detection targets by customized layers
Parameters:
----------
from_layers : list of mx.symbol
generate multibox detection from layers
num_classes : int
number of classes excluding background, will automatically handle
background in this function
sizes : list or list of list
[min_size, max_size] for all layers or [[], [], []...] for specific layers
ratios : list or list of list
[ratio1, ratio2...] for all layers or [[], [], ...] for specific layers
normalizations : int or list of int
use normalizations value for all layers or [...] for specific layers,
-1 indicate no normalizations and scales
num_channels : list of int
number of input layer channels, used when normalization is enabled, the
length of list should equals to number of normalization layers
clip : bool
whether to clip out-of-image boxes
interm_layer : int
if > 0, will add a intermediate Convolution layer
steps : list
specify steps for each MultiBoxPrior layer, leave empty, it will calculate
according to layer dimensions
Returns:
----------
list of outputs, as [loc_preds, cls_preds, anchor_boxes]
loc_preds : localization regression prediction
cls_preds : classification prediction
anchor_boxes : generated anchor boxes
"""
assert len(from_layers) > 0, "from_layers must not be empty list"
assert num_classes > 0, \
"num_classes {} must be larger than 0".format(num_classes)

assert len(ratios) > 0, "aspect ratios must not be empty list"
if not isinstance(ratios[0], list):
# provided only one ratio list, broadcast to all from_layers
ratios = [ratios] * len(from_layers)
assert len(ratios) == len(from_layers), \
"ratios and from_layers must have same length"

assert len(sizes) > 0, "sizes must not be empty list"
if len(sizes) == 2 and not isinstance(sizes[0], list):
# provided size range, we need to compute the sizes for each layer
assert sizes[0] > 0 and sizes[0] < 1
assert sizes[1] > 0 and sizes[1] < 1 and sizes[1] > sizes[0]
tmp = np.linspace(sizes[0], sizes[1], num=(len(from_layers)-1))
min_sizes = [start_offset] + tmp.tolist()
max_sizes = tmp.tolist() + [tmp[-1]+start_offset]
sizes = zip(min_sizes, max_sizes)
assert len(sizes) == len(from_layers), \
"sizes and from_layers must have same length"

if not isinstance(normalization, list):
normalization = [normalization] * len(from_layers)
assert len(normalization) == len(from_layers)

assert sum(x > 0 for x in normalization) == len(num_channels), \
"must provide number of channels for each normalized layer"

if steps:
assert len(steps) == len(from_layers), "provide steps for all layers or leave empty"

loc_pred_layers = []
cls_pred_layers = []
anchor_layers = []
num_classes += 1 # always use background as label 0

for k, from_layer in enumerate(from_layers):
from_name = from_layer.name
# normalize
if normalization[k] > 0:
from_layer = mx.symbol.L2Normalization(data=from_layer, \
mode="channel", name="{}_norm".format(from_name))
scale = mx.symbol.Variable(name="{}_scale".format(from_name),
shape=(1, num_channels.pop(0), 1, 1),
init=mx.init.Constant(normalization[k]))
from_layer = mx.symbol.broadcast_mul(lhs=scale, rhs=from_layer)
if interm_layer > 0:
from_layer = mx.symbol.Convolution(data=from_layer, kernel=(3,3), \
stride=(1,1), pad=(1,1), num_filter=interm_layer, \
name="{}_inter_conv".format(from_name))
from_layer = mx.symbol.Activation(data=from_layer, act_type="relu", \
name="{}_inter_relu".format(from_name))

# estimate number of anchors per location
# here I follow the original version in caffe
# TODO: better way to shape the anchors??
size = sizes[k]
assert len(size) > 0, "must provide at least one size"
size_str = "(" + ",".join([str(x) for x in size]) + ")"
ratio = ratios[k]
assert len(ratio) > 0, "must provide at least one ratio"
ratio_str = "(" + ",".join([str(x) for x in ratio]) + ")"
num_anchors = len(size) -1 + len(ratio)

# create location prediction layer
num_loc_pred = num_anchors * 4
bias = mx.symbol.Variable(name="{}_loc_pred_conv_bias".format(from_name),
init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0'})
loc_pred = mx.symbol.Convolution(data=from_layer, bias=bias, kernel=(3,3), \
stride=(1,1), pad=(1,1), num_filter=num_loc_pred, \
name="{}_loc_pred_conv".format(from_name))
loc_pred = mx.symbol.transpose(loc_pred, axes=(0,2,3,1))
loc_pred = mx.symbol.Flatten(data=loc_pred)
loc_pred_layers.append(loc_pred)

# create class prediction layer
num_cls_pred = num_anchors * num_classes
bias = mx.symbol.Variable(name="{}_cls_pred_conv_bias".format(from_name),
init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0'})
cls_pred = mx.symbol.Convolution(data=from_layer, bias=bias, kernel=(3,3), \
stride=(1,1), pad=(1,1), num_filter=num_cls_pred, \
name="{}_cls_pred_conv".format(from_name))
cls_pred = mx.symbol.transpose(cls_pred, axes=(0,2,3,1))
cls_pred = mx.symbol.Flatten(data=cls_pred)
cls_pred_layers.append(cls_pred)

# create anchor generation layer
if steps:
step = (steps[k], steps[k])
else:
step = '(-1.0, -1.0)'
anchors = mx.contrib.symbol.MultiBoxPrior(from_layer, sizes=size_str, ratios=ratio_str, \
clip=clip, name="{}_anchors".format(from_name), steps=step)
anchors = mx.symbol.Flatten(data=anchors)
anchor_layers.append(anchors)

loc_preds = mx.symbol.Concat(*loc_pred_layers, num_args=len(loc_pred_layers), \
dim=1, name="multibox_loc_pred")
cls_preds = mx.symbol.Concat(*cls_pred_layers, num_args=len(cls_pred_layers), \
dim=1)
cls_preds = mx.symbol.Reshape(data=cls_preds, shape=(0, -1, num_classes))
cls_preds = mx.symbol.transpose(cls_preds, axes=(0, 2, 1), name="multibox_cls_pred")
anchor_boxes = mx.symbol.Concat(*anchor_layers, \
num_args=len(anchor_layers), dim=1)
anchor_boxes = mx.symbol.Reshape(data=anchor_boxes, shape=(0, -1, 4), name="multibox_anchors")
return [loc_preds, cls_preds, anchor_boxes]
def stack_neighbor(from_layer, factor=2):
"""Downsample spatial dimentions and collapse to channel dimention by factor"""
out = mx.sym.reshape(from_layer, shape=(0, 0, -4, -1, factor, -2)) # (b, c, h/2, 2, w)
out = mx.sym.transpose(out, axes=(0, 1, 3, 2, 4)) # (b, c, 2, h/2, w)
out = mx.sym.reshape(out, shape=(0, -3, -1, -2)) # (b, c * 2, h/2, w)
out = mx.sym.reshape(out, shape=(0, 0, 0, -4, -1, factor)) # (b, c * 2, h/2, w/2, 2)
out = mx.sym.transpose(out, axes=(0, 1, 4, 2, 3)) # (b, c*2, 2, h/2, w/2)
out = mx.sym.reshape(out, shape=(0, -3, -1, -2)) # (b, c*4, h/2, w/2)
return out
4 changes: 2 additions & 2 deletions symbol/symbol_darknet19_lyolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"https://arxiv.org/pdf/1612.08242.pdf"
"""
import mxnet as mx
from symbol_darknet19 import get_symbol as get_darknet19
from symbol_darknet19 import conv_act_layer
from symbol.symbol_darknet19 import get_symbol as get_darknet19
from symbol.symbol_darknet19 import conv_act_layer

def get_symbol(num_classes=20, nms_thresh=0.5, force_nms=False, **kwargs):
bone = get_darknet19(num_classes=num_classes, **kwargs)
Expand Down
22 changes: 15 additions & 7 deletions symbol/symbol_darknet19_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
"https://arxiv.org/pdf/1612.08242.pdf"
"""
import mxnet as mx
from symbol_darknet19 import get_symbol as get_darknet19
from symbol_darknet19 import conv_act_layer
from symbol.symbol_darknet19 import get_symbol as get_darknet19
from symbol.symbol_darknet19 import conv_act_layer

def get_symbol(num_classes=20, nms_thresh=0.5, force_nms=False, **kwargs):
bone = get_darknet19(num_classes=num_classes, **kwargs)
conv5_5 = bone.get_internals()["conv5_5_output"]
conv6_5 = bone.get_internals()["conv6_5_output"]
conv5_5 = bone.get_internals()["leaky_conv5_5_output"]
conv6_5 = bone.get_internals()["leaky_conv6_5_output"]
# anchors
anchors = [
1.3221, 1.73145,
Expand All @@ -21,14 +21,22 @@ def get_symbol(num_classes=20, nms_thresh=0.5, force_nms=False, **kwargs):
num_anchor = len(anchors) // 2

# extra layers
conv5_6 = conv_act_layer(conv5_5, 'conv5_6', 1024, kernel=(3, 3), pad=(1, 1),
act_type='leaky')
conv7_1 = conv_act_layer(conv6_5, 'conv7_1', 1024, kernel=(3, 3), pad=(1, 1),
act_type='leaky')
conv7_2 = conv_act_layer(conv7_1, 'conv7_2', 1024, kernel=(3, 3), pad=(1, 1),
act_type='leaky')

# re-organze conv5_5 and concat conv7_2
conv5_6 = mx.sym.stack_neighbor(data=conv5_5, kernel=(2, 2), name='stack_downsample')
concat = mx.sym.Concat(*[conv5_6, conv7_2], dim=1)
# re-organze conv5_6 and concat conv7_2
# conv5_7 = mx.sym.stack_neighbor(data=conv5_6, kernel=(2, 2), name='stack_downsample')
conv5_7 = mx.sym.reshape(conv5_6, shape=(0, 0, -4, -1, 2, -2)) # (b, c, h/2, 2, w)
conv5_7 = mx.sym.transpose(conv5_7, axes=(0, 1, 3, 2, 4)) # (b, c, 2, h/2, w)
conv5_7 = mx.sym.reshape(conv5_7, shape=(0, -3, -1, -2)) # (b, c * 2, h/2, w)
conv5_7 = mx.sym.reshape(conv5_7, shape=(0, 0, 0, -4, -1, 2)) # (b, c * 2, h/2, w/2, 2)
conv5_7 = mx.sym.transpose(conv5_7, axes=(0, 1, 4, 2, 3)) # (b, c*2, 2, h/2, w/2)
conv5_7 = mx.sym.reshape(conv5_7, shape=(0, -3, -1, -2)) # (b, c*4, h/2, w/2)
concat = mx.sym.Concat(*[conv5_7, conv7_2], dim=1)
# concat = conv7_2
conv8_1 = conv_act_layer(concat, 'conv8_1', 1024, kernel=(3, 3), pad=(1, 1),
act_type='leaky')
Expand Down
37 changes: 37 additions & 0 deletions symbol/symbol_darknet_syolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Reference:
Redmon, Joseph, and Ali Farhadi. "YOLO9000: Better, Faster, Stronger."
"https://arxiv.org/pdf/1612.08242.pdf"
"""
import mxnet as mx
from symbol.symbol_darknet19 import get_symbol as get_darknet19
from symbol.symbol_darknet19 import conv_act_layer

def get_symbol(num_classes=20, nms_thresh=0.5, force_nms=False, **kwargs):
bone = get_darknet19(num_classes=num_classes, **kwargs)
conv5_5 = bone.get_internals()["leaky_conv5_5_output"]
conv6_5 = bone.get_internals()["leaky_conv6_5_output"]
# anchors
anchors = [
1.3221, 1.73145,
3.19275, 4.00944,
5.05587, 8.09892,
9.47112, 4.84053,
11.2364, 10.0071]
num_anchor = len(anchors) // 2

# extra layers
conv7_1 = conv_act_layer(conv6_5, 'conv7_1', 1024, kernel=(3, 3), pad=(1, 1),
act_type='leaky')
conv7_2 = conv_act_layer(conv7_1, 'conv7_2', 1024, kernel=(3, 3), pad=(1, 1),
act_type='leaky')
conv8_1 = conv_act_layer(conv7_2, 'conv8_1', 1024, kernel=(3, 3), pad=(1, 1),
act_type='leaky')
pred = mx.symbol.Convolution(data=conv8_1, name='conv_pred', kernel=(1, 1),
num_filter=num_anchor * (num_classes + 4 + 1))

out = mx.contrib.symbol.YoloOutput(data=pred, num_class=num_classes,
num_anchor=num_anchor, object_grad_scale=5.0, background_grad_scale=1.0,
coord_grad_scale=1.0, class_grad_scale=1.0, anchors=anchors,
nms_topk=400, warmup_samples=12800, name='yolo_output')
return out
Loading

0 comments on commit 7eed65b

Please sign in to comment.