diff --git a/README.md b/README.md index 0afd3eb..5e6ae93 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,16 @@ print(f"elasp: {elasp}") # # 可视化 ocr 识别框 # plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res) ``` - +#### 偏移修正 +```python +import cv2 +img_path = f'tests/test_files/wired/squeeze_error.jpeg' +from wired_table_rec.utils import ImageOrientationCorrector +img_orientation_corrector = ImageOrientationCorrector() +img = cv2.imread(img_path) +img = img_orientation_corrector(img) +cv2.imwrite(f'img_rotated.jpg', img) +``` ## FAQ (Frequently Asked Questions) 1. **问:偏移的图片能够处理吗?** @@ -101,7 +110,7 @@ print(f"elasp: {elasp}") ### TODO List -- [ ] 识别前图片偏移修正 +- [ ] 识别前图片偏移修正(完成有线表格小角度偏移修正) - [ ] 增加数据集数量,增加更多评测对比 - [ ] 优化无线表格模型 diff --git a/lineless_table_rec/main.py b/lineless_table_rec/main.py index 62a4e6b..cf0116b 100644 --- a/lineless_table_rec/main.py +++ b/lineless_table_rec/main.py @@ -119,10 +119,10 @@ def __call__( def transform_res( self, - cell_box_det_map: dict[int, List[any]], + cell_box_det_map: Dict[int, List[any]], polygons: np.ndarray, logi_points: List[np.ndarray], - ) -> List[dict[str, any]]: + ) -> List[Dict[str, any]]: res = [] for i in range(len(polygons)): ocr_res_list = cell_box_det_map.get(i) diff --git a/lineless_table_rec/utils_table_recover.py b/lineless_table_rec/utils_table_recover.py index 670364b..30921cb 100644 --- a/lineless_table_rec/utils_table_recover.py +++ b/lineless_table_rec/utils_table_recover.py @@ -3,7 +3,7 @@ # @Contact: liekkaskono@163.com import os import random -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Set, Tuple import cv2 import numpy as np @@ -67,7 +67,7 @@ def compute_poly_iou(a: np.ndarray, b: np.ndarray) -> float: return float(inter_area) / union_area -def filter_duplicated_box(table_boxes: list[list[float]]) -> set[int]: +def filter_duplicated_box(table_boxes: List[List[float]]) -> Set[int]: """ :param table_boxes: [[xmin,ymin,xmax,ymax]] :return: @@ -95,7 +95,9 @@ def filter_duplicated_box(table_boxes: list[list[float]]) -> set[int]: return delete_idx -def calculate_iou(box1: list | np.ndarray, box2: list | np.ndarray) -> float: +def calculate_iou( + box1: Union[np.ndarray, List], box2: Union[np.ndarray, List] +) -> float: """ :param box1: Iterable [xmin,ymin,xmax,ymax] :param box2: Iterable [xmin,ymin,xmax,ymax] @@ -127,7 +129,7 @@ def calculate_iou(box1: list | np.ndarray, box2: list | np.ndarray) -> float: def caculate_single_axis_iou( - box1: list | np.ndarray, box2: list | np.ndarray, axis="x" + box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], axis="x" ) -> float: """ :param box1: Iterable [xmin,ymin,xmax,ymax] @@ -151,8 +153,8 @@ def caculate_single_axis_iou( def is_box_contained( - box1: list | np.ndarray, box2: list | np.ndarray, threshold=0.2 -) -> int | None: + box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], threshold=0.2 +) -> Union[int, None]: """ :param box1: Iterable [xmin,ymin,xmax,ymax] :param box2: Iterable [xmin,ymin,xmax,ymax] @@ -195,8 +197,8 @@ def is_box_contained( def is_single_axis_contained( - box1: list | np.ndarray, box2: list | np.ndarray, axis="x", threhold=0.2 -) -> int | None: + box1: Union[np.ndarray, list], box2: Union[np.ndarray, list], axis="x", threhold=0.2 +) -> Union[int, None]: """ :param box1: Iterable [xmin,ymin,xmax,ymax] :param box2: Iterable [xmin,ymin,xmax,ymax] @@ -228,8 +230,8 @@ def is_single_axis_contained( def sorted_ocr_boxes( - dt_boxes: np.ndarray | list, threhold: float = 0.2 -) -> tuple[np.ndarray | list, list[int]]: + dt_boxes: Union[np.ndarray, List], threhold: float = 0.2 +) -> Tuple[Union[np.ndarray, list], List[int]]: """ Sort text boxes in order from top to bottom, left to right args: @@ -266,9 +268,7 @@ def sorted_ocr_boxes( return _boxes, indices -def gather_ocr_list_by_row( - ocr_list: list[list[list[float], str]], thehold: float = 0.2 -) -> list[list[list[float], str]]: +def gather_ocr_list_by_row(ocr_list: List[Any], thehold: float = 0.2) -> List[Any]: """ :param ocr_list: [[[xmin,ymin,xmax,ymax], text]] :return: @@ -305,12 +305,12 @@ def gather_ocr_list_by_row( return ocr_list -def box_4_1_poly_to_box_4_2(poly_box: list | np.ndarray) -> list[list[float]]: +def box_4_1_poly_to_box_4_2(poly_box: Union[np.ndarray, list]) -> List[List[float]]: xmin, ymin, xmax, ymax = tuple(poly_box) return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] -def box_4_2_poly_to_box_4_1(poly_box: list | np.ndarray) -> list[float]: +def box_4_2_poly_to_box_4_1(poly_box: Union[np.ndarray, list]) -> List[float]: """ 将poly_box转换为box_4_1 :param poly_box: @@ -407,7 +407,7 @@ def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.nd def plot_html_table( - logi_points: np.ndarray | list, cell_box_map: Dict[int, List[str]] + logi_points: Union[np.ndarray, list], cell_box_map: Dict[int, List[str]] ) -> str: # 初始化最大行数和列数 max_row = 0 diff --git a/setup_table_cls.py b/setup_table_cls.py index 70df236..7a8e87d 100644 --- a/setup_table_cls.py +++ b/setup_table_cls.py @@ -1,12 +1,13 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +import sys from pathlib import Path from typing import List, Union import setuptools -# from get_pypi_latest_version import GetPyPiLatestVersion +from get_pypi_latest_version import GetPyPiLatestVersion def read_txt(txt_path: Union[Path, str]) -> List[str]: @@ -17,21 +18,20 @@ def read_txt(txt_path: Union[Path, str]) -> List[str]: MODULE_NAME = "table_cls" -# obtainer = GetPyPiLatestVersion() -# try: -# latest_version = obtainer(MODULE_NAME) -# except Exception: -# latest_version = "0.0.0" -# -# VERSION_NUM = obtainer.version_add_one(latest_version) -VERSION_NUM = "1.0.0" - -# if len(sys.argv) > 2: -# match_str = " ".join(sys.argv[2:]) -# matched_versions = obtainer.extract_version(match_str) -# if matched_versions: -# VERSION_NUM = matched_versions -# sys.argv = sys.argv[:2] +obtainer = GetPyPiLatestVersion() +try: + latest_version = obtainer(MODULE_NAME) +except Exception: + latest_version = "0.0.0" + +VERSION_NUM = obtainer.version_add_one(latest_version) + +if len(sys.argv) > 2: + match_str = " ".join(sys.argv[2:]) + matched_versions = obtainer.extract_version(match_str) + if matched_versions: + VERSION_NUM = matched_versions +sys.argv = sys.argv[:2] setuptools.setup( name=MODULE_NAME, diff --git a/tests/test_wired_table_rec.py b/tests/test_wired_table_rec.py index 37d46e0..b4f263a 100644 --- a/tests/test_wired_table_rec.py +++ b/tests/test_wired_table_rec.py @@ -41,7 +41,7 @@ def test_squeeze_bug(): ocr_result, _ = ocr_engine(img_path) table_str, *_ = table_recog(str(img_path), ocr_result) td_nums = get_td_nums(table_str) - assert td_nums == 228 + assert td_nums == 291 @pytest.mark.parametrize( diff --git a/wired_table_rec/main.py b/wired_table_rec/main.py index 0b0b784..c220f97 100644 --- a/wired_table_rec/main.py +++ b/wired_table_rec/main.py @@ -7,7 +7,7 @@ import time import traceback from pathlib import Path -from typing import List, Optional, Tuple, Union, Dict +from typing import List, Optional, Tuple, Union, Dict, Any import numpy as np import cv2 @@ -50,7 +50,7 @@ def __call__( self, img: InputType, ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None, - ) -> Tuple[str, float, list]: + ) -> Tuple[str, float, Any, Any, Any]: if self.ocr is None and ocr_result is None: raise ValueError( "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed." @@ -109,10 +109,10 @@ def __call__( def transform_res( self, - cell_box_det_map: dict[int, List[any]], + cell_box_det_map: Dict[int, List[any]], polygons: np.ndarray, logi_points: List[np.ndarray], - ) -> List[dict[str, any]]: + ) -> List[Dict[str, any]]: res = [] for i in range(len(polygons)): ocr_res_list = cell_box_det_map.get(i) @@ -152,7 +152,7 @@ def re_rec( img: np.ndarray, sorted_polygons: np.ndarray, cell_box_map: Dict[int, List[str]], - ) -> Dict[int, List[any]]: + ) -> Dict[int, List[Any]]: """找到poly对应为空的框,尝试将直接将poly框直接送到识别中""" # for i in range(sorted_polygons.shape[0]): diff --git a/wired_table_rec/table_line_rec.py b/wired_table_rec/table_line_rec.py index 7be1004..2447066 100644 --- a/wired_table_rec/table_line_rec.py +++ b/wired_table_rec/table_line_rec.py @@ -48,7 +48,9 @@ def __call__(self, img: np.ndarray) -> Optional[np.ndarray]: [box_4_2_poly_to_box_4_1(box) for box in polygons] ) polygons = np.delete(polygons, list(del_idxs), axis=0) - _, idx = sorted_ocr_boxes([box_4_2_poly_to_box_4_1(box) for box in polygons]) + _, idx = sorted_ocr_boxes( + [box_4_2_poly_to_box_4_1(box) for box in polygons], threhold=0.4 + ) polygons = polygons[idx] polygons = merge_adjacent_polys(polygons) return polygons diff --git a/wired_table_rec/table_line_rec_plus.py b/wired_table_rec/table_line_rec_plus.py index 8880891..f4e2c77 100644 --- a/wired_table_rec/table_line_rec_plus.py +++ b/wired_table_rec/table_line_rec_plus.py @@ -44,7 +44,7 @@ def __call__(self, img: np.ndarray) -> Optional[np.ndarray]: polygons[:, 3, :].copy(), ) _, idx = sorted_ocr_boxes( - [box_4_2_poly_to_box_4_1(poly_box) for poly_box in polygons] + [box_4_2_poly_to_box_4_1(poly_box) for poly_box in polygons], threhold=0.4 ) polygons = polygons[idx] return polygons diff --git a/wired_table_rec/table_recover.py b/wired_table_rec/table_recover.py index be0502c..afb2c2d 100644 --- a/wired_table_rec/table_recover.py +++ b/wired_table_rec/table_recover.py @@ -115,38 +115,6 @@ def get_benchmark_rows( leftmost_cell_idxs = [v[0] for v in rows.values()] benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1] - theta = 15 - # 遍历其他所有的框,按照y轴进行区间划分 - range_res = {} - for cur_idx, cur_box in enumerate(polygons): - # fix cur_idx in benchmark_x - if cur_idx in leftmost_cell_idxs: - continue - - cur_y = cur_box[0, 1] - - start_idx, end_idx = None, None - for i, v in enumerate(benchmark_x): - if cur_y - theta <= v <= cur_y + theta: - break - - if cur_y > v: - start_idx = i - continue - - if cur_y < v: - end_idx = i - break - - range_res[cur_idx] = [start_idx, end_idx] - - sorted_res = dict(sorted(range_res.items(), key=lambda x: x[0], reverse=True)) - for k, v in sorted_res.items(): - if not all(v): - continue - - benchmark_x = np.insert(benchmark_x, v[1], polygons[k][0, 1]) - each_row_widths = (benchmark_x[1:] - benchmark_x[:-1]).tolist() # 求出最后一行cell中,最大的高度作为最后一行的高度 diff --git a/wired_table_rec/utils.py b/wired_table_rec/utils.py index a721751..d69676c 100644 --- a/wired_table_rec/utils.py +++ b/wired_table_rec/utils.py @@ -1,4 +1,5 @@ # -*- encoding: utf-8 -*- +import math import traceback from io import BytesIO from pathlib import Path @@ -350,3 +351,47 @@ def _scale_size(size, scale): scale = (scale, scale) w, h = size return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5) + + +class ImageOrientationCorrector: + """ + 对图片小角度(-90 - + 90度进行修正) + """ + + def __init__(self): + self.img_loader = LoadImage() + + def __call__(self, img: InputType): + img = self.img_loader(img) + # 取灰度 + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # 二值化 + gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] + # 边缘检测 + edges = cv2.Canny(gray, 100, 250, apertureSize=3) + # 霍夫变换,摘自https://blog.csdn.net/feilong_csdn/article/details/81586322 + lines = cv2.HoughLines(edges, 1, np.pi / 180, 0) + for rho, theta in lines[0]: + a = np.cos(theta) + b = np.sin(theta) + x0 = a * rho + y0 = b * rho + x1 = int(x0 + 1000 * (-b)) + y1 = int(y0 + 1000 * (a)) + x2 = int(x0 - 1000 * (-b)) + y2 = int(y0 - 1000 * (a)) + if x1 == x2 or y1 == y2: + return img + else: + t = float(y2 - y1) / (x2 - x1) + # 得到角度后 + rotate_angle = math.degrees(math.atan(t)) + if rotate_angle > 45: + rotate_angle = -90 + rotate_angle + elif rotate_angle < -45: + rotate_angle = 90 + rotate_angle + # 旋转图像 + (h, w) = img.shape[:2] + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, rotate_angle, 1.0) + return cv2.warpAffine(img, M, (w, h)) diff --git a/wired_table_rec/utils_table_recover.py b/wired_table_rec/utils_table_recover.py index 4af788c..a12e6e9 100644 --- a/wired_table_rec/utils_table_recover.py +++ b/wired_table_rec/utils_table_recover.py @@ -3,7 +3,7 @@ # @Contact: liekkaskono@163.com import os import random -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Set, Tuple import cv2 import numpy as np @@ -36,7 +36,9 @@ def sorted_boxes(dt_boxes: np.ndarray) -> np.ndarray: return np.array(_boxes) -def calculate_iou(box1: list | np.ndarray, box2: list | np.ndarray) -> float: +def calculate_iou( + box1: Union[np.ndarray, List], box2: Union[np.ndarray, List] +) -> float: """ :param box1: Iterable [xmin,ymin,xmax,ymax] :param box2: Iterable [xmin,ymin,xmax,ymax] @@ -68,7 +70,7 @@ def calculate_iou(box1: list | np.ndarray, box2: list | np.ndarray) -> float: def caculate_single_axis_iou( - box1: list | np.ndarray, box2: list | np.ndarray, axis="x" + box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], axis="x" ) -> float: """ :param box1: Iterable [xmin,ymin,xmax,ymax] @@ -92,8 +94,8 @@ def caculate_single_axis_iou( def is_box_contained( - box1: list | np.ndarray, box2: list | np.ndarray, threshold=0.2 -) -> int | None: + box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], threshold=0.2 +) -> Union[int, None]: """ :param box1: Iterable [xmin,ymin,xmax,ymax] :param box2: Iterable [xmin,ymin,xmax,ymax] @@ -136,8 +138,11 @@ def is_box_contained( def is_single_axis_contained( - box1: list | np.ndarray, box2: list | np.ndarray, axis="x", threhold: float = 0.2 -) -> int | None: + box1: Union[np.ndarray, List], + box2: Union[np.ndarray, List], + axis="x", + threhold: float = 0.2, +) -> Union[int, None]: """ :param box1: Iterable [xmin,ymin,xmax,ymax] :param box2: Iterable [xmin,ymin,xmax,ymax] @@ -168,7 +173,7 @@ def is_single_axis_contained( return None -def filter_duplicated_box(table_boxes: list[list[float]]) -> set[int]: +def filter_duplicated_box(table_boxes: List[List[float]]) -> Set[int]: """ :param table_boxes: [[xmin,ymin,xmax,ymax]] :return: @@ -197,8 +202,8 @@ def filter_duplicated_box(table_boxes: list[list[float]]) -> set[int]: def sorted_ocr_boxes( - dt_boxes: np.ndarray | list, threhold: float = 0.2 -) -> tuple[np.ndarray | list, list[int]]: + dt_boxes: Union[np.ndarray, list], threhold: float = 0.2 +) -> Tuple[Union[np.ndarray, list], List[int]]: """ Sort text boxes in order from top to bottom, left to right args: @@ -312,12 +317,12 @@ def plot_rec_box(img_path, output_path, sorted_polygons): cv2.imwrite(output_path, img) -def box_4_1_poly_to_box_4_2(poly_box: list | np.ndarray) -> list[list[float]]: +def box_4_1_poly_to_box_4_2(poly_box: Union[list, np.ndarray]) -> List[List[float]]: xmin, ymin, xmax, ymax = tuple(poly_box) return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] -def box_4_2_poly_to_box_4_1(poly_box: list | np.ndarray) -> list[float]: +def box_4_2_poly_to_box_4_1(poly_box: Union[list, np.ndarray]) -> List[Any]: """ 将poly_box转换为box_4_1 :param poly_box: @@ -357,9 +362,7 @@ def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.nd return matched, not_match_orc_boxes -def gather_ocr_list_by_row( - ocr_list: list[list[list[float], str]], threhold: float = 0.2 -) -> list[list[list[float], str]]: +def gather_ocr_list_by_row(ocr_list: List[Any], threhold: float = 0.2) -> List[Any]: """ :param ocr_list: [[[xmin,ymin,xmax,ymax], text]] :return: @@ -555,7 +558,7 @@ def is_inclusive_each_other(box1: np.ndarray, box2: np.ndarray): def plot_html_table( - logi_points: np.ndarray | list, cell_box_map: Dict[int, List[str]] + logi_points: Union[Union[np.ndarray, List]], cell_box_map: Dict[int, List[str]] ) -> str: # 初始化最大行数和列数 max_row = 0