From 99d61274d1cd9bf87d71395b4f860b6d6866922a Mon Sep 17 00:00:00 2001 From: fcakyon Date: Fri, 20 Mar 2020 10:34:28 +0300 Subject: [PATCH] added cropped text region export --- file_utils.py | 75 +++++++++++++++----- predict.py | 31 ++++---- test.py | 191 -------------------------------------------------- 3 files changed, 73 insertions(+), 224 deletions(-) delete mode 100755 test.py diff --git a/file_utils.py b/file_utils.py index 4d35507..ec760dd 100644 --- a/file_utils.py +++ b/file_utils.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- import os -import numpy as np import cv2 +import copy +import numpy as np def create_dir(_dir): @@ -12,7 +13,6 @@ def create_dir(_dir): os.makedirs(_dir) -# borrowed from https://github.com/lengstrom/fast-style-transfer/blob/master/src/utils.py def get_files(img_dir): imgs, masks, xmls = list_files(img_dir) return imgs, masks, xmls @@ -41,30 +41,69 @@ def list_files(in_path): return img_files, mask_files, gt_files -def saveResult(img_file, - img, - boxes, - dirname='./result/', - verticals=None, - texts=None): +def export_detected_region(image, points, file_path): + # points should have 1*4*2 shape + if len(points.shape) == 2: + points = np.array([np.array(points).astype(np.int32)]) + + # create mask with shape of image + mask = np.zeros(image.shape[0:2], dtype=np.uint8) + + # method 1 smooth region + cv2.drawContours(mask, [points], -1, (255, 255, 255), -1, cv2.LINE_AA) + + # method 2 not so smooth region + # cv2.fillPoly(mask, points, (255)) + + res = cv2.bitwise_and(image, image, mask=mask) + rect = cv2.boundingRect(points) # returns (x,y,w,h) of the rect + cropped = res[rect[1]: rect[1] + rect[3], rect[0]: rect[0] + rect[2]] + + # export corpped region + cv2.imwrite(file_path, cropped) + + +def export_detected_regions(image_path, image, polys, + output_dir: str = "output/"): + # deepcopy image so that original is not altered + image = copy.deepcopy(image) + + # get file name + file_name, file_ext = os.path.splitext(os.path.basename(image_path)) + + # create crops dir + crops_dir = os.path.join(output_dir, file_name + "_crops") + create_dir(crops_dir) + + for ind, poly in enumerate(polys): + file_path = os.path.join(crops_dir, "crop_" + str(ind) + ".png") + export_detected_region(image, points=poly, file_path=file_path) + + +def export_extra_results(image_path, + image, + boxes, + output_dir='output/', + verticals=None, + texts=None): """ save text detection result one by one Args: - img_file (str): image file name - img (array): raw image context + image_path (str): image file name + image (array): raw image context boxes (array): array of result file Shape: [num_detections, 4] for BB output / [num_detections, 4] for QUAD output Return: None """ - img = np.array(img) + image = np.array(image) # make result file list - filename, file_ext = os.path.splitext(os.path.basename(img_file)) + filename, file_ext = os.path.splitext(os.path.basename(image_path)) # result directory - res_file = dirname + "res_" + filename + '.txt' - res_img_file = dirname + "res_" + filename + '.jpg' + res_file = output_dir + "res_" + filename + '.txt' + res_img_file = output_dir + "res_" + filename + '.jpg' with open(res_file, 'w') as f: for i, box in enumerate(boxes): @@ -73,7 +112,7 @@ def saveResult(img_file, f.write(strResult) poly = poly.reshape(-1, 2) - cv2.polylines(img, + cv2.polylines(image, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), @@ -86,13 +125,13 @@ def saveResult(img_file, if texts is not None: font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 - cv2.putText(img, "{}".format(texts[i]), + cv2.putText(image, "{}".format(texts[i]), (poly[0][0]+1, poly[0][1]+1), font, font_scale, (0, 0, 0), thickness=1) - cv2.putText(img, + cv2.putText(image, "{}".format(texts[i]), tuple(poly[0]), font, @@ -101,4 +140,4 @@ def saveResult(img_file, thickness=1) # Save result image - cv2.imwrite(res_img_file, img) + cv2.imwrite(res_img_file, image) diff --git a/predict.py b/predict.py index 64b4aaf..12d1730 100755 --- a/predict.py +++ b/predict.py @@ -1,9 +1,3 @@ -""" -Copyright (c) 2019-present NAVER Corp. -MIT License -""" - -# -*- coding: utf-8 -*- import os import time @@ -154,7 +148,8 @@ def detect_text(image_path: str, mag_ratio: float = 1.5, poly: bool = False, show_time: bool = False, - refiner: bool = False): + refiner: bool = False, + export_extra: bool = True): """ Arguments: image_path: path to the image to be processed @@ -168,6 +163,7 @@ def detect_text(image_path: str, poly: enable polygon type show_time: show processing time refiner: enable link refiner + export_extra: export score map, detection points, box visualization """ # create output dir @@ -194,14 +190,19 @@ def detect_text(image_path: str, poly, show_time) - # save score text - filename, file_ext = os.path.splitext(os.path.basename(image_path)) - mask_file = os.path.join(output_dir, "res_" + filename + '_mask.jpg') - cv2.imwrite(mask_file, score_text) + # export detected text regions + file_utils.export_detected_regions(image_path, image, polys, output_dir) + + if export_extra: + # export score map + filename, file_ext = os.path.splitext(os.path.basename(image_path)) + mask_file = os.path.join(output_dir, "res_" + filename + '_mask.jpg') + cv2.imwrite(mask_file, score_text) - file_utils.saveResult(image_path, - image[:, :, ::-1], - polys, - dirname=output_dir) + # export detected points and box visualization + file_utils.export_extra_results(image_path, + image[:, :, ::-1], + polys, + output_dir=output_dir) print("elapsed time : {}s".format(time.time() - t)) diff --git a/test.py b/test.py deleted file mode 100755 index f9fea90..0000000 --- a/test.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -Copyright (c) 2019-present NAVER Corp. -MIT License -""" - -# -*- coding: utf-8 -*- -import os -import time -import argparse - -import torch -import torch.backends.cudnn as cudnn -from torch.autograd import Variable - -import cv2 -import numpy as np -import craft_utils -import imgproc -import file_utils - -from craft import CRAFT - -from collections import OrderedDict - - -def copyStateDict(state_dict): - if list(state_dict.keys())[0].startswith("module"): - start_idx = 1 - else: - start_idx = 0 - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - name = ".".join(k.split(".")[start_idx:]) - new_state_dict[name] = v - return new_state_dict - - -def str2bool(v): - return v.lower() in ("yes", "y", "true", "t", "1") - - -parser = argparse.ArgumentParser(description='CRAFT Text Detection') -parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model') -parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold') -parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score') -parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold') -parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference') -parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference') -parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio') -parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type') -parser.add_argument('--show_time', default=False, action='store_true', help='show processing time') -parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images') -parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner') -parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model') - -args = parser.parse_args() - - -""" For test images in a folder """ -image_list, _, _ = file_utils.get_files(args.test_folder) - -result_folder = './result/' -if not os.path.isdir(result_folder): - os.mkdir(result_folder) - - -def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, - refine_net=None): - t0 = time.time() - - # resize - img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( - image, args.canvas_size, interpolation=cv2.INTER_LINEAR, - mag_ratio=args.mag_ratio) - ratio_h = ratio_w = 1 / target_ratio - - # preprocessing - x = imgproc.normalizeMeanVariance(img_resized) - x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] - x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] - if cuda: - x = x.cuda() - - # forward pass - with torch.no_grad(): - y, feature = net(x) - - # make score and link map - score_text = y[0, :, :, 0].cpu().data.numpy() - score_link = y[0, :, :, 1].cpu().data.numpy() - - # refine link - if refine_net is not None: - with torch.no_grad(): - y_refiner = refine_net(y, feature) - score_link = y_refiner[0, :, :, 0].cpu().data.numpy() - - t0 = time.time() - t0 - t1 = time.time() - - # Post-processing - boxes, polys = craft_utils.getDetBoxes( - score_text, score_link, text_threshold, link_threshold, low_text, - poly) - - # coordinate adjustment - boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) - polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) - for k in range(len(polys)): - if polys[k] is None: - polys[k] = boxes[k] - - t1 = time.time() - t1 - - # render results (optional) - render_img = score_text.copy() - render_img = np.hstack((render_img, score_link)) - ret_score_text = imgproc.cvt2HeatmapImg(render_img) - - if args.show_time: - print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) - - return boxes, polys, ret_score_text - - -if __name__ == '__main__': - # load net - net = CRAFT() # initialize - - print('Loading weights from checkpoint (' + args.trained_model + ')') - if args.cuda: - net.load_state_dict(copyStateDict(torch.load(args.trained_model))) - else: - net.load_state_dict(copyStateDict(torch.load(args.trained_model, - map_location='cpu'))) - - if args.cuda: - net = net.cuda() - net = torch.nn.DataParallel(net) - cudnn.benchmark = False - - net.eval() - - # LinkRefiner - refine_net = None - if args.refine: - from refinenet import RefineNet - refine_net = RefineNet() - print('Loading weights of refiner from checkpoint (' + - args.refiner_model + ')') - if args.cuda: - refine_net.load_state_dict( - copyStateDict(torch.load(args.refiner_model))) - refine_net = refine_net.cuda() - refine_net = torch.nn.DataParallel(refine_net) - else: - refine_net.load_state_dict( - copyStateDict(torch.load(args.refiner_model, - map_location='cpu'))) - - refine_net.eval() - args.poly = True - - t = time.time() - - # load data - for k, image_path in enumerate(image_list): - print("Test image {:d}/{:d}: {:s}" - .format(k+1, len(image_list), image_path), end='\r') - image = imgproc.loadImage(image_path) - - bboxes, polys, score_text = test_net(net, - image, - args.text_threshold, - args.link_threshold, - args.low_text, - args.cuda, - args.poly, - refine_net) - - # save score text - filename, file_ext = os.path.splitext(os.path.basename(image_path)) - mask_file = result_folder + "/res_" + filename + '_mask.jpg' - cv2.imwrite(mask_file, score_text) - - file_utils.saveResult(image_path, - image[:, :, ::-1], - polys, - dirname=result_folder) - - print("elapsed time : {}s".format(time.time() - t))