Skip to content

Commit

Permalink
refactored for modular prediction pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
fcakyon committed Mar 19, 2020
1 parent bd7a2dd commit e051841
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 20 deletions.
2 changes: 1 addition & 1 deletion craft.py → craftnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
"""
Copyright (c) 2019-present NAVER Corp.
MIT License
"""
Expand Down
11 changes: 8 additions & 3 deletions file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
import cv2


def create_dir(_dir):
"""
Creates given directory if it is not present.
"""
if not os.path.exists(_dir):
os.makedirs(_dir)


# borrowed from https://github.com/lengstrom/fast-style-transfer/blob/master/src/utils.py
def get_files(img_dir):
imgs, masks, xmls = list_files(img_dir)
Expand Down Expand Up @@ -58,9 +66,6 @@ def saveResult(img_file,
res_file = dirname + "res_" + filename + '.txt'
res_img_file = dirname + "res_" + filename + '.jpg'

if not os.path.isdir(dirname):
os.mkdir(dirname)

with open(res_file, 'w') as f:
for i, box in enumerate(boxes):
poly = np.array(box).astype(np.int32).reshape((-1))
Expand Down
41 changes: 28 additions & 13 deletions imgproc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
"""
Copyright (c) 2019-present NAVER Corp.
MIT License
"""
Expand All @@ -8,24 +8,38 @@
from skimage import io
import cv2


def loadImage(img_file):
img = io.imread(img_file) # RGB order
if img.shape[0] == 2: img = img[0]
if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
if img.shape[2] == 4: img = img[:,:,:3]
img = io.imread(img_file) # RGB order
if img.shape[0] == 2:
img = img[0]
if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
if img.shape[2] == 4:
img = img[:, :, :3]
img = np.array(img)

return img

def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):

def normalizeMeanVariance(in_img,
mean=(0.485, 0.456, 0.406),
variance=(0.229, 0.224, 0.225)):
# should be RGB order
img = in_img.copy().astype(np.float32)

img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32)
img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32)
img -= np.array([mean[0] * 255.0,
mean[1] * 255.0,
mean[2] * 255.0], dtype=np.float32)
img /= np.array([variance[0] * 255.0,
variance[1] * 255.0,
variance[2] * 255.0], dtype=np.float32)
return img

def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):

def denormalizeMeanVariance(in_img,
mean=(0.485, 0.456, 0.406),
variance=(0.229, 0.224, 0.225)):
# should be RGB order
img = in_img.copy()
img *= variance
Expand All @@ -34,6 +48,7 @@ def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229,
img = np.clip(img, 0, 255).astype(np.uint8)
return img


def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):
height, width, channel = img.shape

Expand All @@ -43,12 +58,11 @@ def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):
# set original image size
if target_size > square_size:
target_size = square_size

ratio = target_size / max(height, width)

target_h, target_w = int(height * ratio), int(width * ratio)
proc = cv2.resize(img, (target_w, target_h), interpolation = interpolation)
ratio = target_size / max(height, width)

target_h, target_w = int(height * ratio), int(width * ratio)
proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation)

# make canvas and paste image
target_h32, target_w32 = target_h, target_w
Expand All @@ -64,6 +78,7 @@ def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):

return resized, ratio, size_heatmap


def cvt2HeatmapImg(img):
img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
Expand Down
207 changes: 207 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
Copyright (c) 2019-present NAVER Corp.
MIT License
"""

# -*- coding: utf-8 -*-
import os
import time

import torch
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

import cv2
import numpy as np

import craft_utils as craft_utils
import imgproc as imgproc
import file_utils as file_utils
from craftnet import CRAFT

from collections import OrderedDict


def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = ".".join(k.split(".")[start_idx:])
new_state_dict[name] = v
return new_state_dict


def str2bool(v):
return v.lower() in ("yes", "y", "true", "t", "1")


def load_weights(net, model_path: str, cuda: bool = False):
if cuda:
net.load_state_dict(copyStateDict(torch.load(model_path)))

net = net.cuda()
net = torch.nn.DataParallel(net)
cudnn.benchmark = False
else:
net.load_state_dict(copyStateDict(torch.load(model_path,
map_location='cpu')))
net.eval()
return net


def get_models(cuda: bool = False, refiner: bool = False):
# load craft net
model_path = os.path.join("weights", "craft_mlt_25k.pth")
# load craft net
craft_net = CRAFT() # initialize
# arange device
craft_net = load_weights(craft_net, model_path, cuda)

# load refine net
if refiner:
model_path = os.path.join("weights", "craft_refiner_CTW1500.pth")
# load net
from refinenet import RefineNet
refine_net = RefineNet()
# arange device
refine_net = load_weights(refine_net, model_path, cuda)
poly = True
else:
refine_net = None
poly = False

return craft_net, refine_net, poly


def get_prediction(craft_net,
refine_net,
image,
text_threshold: float = 0.7,
link_threshold: float = 0.4,
low_text: float = 0.4,
cuda: bool = False,
canvas_size: int = 1280,
mag_ratio: float = 1.5,
poly: bool = False,
show_time: bool = False):
t0 = time.time()

# resize
img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
image, canvas_size, interpolation=cv2.INTER_LINEAR,
mag_ratio=mag_ratio)
ratio_h = ratio_w = 1 / target_ratio

# preprocessing
x = imgproc.normalizeMeanVariance(img_resized)
x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
if cuda:
x = x.cuda()

# forward pass
with torch.no_grad():
y, feature = craft_net(x)

# make score and link map
score_text = y[0, :, :, 0].cpu().data.numpy()
score_link = y[0, :, :, 1].cpu().data.numpy()

# refine link
if refine_net is not None:
with torch.no_grad():
y_refiner = refine_net(y, feature)
score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

t0 = time.time() - t0
t1 = time.time()

# Post-processing
boxes, polys = craft_utils.getDetBoxes(
score_text, score_link, text_threshold, link_threshold, low_text,
poly)

# coordinate adjustment
boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
for k in range(len(polys)):
if polys[k] is None:
polys[k] = boxes[k]

t1 = time.time() - t1

# render results (optional)
render_img = score_text.copy()
render_img = np.hstack((render_img, score_link))
ret_score_text = imgproc.cvt2HeatmapImg(render_img)

if show_time:
print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

return boxes, polys, ret_score_text


def detect_text(image_path: str,
output_dir: str = "output/",
text_threshold: float = 0.7,
link_threshold: float = 0.4,
low_text: float = 0.4,
cuda: bool = False,
canvas_size: int = 1280,
mag_ratio: float = 1.5,
poly: bool = False,
show_time: bool = False,
refiner: bool = False):
"""
Arguments:
image_path: path to the image to be processed
output_dir: path to the results to be exported
text_threshold: text confidence threshold
low_text: text low-bound score
link_threshold: link confidence threshold
cuda: Use cuda for inference
canvas_size: image size for inference
mag_ratio: image magnification ratio
poly: enable polygon type
show_time: show processing time
refiner: enable link refiner
"""

# create output dir
file_utils.create_dir(output_dir)

# get models
craft_net, refine_net, poly = get_models(cuda, refiner)

t = time.time()

# load image
image = imgproc.loadImage(image_path)

# perform text detection
bboxes, polys, score_text = get_prediction(craft_net,
refine_net,
image,
text_threshold,
link_threshold,
low_text,
cuda,
canvas_size,
mag_ratio,
poly,
show_time)

# save score text
filename, file_ext = os.path.splitext(os.path.basename(image_path))
mask_file = os.path.join(output_dir, "res_" + filename + '_mask.jpg')
cv2.imwrite(mask_file, score_text)

file_utils.saveResult(image_path,
image[:, :, ::-1],
polys,
dirname=output_dir)

print("elapsed time : {}s".format(time.time() - t))
4 changes: 1 addition & 3 deletions refinenet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""
"""
Copyright (c) 2019-present NAVER Corp.
MIT License
"""

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from basenet.vgg16_bn import init_weights


Expand Down

0 comments on commit e051841

Please sign in to comment.