-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
78 changed files
with
12,198 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,95 @@ | ||
# Actor-Critic Instance Segmentation | ||
Official implementation of the actor-critic model (Lua Torch) to accompany our paper | ||
|
||
This repository contains implementation of the actor-critic model for instance segmentation. | ||
> Nikita Araslanov, Constantin Rothkopf and Stefan Roth, **Actor-Critic Instance Segmentation**, CVPR 2019. | ||
Preprint: [arXiv:1904.05126](https://arxiv.org/abs/1904.05126) | ||
- ArXiv preprint: [https://arxiv.org/abs/1904.05126](https://arxiv.org/abs/1904.05126) | ||
- CVPR 2019 proceedings: [PDF](http://openaccess.thecvf.com/content_CVPR_2019/papers/Araslanov_Actor-Critic_Instance_Segmentation_CVPR_2019_paper.pdf) | ||
- Contact: Nikita Araslanov, [fname.lname]@visinf.tu-darmstadt.de | ||
|
||
Contact Email: [@ Nikita Araslanov](mailto:[email protected]) | ||
## Installation | ||
We tested the code with ```Lua 5.2.4```, ```CUDA 8.0``` and ```cuDNN-5.1``` on Ubuntu 16.04. | ||
- To install Lua Torch, please, follow the [official docs](http://torch.ch/docs/getting-started.html#_) | ||
- CuDNN 5.1 and CUDA 8.0 are available on the the [nvidia website](https://developer.nvidia.com) | ||
- Compile the Hungarian algorithm implementation (```hungarian.c```): | ||
``` | ||
cd <acis>/code && make | ||
``` | ||
|
||
<hr> | ||
## Running the code | ||
### Training | ||
Training the actor-critic model from scratch consists of four steps: | ||
1) Training the pre-processing network to predict angle quantisation | ||
2) Generating augmented data with the official data from the pre-processing network | ||
3) Pre-training the decoder part of the actor-critic model | ||
4) Training actor-critic on sequential prediction length | ||
|
||
Code is currently in preparation. | ||
These steps are described in detail for CVPPP dataset next. You can download the dataset from the [offical website](https://www.plant-phenotyping.org/datasets-download). For instance segmentation, we use only A1 subset of the dataset, 128 annotated images in total. | ||
|
||
The commands referenced below assume ```<ACIS>/code``` as the current directory, unless mentioned otherwise. | ||
|
||
#### Training pre-processing network | ||
As described in our paper, the actor-critic model uses angle quantisation [1] and the foreground mask, following Ren & Zemel [2]. | ||
The pre-processing network is an FCN and can be trained using ```main_preproc.lua```. | ||
A bash script ```runs/cvppp_preproc.sh``` contains an example command for training the network. | ||
Running | ||
``` | ||
./runs/cvppp_preproc.sh | ||
``` | ||
will create a directory ```checkpoints/cvppp_preproc```, where intermediate checkpoints, validation results and the training progress will be logged. | ||
|
||
> [1] Uhrig J., Cordts M., Franke U., and Brox T. Pixel-level encoding and depth layering for instance-level semantic | ||
labeling. In GCPR, 2016.<br> | ||
> [2] Ren M. and Zemel R. End-to-end instance segmentation and counting with recurrent attention. In CVPR, 2017. | ||
#### Generating augmented data | ||
Instead of keeping the pre-processing net around while training the actor-critic model, we will generate the axiliary data (*with augmentation*) using the pre-processing net: the angles and the foreground. This will save GPU memory and improve runtime at the expense of some disk space. | ||
``` | ||
./runs/cvppp_preproc_save.sh | ||
``` | ||
The script will iterate through the dataset (300 epochs for CVPPP), each time with random augmentation switched on, such as rotation and flipping. You can change the amount of data generated using parameter ```-preproc_epoch``` (see ```cvppp_preproc_save.sh```). | ||
By default, you should account for 80GB of generated data (disk space). The data will be saved into ```data/cvppp/A1_RAW/train/augm``` and ```data/cvppp/A1_RAW/val/augm```. For the next steps, please, move the augmented data into ```data/cvppp/A1_AUG/train``` and ```data/cvppp/A1_AUG/val/```: | ||
``` | ||
mv data/cvppp/A1_RAW/train/augm/* data/cvppp/A1_AUG/train/ | ||
mv data/cvppp/A1_RAW/val/augm/* data/cvppp/A1_AUG/val/ | ||
``` | ||
|
||
#### Pre-training | ||
The purpose of the pre-training stage is to learn a compact representation for masks (action space). The training is equivalent to Variational Auto-Encoder (VAE), where the reconstruction loss is computed for one target mask. | ||
Assuming the augmented data generated in the previous step, run | ||
``` | ||
./runs/cvppp_pretrain.sh | ||
``` | ||
The script will use ```pretrain_main.lua``` to create and train the actor model. It will also reduce the learning rate in stages. After the training, the script will create the checkpoints for the decoder, and the encoder trained to predict one mask, both used in the final training step, described next. | ||
|
||
#### Training | ||
```main.lua``` is the entry script to train the actor-critic model on sequential prediction. The logging is realised with [crayon](https://github.com/torrvision/crayon). Please, follow [this README](https://github.com/arnike/acis_release/blob/master_release/code/README.md) to set it up. | ||
|
||
To train *BL-Trunc* (baseline with truncated backprop), run | ||
``` | ||
./runs/cvppp_train_btrunc.sh | ||
``` | ||
To train *AC* (the actor-critic model), run | ||
``` | ||
./runs/cvppp_train_ac.sh | ||
``` | ||
Both scripts will train the respective models stagewise: the trained sequence length will gradually increase (5, 10, 15) while cutting the learning rate. ```schedules.lua``` contains the training schedule for both models. | ||
|
||
### Using trained models for inference | ||
Directory ```eval/``` contains the code to produce the final results for evaluation. | ||
Put the pre-trained models of the checkpoint for evaluation into ```eval/cvppp/models/<MODEL-ID>```. | ||
Then, run | ||
``` | ||
th cvppp_main.lua -dataIn [DATA-IN] -dataOut [DATA-OUT] -modelIdx <MODEL-ID> | ||
``` | ||
where the parameters in backets should be replaced with your own values. The script will save the final predictions in ```[DATA-OUT]```. | ||
|
||
## Citation | ||
``` | ||
@inproceedings{Araslanov:2019:ACIS, | ||
title = {Actor-Critic Instance Segmentation}, | ||
author = {Araslanov, Nikita and Rothkopf, Constantin and Roth, Stefan}, | ||
booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, | ||
year = {2019} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
*.sw* | ||
checkpoints/ | ||
tensorboard/ | ||
debug/ | ||
media/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
|
||
# Copyright (c) 2016-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. An additional grant | ||
# of patent rights can be found in the PATENTS file in the same directory. | ||
|
||
CC = cc | ||
CFLAGS = -std=c99 -fPIC -Wall -Ofast -c | ||
LDFLAGS = -shared | ||
|
||
opt: lib | ||
|
||
lib: hungarian.c | ||
$(CC) $(CFLAGS) hungarian.c | ||
$(CC) $(LDFLAGS) -o libhungarian.so hungarian.o | ||
|
||
clean: | ||
rm -rf *.o *.so |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
## ACIS/code | ||
|
||
### Crayon | ||
|
||
#### Installation | ||
|
||
Install crayon client: | ||
``` | ||
luarocks install crayon OPENSSL_INCDIR=~/.linuxbrew/include OPENSSL_DIR=~/.linuxbrew | ||
``` | ||
|
||
A slightly modified crayon server is already supplied in ```<acis>/code/crayon/server``` | ||
The modification simply includes additional arguments: | ||
- ```--tb-port``` specified the tensorboard port; | ||
- ```--logdir``` specified the log directory of the tensorboard server. | ||
|
||
#### Run | ||
Go to ```<acis>/code```. Then, | ||
|
||
1. start tensorboard | ||
``` | ||
nohup tensorboard --logdir tensorboard --port 6038 > tensorboard/tb.log 2>&1 & | ||
``` | ||
2. start crayon | ||
``` | ||
nohup python crayon/server/server.py --port 6039 --logdir tensorboard --tb-port 6038 > tensorboard/crayon.log 2>&1 & | ||
``` | ||
|
||
The crayon can now be accessed through port 6039. | ||
It will in turn access tensorboard via port 6038 and save data in ```tensorboard``` directory. | ||
|
||
#### Managing experiments | ||
Session example: | ||
|
||
```lua | ||
-- initialise the client | ||
cc = crayon.CrayonClient("localhost", 6039) | ||
|
||
-- get a list of experiment names | ||
cc:get_experiment_names() | ||
|
||
-- remove experiments | ||
cc:remove_experiment("cvppp_600_btrunc/train") | ||
cc:remove_experiment("cvppp_600_btrunc/train_val") | ||
cc:remove_experiment("cvppp_600_btrunc/val") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
analysis = require 'analysis' | ||
|
||
local M = {} | ||
local Analyser = torch.class('ris.Analyser', M) | ||
|
||
function Analyser.create() | ||
-- The train and val loader | ||
return M.Analyser() | ||
end | ||
|
||
function Analyser:__init() | ||
self.size = 0 | ||
self.criteria_stat = {} | ||
self.criteria_fn = {} | ||
end | ||
|
||
function Analyser:addMetric(name) | ||
|
||
if name == 'sbd' then | ||
self.criteria_fn[name] = analysis.sbd | ||
else | ||
error("No metric found with name " .. name) | ||
end | ||
|
||
self.criteria_stat[name] = {} | ||
end | ||
|
||
function Analyser:updateStat(prediction, gt_segments) | ||
for key,value in pairs(self.criteria_fn) do | ||
table.insert(self.criteria_stat[key], value(prediction, gt_segments)) | ||
end | ||
self.size = self.size + 1 | ||
end | ||
|
||
function Analyser:printStat() | ||
for key,value in pairs(self.criteria_stat) do | ||
local criteria_sum = 0 | ||
for n = 1,self.size do | ||
criteria_sum = criteria_sum + value[n] | ||
end | ||
print(key .. ": ", criteria_sum / self.size) | ||
end | ||
end | ||
|
||
return M |
Oops, something went wrong.