Skip to content

Commit

Permalink
Merge pull request #28 from RapidAI/use_adaptive_code
Browse files Browse the repository at this point in the history
adapt for py 3.8 & add text oritation for wired table
  • Loading branch information
Joker1212 authored Sep 19, 2024
2 parents b90689f + dcbc8c2 commit 5cbd8b2
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 92 deletions.
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,16 @@ print(f"elasp: {elasp}")
# # 可视化 ocr 识别框
# plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res)
```

#### 偏移修正
```python
import cv2
img_path = f'tests/test_files/wired/squeeze_error.jpeg'
from wired_table_rec.utils import ImageOrientationCorrector
img_orientation_corrector = ImageOrientationCorrector()
img = cv2.imread(img_path)
img = img_orientation_corrector(img)
cv2.imwrite(f'img_rotated.jpg', img)
```
## FAQ (Frequently Asked Questions)

1. **问:偏移的图片能够处理吗?**
Expand All @@ -101,7 +110,7 @@ print(f"elasp: {elasp}")

### TODO List

- [ ] 识别前图片偏移修正
- [ ] 识别前图片偏移修正(完成有线表格小角度偏移修正)
- [ ] 增加数据集数量,增加更多评测对比
- [ ] 优化无线表格模型

Expand Down
4 changes: 2 additions & 2 deletions lineless_table_rec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def __call__(

def transform_res(
self,
cell_box_det_map: dict[int, List[any]],
cell_box_det_map: Dict[int, List[any]],
polygons: np.ndarray,
logi_points: List[np.ndarray],
) -> List[dict[str, any]]:
) -> List[Dict[str, any]]:
res = []
for i in range(len(polygons)):
ocr_res_list = cell_box_det_map.get(i)
Expand Down
32 changes: 16 additions & 16 deletions lineless_table_rec/utils_table_recover.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Contact: [email protected]
import os
import random
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Set, Tuple

import cv2
import numpy as np
Expand Down Expand Up @@ -67,7 +67,7 @@ def compute_poly_iou(a: np.ndarray, b: np.ndarray) -> float:
return float(inter_area) / union_area


def filter_duplicated_box(table_boxes: list[list[float]]) -> set[int]:
def filter_duplicated_box(table_boxes: List[List[float]]) -> Set[int]:
"""
:param table_boxes: [[xmin,ymin,xmax,ymax]]
:return:
Expand Down Expand Up @@ -95,7 +95,9 @@ def filter_duplicated_box(table_boxes: list[list[float]]) -> set[int]:
return delete_idx


def calculate_iou(box1: list | np.ndarray, box2: list | np.ndarray) -> float:
def calculate_iou(
box1: Union[np.ndarray, List], box2: Union[np.ndarray, List]
) -> float:
"""
:param box1: Iterable [xmin,ymin,xmax,ymax]
:param box2: Iterable [xmin,ymin,xmax,ymax]
Expand Down Expand Up @@ -127,7 +129,7 @@ def calculate_iou(box1: list | np.ndarray, box2: list | np.ndarray) -> float:


def caculate_single_axis_iou(
box1: list | np.ndarray, box2: list | np.ndarray, axis="x"
box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], axis="x"
) -> float:
"""
:param box1: Iterable [xmin,ymin,xmax,ymax]
Expand All @@ -151,8 +153,8 @@ def caculate_single_axis_iou(


def is_box_contained(
box1: list | np.ndarray, box2: list | np.ndarray, threshold=0.2
) -> int | None:
box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], threshold=0.2
) -> Union[int, None]:
"""
:param box1: Iterable [xmin,ymin,xmax,ymax]
:param box2: Iterable [xmin,ymin,xmax,ymax]
Expand Down Expand Up @@ -195,8 +197,8 @@ def is_box_contained(


def is_single_axis_contained(
box1: list | np.ndarray, box2: list | np.ndarray, axis="x", threhold=0.2
) -> int | None:
box1: Union[np.ndarray, list], box2: Union[np.ndarray, list], axis="x", threhold=0.2
) -> Union[int, None]:
"""
:param box1: Iterable [xmin,ymin,xmax,ymax]
:param box2: Iterable [xmin,ymin,xmax,ymax]
Expand Down Expand Up @@ -228,8 +230,8 @@ def is_single_axis_contained(


def sorted_ocr_boxes(
dt_boxes: np.ndarray | list, threhold: float = 0.2
) -> tuple[np.ndarray | list, list[int]]:
dt_boxes: Union[np.ndarray, List], threhold: float = 0.2
) -> Tuple[Union[np.ndarray, list], List[int]]:
"""
Sort text boxes in order from top to bottom, left to right
args:
Expand Down Expand Up @@ -266,9 +268,7 @@ def sorted_ocr_boxes(
return _boxes, indices


def gather_ocr_list_by_row(
ocr_list: list[list[list[float], str]], thehold: float = 0.2
) -> list[list[list[float], str]]:
def gather_ocr_list_by_row(ocr_list: List[Any], thehold: float = 0.2) -> List[Any]:
"""
:param ocr_list: [[[xmin,ymin,xmax,ymax], text]]
:return:
Expand Down Expand Up @@ -305,12 +305,12 @@ def gather_ocr_list_by_row(
return ocr_list


def box_4_1_poly_to_box_4_2(poly_box: list | np.ndarray) -> list[list[float]]:
def box_4_1_poly_to_box_4_2(poly_box: Union[np.ndarray, list]) -> List[List[float]]:
xmin, ymin, xmax, ymax = tuple(poly_box)
return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]


def box_4_2_poly_to_box_4_1(poly_box: list | np.ndarray) -> list[float]:
def box_4_2_poly_to_box_4_1(poly_box: Union[np.ndarray, list]) -> List[float]:
"""
将poly_box转换为box_4_1
:param poly_box:
Expand Down Expand Up @@ -407,7 +407,7 @@ def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.nd


def plot_html_table(
logi_points: np.ndarray | list, cell_box_map: Dict[int, List[str]]
logi_points: Union[np.ndarray, list], cell_box_map: Dict[int, List[str]]
) -> str:
# 初始化最大行数和列数
max_row = 0
Expand Down
32 changes: 16 additions & 16 deletions setup_table_cls.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import sys
from pathlib import Path
from typing import List, Union

import setuptools

# from get_pypi_latest_version import GetPyPiLatestVersion
from get_pypi_latest_version import GetPyPiLatestVersion


def read_txt(txt_path: Union[Path, str]) -> List[str]:
Expand All @@ -17,21 +18,20 @@ def read_txt(txt_path: Union[Path, str]) -> List[str]:

MODULE_NAME = "table_cls"

# obtainer = GetPyPiLatestVersion()
# try:
# latest_version = obtainer(MODULE_NAME)
# except Exception:
# latest_version = "0.0.0"
#
# VERSION_NUM = obtainer.version_add_one(latest_version)
VERSION_NUM = "1.0.0"

# if len(sys.argv) > 2:
# match_str = " ".join(sys.argv[2:])
# matched_versions = obtainer.extract_version(match_str)
# if matched_versions:
# VERSION_NUM = matched_versions
# sys.argv = sys.argv[:2]
obtainer = GetPyPiLatestVersion()
try:
latest_version = obtainer(MODULE_NAME)
except Exception:
latest_version = "0.0.0"

VERSION_NUM = obtainer.version_add_one(latest_version)

if len(sys.argv) > 2:
match_str = " ".join(sys.argv[2:])
matched_versions = obtainer.extract_version(match_str)
if matched_versions:
VERSION_NUM = matched_versions
sys.argv = sys.argv[:2]

setuptools.setup(
name=MODULE_NAME,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_wired_table_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_squeeze_bug():
ocr_result, _ = ocr_engine(img_path)
table_str, *_ = table_recog(str(img_path), ocr_result)
td_nums = get_td_nums(table_str)
assert td_nums == 228
assert td_nums == 291


@pytest.mark.parametrize(
Expand Down
10 changes: 5 additions & 5 deletions wired_table_rec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import traceback
from pathlib import Path
from typing import List, Optional, Tuple, Union, Dict
from typing import List, Optional, Tuple, Union, Dict, Any
import numpy as np
import cv2

Expand Down Expand Up @@ -50,7 +50,7 @@ def __call__(
self,
img: InputType,
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
) -> Tuple[str, float, list]:
) -> Tuple[str, float, Any, Any, Any]:
if self.ocr is None and ocr_result is None:
raise ValueError(
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
Expand Down Expand Up @@ -109,10 +109,10 @@ def __call__(

def transform_res(
self,
cell_box_det_map: dict[int, List[any]],
cell_box_det_map: Dict[int, List[any]],
polygons: np.ndarray,
logi_points: List[np.ndarray],
) -> List[dict[str, any]]:
) -> List[Dict[str, any]]:
res = []
for i in range(len(polygons)):
ocr_res_list = cell_box_det_map.get(i)
Expand Down Expand Up @@ -152,7 +152,7 @@ def re_rec(
img: np.ndarray,
sorted_polygons: np.ndarray,
cell_box_map: Dict[int, List[str]],
) -> Dict[int, List[any]]:
) -> Dict[int, List[Any]]:
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
#
for i in range(sorted_polygons.shape[0]):
Expand Down
4 changes: 3 additions & 1 deletion wired_table_rec/table_line_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def __call__(self, img: np.ndarray) -> Optional[np.ndarray]:
[box_4_2_poly_to_box_4_1(box) for box in polygons]
)
polygons = np.delete(polygons, list(del_idxs), axis=0)
_, idx = sorted_ocr_boxes([box_4_2_poly_to_box_4_1(box) for box in polygons])
_, idx = sorted_ocr_boxes(
[box_4_2_poly_to_box_4_1(box) for box in polygons], threhold=0.4
)
polygons = polygons[idx]
polygons = merge_adjacent_polys(polygons)
return polygons
Expand Down
2 changes: 1 addition & 1 deletion wired_table_rec/table_line_rec_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __call__(self, img: np.ndarray) -> Optional[np.ndarray]:
polygons[:, 3, :].copy(),
)
_, idx = sorted_ocr_boxes(
[box_4_2_poly_to_box_4_1(poly_box) for poly_box in polygons]
[box_4_2_poly_to_box_4_1(poly_box) for poly_box in polygons], threhold=0.4
)
polygons = polygons[idx]
return polygons
Expand Down
32 changes: 0 additions & 32 deletions wired_table_rec/table_recover.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,38 +115,6 @@ def get_benchmark_rows(
leftmost_cell_idxs = [v[0] for v in rows.values()]
benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1]

theta = 15
# 遍历其他所有的框,按照y轴进行区间划分
range_res = {}
for cur_idx, cur_box in enumerate(polygons):
# fix cur_idx in benchmark_x
if cur_idx in leftmost_cell_idxs:
continue

cur_y = cur_box[0, 1]

start_idx, end_idx = None, None
for i, v in enumerate(benchmark_x):
if cur_y - theta <= v <= cur_y + theta:
break

if cur_y > v:
start_idx = i
continue

if cur_y < v:
end_idx = i
break

range_res[cur_idx] = [start_idx, end_idx]

sorted_res = dict(sorted(range_res.items(), key=lambda x: x[0], reverse=True))
for k, v in sorted_res.items():
if not all(v):
continue

benchmark_x = np.insert(benchmark_x, v[1], polygons[k][0, 1])

each_row_widths = (benchmark_x[1:] - benchmark_x[:-1]).tolist()

# 求出最后一行cell中,最大的高度作为最后一行的高度
Expand Down
45 changes: 45 additions & 0 deletions wired_table_rec/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- encoding: utf-8 -*-
import math
import traceback
from io import BytesIO
from pathlib import Path
Expand Down Expand Up @@ -350,3 +351,47 @@ def _scale_size(size, scale):
scale = (scale, scale)
w, h = size
return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)


class ImageOrientationCorrector:
"""
对图片小角度(-90 - + 90度进行修正)
"""

def __init__(self):
self.img_loader = LoadImage()

def __call__(self, img: InputType):
img = self.img_loader(img)
# 取灰度
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 二值化
gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
# 边缘检测
edges = cv2.Canny(gray, 100, 250, apertureSize=3)
# 霍夫变换,摘自https://blog.csdn.net/feilong_csdn/article/details/81586322
lines = cv2.HoughLines(edges, 1, np.pi / 180, 0)
for rho, theta in lines[0]:
a = np.cos(theta)
b = np.sin(theta)
x0 = a * rho
y0 = b * rho
x1 = int(x0 + 1000 * (-b))
y1 = int(y0 + 1000 * (a))
x2 = int(x0 - 1000 * (-b))
y2 = int(y0 - 1000 * (a))
if x1 == x2 or y1 == y2:
return img
else:
t = float(y2 - y1) / (x2 - x1)
# 得到角度后
rotate_angle = math.degrees(math.atan(t))
if rotate_angle > 45:
rotate_angle = -90 + rotate_angle
elif rotate_angle < -45:
rotate_angle = 90 + rotate_angle
# 旋转图像
(h, w) = img.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, rotate_angle, 1.0)
return cv2.warpAffine(img, M, (w, h))
Loading

0 comments on commit 5cbd8b2

Please sign in to comment.