From 97b685ff187a7f76c2deee0f9c265eb4b178f4a5 Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Mon, 30 Dec 2024 17:07:04 +0800 Subject: [PATCH 1/3] feat: optimize torch inference --- .github/workflows/publish_whl.yml | 4 + .github/workflows/publish_whl_torch.yml | 59 ----- README.md | 64 +---- demo_torch.py | 50 ++-- .../download_model.py | 24 +- {rapid_table_torch => rapid_table}/logger.py | 0 rapid_table/main.py | 29 +- rapid_table/params.py | 40 +++ rapid_table/table_structure/__init__.py | 1 + .../table_structure_unitable.py | 13 +- .../table_structure/unitable_modules.py | 0 rapid_table/utils.py | 1 + rapid_table_torch/__init__.py | 2 - rapid_table_torch/main.py | 155 ----------- rapid_table_torch/table_matcher/__init__.py | 4 - rapid_table_torch/table_matcher/matcher.py | 192 -------------- rapid_table_torch/table_matcher/utils.py | 248 ------------------ rapid_table_torch/table_structure/__init__.py | 1 - rapid_table_torch/table_structure/utils.py | 28 -- rapid_table_torch/utils.py | 210 --------------- requirements_torch.txt | 8 - setup.py | 1 + setup_torch.py | 69 ----- tests/test_table_torch.py | 10 +- 24 files changed, 130 insertions(+), 1083 deletions(-) delete mode 100644 .github/workflows/publish_whl_torch.yml rename {rapid_table_torch => rapid_table}/download_model.py (70%) rename {rapid_table_torch => rapid_table}/logger.py (100%) create mode 100644 rapid_table/params.py rename rapid_table_torch/table_structure/table_structure.py => rapid_table/table_structure/table_structure_unitable.py (93%) rename rapid_table_torch/table_structure/components.py => rapid_table/table_structure/unitable_modules.py (100%) delete mode 100644 rapid_table_torch/__init__.py delete mode 100644 rapid_table_torch/main.py delete mode 100644 rapid_table_torch/table_matcher/__init__.py delete mode 100644 rapid_table_torch/table_matcher/matcher.py delete mode 100644 rapid_table_torch/table_matcher/utils.py delete mode 100644 rapid_table_torch/table_structure/__init__.py delete mode 100644 rapid_table_torch/table_structure/utils.py delete mode 100644 rapid_table_torch/utils.py delete mode 100644 requirements_torch.txt delete mode 100644 setup_torch.py diff --git a/.github/workflows/publish_whl.yml b/.github/workflows/publish_whl.yml index 99e137c..7e135b3 100644 --- a/.github/workflows/publish_whl.yml +++ b/.github/workflows/publish_whl.yml @@ -34,8 +34,12 @@ jobs: pip install -r requirements.txt pip install rapidocr_onnxruntime + pip install torch + pip install torchvision + pip install tokenizers pip install pytest pytest tests/test_table.py + pytest tests/test_table_torch.py GenerateWHL_PushPyPi: needs: UnitTesting diff --git a/.github/workflows/publish_whl_torch.yml b/.github/workflows/publish_whl_torch.yml deleted file mode 100644 index 09bee91..0000000 --- a/.github/workflows/publish_whl_torch.yml +++ /dev/null @@ -1,59 +0,0 @@ -name: Push rapidocr_table_torch to pypi - -on: - push: - tags: - - torch_v* - -env: - RESOURCES_URL: https://github.com/RapidAI/RapidTable/releases/download/assets/unitable.zip - - -jobs: - UnitTesting: - runs-on: ubuntu-latest - steps: - - name: Pull latest code - uses: actions/checkout@v3 - - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: 'x64' - - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - - name: Unit testings - run: | - pip install -r requirements_torch.txt - pip install rapidocr_onnxruntime - pip install pytest - pytest tests/test_table_torch.py - - GenerateWHL_PushPyPi: - needs: UnitTesting - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: 'x64' - - - name: Run setup_torch.py - run: | - pip install -r requirements_torch.txt - python -m pip install --upgrade pip - pip install wheel get_pypi_latest_version - python setup_torch.py bdist_wheel ${{ github.ref_name }} - - - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@v1.5.0 - with: - password: ${{ secrets.RAPID_TABLE }} - packages_dir: dist/ diff --git a/README.md b/README.md index a621571..89b0f59 100644 --- a/README.md +++ b/README.md @@ -105,8 +105,8 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu ```bash pip install rapidocr_onnxruntime pip install rapid_table -#pip install rapid_table_torch # for unitable inference -#pip install onnxruntime-gpu # for gpu inference +#pip install rapid_table[torch] # for unitable inference +#pip install onnxruntime-gpu # for onnx gpu inference ``` ### 使用方式 @@ -117,6 +117,7 @@ RapidTable类提供model_path参数,可以自行指定上述2个模型,默 ```python table_engine = RapidTable() +# table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable") ``` 完整示例: @@ -132,6 +133,8 @@ from rapid_table.table_structure.utils import trans_char_ocr_res table_engine = RapidTable() # 开启onnx-gpu推理 # table_engine = RapidTable(use_cuda=True) +# 使用torch推理版本的unitable模型 +# table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable") ocr_engine = RapidOCR() viser = VisTable() @@ -159,41 +162,8 @@ viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_p print(table_html_str) ``` -#### torch版本 -```python -from pathlib import Path -from rapidocr_onnxruntime import RapidOCR - -from rapid_table_torch import RapidTable, VisTable -from rapid_table_torch.table_structure.utils import trans_char_ocr_res - -if __name__ == '__main__': -# Init -ocr_engine = RapidOCR() -table_engine = RapidTable(device="cpu") # 默认使用cpu,若使用cuda,则传入device="cuda:0" -viser = VisTable() -img_path = "tests/test_files/image34.png" -# OCR,本模型检测框比较精准,配合单字匹配效果更好 -ocr_result, _ = ocr_engine(img_path, return_word_box=True) -ocr_result = trans_char_ocr_res(ocr_result) -boxes, txts, scores = list(zip(*ocr_result)) -# Save -save_dir = Path("outputs") -save_dir.mkdir(parents=True, exist_ok=True) - -save_html_path = save_dir / f"{Path(img_path).stem}.html" -save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}" -# 返回逻辑坐标 -table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result) -save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}" -vis_imged = viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path, logic_points, - save_logic_path) -print(f"elapse:{elapse}") -``` - #### 终端运行 -##### onnx: - 用法: ```bash @@ -214,29 +184,7 @@ print(f"elapse:{elapse}") ```bash rapid_table -v -img test_images/table.jpg ``` - -##### pytorch: -- 用法: - - ```bash - $ rapid_table_torch -h - usage: rapid_table_torch [-h] [-v] -img IMG_PATH [-d DEVICE] - - optional arguments: - -h, --help show this help message and exit - -v, --vis Whether to visualize the layout results. - -img IMG_PATH, --img_path IMG_PATH - Path to image for layout. - -d DEVICE, --device device - The model device used for inference. - ``` - -- 示例: - - ```bash - rapid_table_torch -v -img test_images/table.jpg - ``` - + ### 结果 #### 返回结果 diff --git a/demo_torch.py b/demo_torch.py index e8dc1e0..72a2d44 100644 --- a/demo_torch.py +++ b/demo_torch.py @@ -4,31 +4,29 @@ from pathlib import Path from rapidocr_onnxruntime import RapidOCR -from rapid_table_torch import RapidTable, VisTable -from rapid_table_torch.table_structure.utils import trans_char_ocr_res +from rapid_table import RapidTable, VisTable +from rapid_table.table_structure.utils import trans_char_ocr_res -if __name__ == '__main__': - # Init - ocr_engine = RapidOCR() - table_engine = RapidTable(encoder_path="rapid_table_torch/models/encoder.pth", - decoder_path="rapid_table_torch/models/decoder.pth", - vocab_path="rapid_table_torch/models/vocab.json", - device="cpu") - viser = VisTable() - img_path = "tests/test_files/image34.png" - # OCR - ocr_result, _ = ocr_engine(img_path, return_word_box=True) - ocr_result = trans_char_ocr_res(ocr_result) - boxes, txts, scores = list(zip(*ocr_result)) - # Save - save_dir = Path("outputs") - save_dir.mkdir(parents=True, exist_ok=True) +ocr_engine = RapidOCR() +table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable") +viser = VisTable() +img_path = "tests/test_files/table.jpg" +# OCR +ocr_result, _ = ocr_engine(img_path, return_word_box=True) +ocr_result = trans_char_ocr_res(ocr_result) +boxes, txts, scores = list(zip(*ocr_result)) +# Save +save_dir = Path("outputs") +save_dir.mkdir(parents=True, exist_ok=True) - save_html_path = save_dir / f"{Path(img_path).stem}.html" - save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}" - # 返回逻辑坐标 - table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result) - save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}" - vis_imged = viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path, logic_points, - save_logic_path) - print(f"elapse:{elapse}") +save_html_path = save_dir / f"{Path(img_path).stem}.html" +save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}" + +table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result) +viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path) +# 返回逻辑坐标 +# table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result, return_logic_points=True) +# save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}" +# vis_imged = viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path, logic_points, +# save_logic_path) +print(f"elapse:{elapse}") diff --git a/rapid_table_torch/download_model.py b/rapid_table/download_model.py similarity index 70% rename from rapid_table_torch/download_model.py rename to rapid_table/download_model.py index d636f94..9e9c496 100644 --- a/rapid_table_torch/download_model.py +++ b/rapid_table/download_model.py @@ -10,11 +10,33 @@ logger = get_logger("DownloadModel") CUR_DIR = Path(__file__).resolve() PROJECT_DIR = CUR_DIR.parent - +ROOT_URL = "https://www.modelscope.cn/studio/jockerK/TableRec/resolve/master/models/table_rec/unitable/" +KEY_TO_MODEL_URL = { + "unitable": { + "encoder": f"{ROOT_URL}/encoder.pth", + "decoder": f"{ROOT_URL}/decoder.pth", + "vocab": f"{ROOT_URL}/vocab.json", + } +} class DownloadModel: cur_dir = PROJECT_DIR + @staticmethod + def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, None]) -> str: + if path is not None: + return path + + model_url = KEY_TO_MODEL_URL.get(model_type, {}).get(sub_file_type, None) + if model_url: + model_path = DownloadModel.download(model_url) + return model_path + + logger.info( + "model url is None, using the default download model %s", path + ) + return path + @classmethod def download(cls, model_full_url: Union[str, Path]) -> str: save_dir = cls.cur_dir / "models" diff --git a/rapid_table_torch/logger.py b/rapid_table/logger.py similarity index 100% rename from rapid_table_torch/logger.py rename to rapid_table/logger.py diff --git a/rapid_table/main.py b/rapid_table/main.py index d620fe5..202f6e0 100644 --- a/rapid_table/main.py +++ b/rapid_table/main.py @@ -5,33 +5,34 @@ import copy import importlib import time +from dataclasses import asdict from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import cv2 import numpy as np +from .download_model import DownloadModel +from .params import accept_kwargs_as_dataclass, BaseConfig from .table_matcher import TableMatch -from .table_structure import TableStructurer +from .table_structure import TableStructurer, TableStructureUnitable from .utils import LoadImage, VisTable root_dir = Path(__file__).resolve().parent class RapidTable: - def __init__(self, model_path: Optional[str] = None, model_type: str = None, use_cuda: bool = False): - if model_path is None: - model_path = str( - root_dir / "models" / "slanet-plus.onnx" - ) - model_type = "slanet-plus" - self.model_type = model_type + @accept_kwargs_as_dataclass(BaseConfig) + def __init__(self, config: BaseConfig): + self.model_type = config.model_type self.load_img = LoadImage() - config = { - "model_path": model_path, - "use_cuda": use_cuda, - } - self.table_structure = TableStructurer(config) + if self.model_type == "unitable": + config.encoder_path = DownloadModel.get_model_path(self.model_type, "encoder", config.encoder_path) + config.decoder_path = DownloadModel.get_model_path(self.model_type, "decoder", config.decoder_path) + config.vocab_path = DownloadModel.get_model_path(self.model_type, "vocab", config.vocab_path) + self.table_structure = TableStructureUnitable(asdict(config)) + else: + self.table_structure = TableStructurer(asdict(config)) self.table_matcher = TableMatch() try: diff --git a/rapid_table/params.py b/rapid_table/params.py new file mode 100644 index 0000000..c5756aa --- /dev/null +++ b/rapid_table/params.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass, fields +from functools import wraps +from pathlib import Path + +from rapid_table.logger import get_logger + +root_dir = Path(__file__).resolve().parent +logger = get_logger("params") + +@dataclass +class BaseConfig: + model_type: str = "slanet-plus" + model_path: str = str(root_dir / "models" / "slanet-plus.onnx") + use_cuda: bool = False + device: str = "cpu" + encoder_path: str = None + decoder_path: str = None + vocab_path: str = None + + +def accept_kwargs_as_dataclass(cls): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) == 2 and isinstance(args[1], cls): + # 如果已经传递了 ModelConfig 实例,直接调用函数 + return func(*args, **kwargs) + else: + # 提取 cls 中定义的字段 + cls_fields = {field.name for field in fields(cls)} + # 过滤掉未定义的字段 + filtered_kwargs = {k: v for k, v in kwargs.items() if k in cls_fields} + # 发出警告对于未定义的字段 + for k in (kwargs.keys() - cls_fields): + logger.warning(f"Warning: '{k}' is not a valid field in {cls.__name__} and will be ignored.") + # 创建 ModelConfig 实例并调用函数 + config = cls(**filtered_kwargs) + return func(args[0], config=config) + return wrapper + return decorator \ No newline at end of file diff --git a/rapid_table/table_structure/__init__.py b/rapid_table/table_structure/__init__.py index a548391..3e638d9 100644 --- a/rapid_table/table_structure/__init__.py +++ b/rapid_table/table_structure/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from .table_structure import TableStructurer +from .table_structure_unitable import TableStructureUnitable diff --git a/rapid_table_torch/table_structure/table_structure.py b/rapid_table/table_structure/table_structure_unitable.py similarity index 93% rename from rapid_table_torch/table_structure/table_structure.py rename to rapid_table/table_structure/table_structure_unitable.py index 32f3fa0..da9ed7f 100644 --- a/rapid_table_torch/table_structure/table_structure.py +++ b/rapid_table/table_structure/table_structure_unitable.py @@ -1,5 +1,6 @@ import re import time +from typing import Dict, Any import cv2 import numpy as np @@ -7,7 +8,7 @@ from PIL import Image from tokenizers import Tokenizer -from .components import Encoder, GPTFastDecoder +from .unitable_modules import Encoder, GPTFastDecoder from torchvision import transforms IMG_SIZE = 448 @@ -77,8 +78,14 @@ ] -class TableStructurer: - def __init__(self, encoder_path: str, decoder_path: str, vocab_path: str, device: str): +class TableStructureUnitable: + def __init__(self, config:Dict[str, Any]): + # encoder_path: str, decoder_path: str, vocab_path: str, device: str + vocab_path = config["vocab_path"] + encoder_path = config["encoder_path"] + decoder_path = config["decoder_path"] + device = config.get("device", "cuda:0") if config["use_cuda"] else "cpu" + self.vocab = Tokenizer.from_file(vocab_path) self.token_white_list = [self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS] self.bbox_token_ids = set([self.vocab.token_to_id(i) for i in BBOX_TOKENS]) diff --git a/rapid_table_torch/table_structure/components.py b/rapid_table/table_structure/unitable_modules.py similarity index 100% rename from rapid_table_torch/table_structure/components.py rename to rapid_table/table_structure/unitable_modules.py diff --git a/rapid_table/utils.py b/rapid_table/utils.py index e7135c5..77224ce 100644 --- a/rapid_table/utils.py +++ b/rapid_table/utils.py @@ -208,3 +208,4 @@ def save_img(save_path: Union[str, Path], img: np.ndarray): def save_html(save_path: Union[str, Path], html: str): with open(save_path, "w", encoding="utf-8") as f: f.write(html) + diff --git a/rapid_table_torch/__init__.py b/rapid_table_torch/__init__.py deleted file mode 100644 index 86c7b1c..0000000 --- a/rapid_table_torch/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .main import RapidTable -from .utils import VisTable diff --git a/rapid_table_torch/main.py b/rapid_table_torch/main.py deleted file mode 100644 index 5ef3dcc..0000000 --- a/rapid_table_torch/main.py +++ /dev/null @@ -1,155 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -import argparse -import copy -import importlib -import os -import time -from pathlib import Path -from typing import List, Union - -import cv2 -import numpy as np - -from .download_model import DownloadModel -from .logger import get_logger -from .table_matcher import TableMatch -from .table_structure import TableStructurer -from .utils import LoadImage, VisTable - -root_dir = Path(__file__).resolve().parent -model_dir = os.path.join(root_dir, "models") -logger = get_logger("rapid_table_torch") -default_config = os.path.join(root_dir, "config.yaml") -ROOT_URL = "https://www.modelscope.cn/studio/jockerK/TableRec/resolve/master/models/table_rec/unitable/" -KEY_TO_MODEL_URL = { - "unitable": { - "encoder": f"{ROOT_URL}/encoder.pth", - "decoder": f"{ROOT_URL}/decoder.pth", - "vocab": f"{ROOT_URL}/vocab.json", - } -} - - -class RapidTable: - def __init__(self, encoder_path: str = None, decoder_path: str = None, vocab_path: str = None, - model_type: str = "unitable", - device: str = "cpu"): - self.model_type = model_type - self.load_img = LoadImage() - encoder_path = self.get_model_path(model_type, "encoder", encoder_path) - decoder_path = self.get_model_path(model_type, "decoder", decoder_path) - vocab_path = self.get_model_path(model_type, "vocab", vocab_path) - self.table_structure = TableStructurer(encoder_path, decoder_path, vocab_path, device) - self.table_matcher = TableMatch() - try: - self.ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR() - except ModuleNotFoundError: - self.ocr_engine = None - - def __call__( - self, - img_content: Union[str, np.ndarray, bytes, Path], - ocr_result: List[Union[List[List[float]], str, str]] = None - ): - if self.ocr_engine is None and ocr_result is None: - raise ValueError( - "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed." - ) - - img = self.load_img(img_content) - - s = time.time() - h, w = img.shape[:2] - if ocr_result is None: - ocr_result, _ = self.ocr_engine(img) - dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w) - - pred_structures, pred_bboxes, _ = self.table_structure(copy.deepcopy(img)) - pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res) - logic_points = self.table_matcher.decode_logic_points(pred_structures) - elapse = time.time() - s - return pred_html, pred_bboxes, logic_points, elapse - - def get_boxes_recs( - self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int - ): - dt_boxes, rec_res, scores = list(zip(*ocr_result)) - rec_res = list(zip(rec_res, scores)) - - r_boxes = [] - for box in dt_boxes: - box = np.array(box) - x_min = max(0, box[:, 0].min() - 1) - x_max = min(w, box[:, 0].max() + 1) - y_min = max(0, box[:, 1].min() - 1) - y_max = min(h, box[:, 1].max() + 1) - box = [x_min, y_min, x_max, y_max] - r_boxes.append(box) - dt_boxes = np.array(r_boxes) - return dt_boxes, rec_res - - @staticmethod - def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, None]) -> str: - if path is not None: - return path - - model_url = KEY_TO_MODEL_URL.get(model_type, {}).get(sub_file_type, None) - if model_url: - model_path = DownloadModel.download(model_url) - return model_path - - logger.info( - "model url is None, using the default download model %s", path - ) - return path -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "-v", - "--vis", - action="store_true", - help="Wheter to visualize the layout results.", - ) - parser.add_argument( - "-img", "--img_path", type=str, required=True, help="Path to image for layout." - ) - parser.add_argument( - "-d", "--device", type=str, default="cpu", help="device to use" - ) - args = parser.parse_args() - - try: - ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR() - except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime." - ) from exc - - rapid_table = RapidTable(device=args.device) - - img = cv2.imread(args.img_path) - - ocr_result, _ = ocr_engine(img) - table_html_str, table_cell_bboxes, elapse = rapid_table(img, ocr_result) - print(table_html_str) - - viser = VisTable() - if args.vis: - img_path = Path(args.img_path) - - save_dir = img_path.resolve().parent - save_html_path = save_dir / f"{Path(img_path).stem}.html" - save_drawed_path = save_dir / f"vis_{Path(img_path).name}" - viser( - img_path, - table_html_str, - save_html_path, - table_cell_bboxes, - save_drawed_path, - ) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/rapid_table_torch/table_matcher/__init__.py b/rapid_table_torch/table_matcher/__init__.py deleted file mode 100644 index 9bff7d7..0000000 --- a/rapid_table_torch/table_matcher/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -from .matcher import TableMatch diff --git a/rapid_table_torch/table_matcher/matcher.py b/rapid_table_torch/table_matcher/matcher.py deleted file mode 100644 index 0453929..0000000 --- a/rapid_table_torch/table_matcher/matcher.py +++ /dev/null @@ -1,192 +0,0 @@ -# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -*- encoding: utf-8 -*- -import numpy as np - -from .utils import compute_iou, distance - - -class TableMatch: - def __init__(self, filter_ocr_result=True, use_master=False): - self.filter_ocr_result = filter_ocr_result - self.use_master = use_master - - def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res): - if self.filter_ocr_result: - dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, rec_res) - matched_index = self.match_result(dt_boxes, pred_bboxes) - pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) - return pred_html - - def match_result(self, dt_boxes, pred_bboxes): - matched = {} - for i, gt_box in enumerate(dt_boxes): - distances = [] - for j, pred_box in enumerate(pred_bboxes): - if len(pred_box) == 8: - pred_box = [ - np.min(pred_box[0::2]), - np.min(pred_box[1::2]), - np.max(pred_box[0::2]), - np.max(pred_box[1::2]), - ] - distances.append( - (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box)) - ) # compute iou and l1 distance - sorted_distances = distances.copy() - # select det box by iou and l1 distance - sorted_distances = sorted( - sorted_distances, key=lambda item: (item[1], item[0]) - ) - if distances.index(sorted_distances[0]) not in matched.keys(): - matched[distances.index(sorted_distances[0])] = [i] - else: - matched[distances.index(sorted_distances[0])].append(i) - return matched - - def get_pred_html(self, pred_structures, matched_index, ocr_contents): - end_html = [] - td_index = 0 - for tag in pred_structures: - if "" not in tag: - end_html.append(tag) - continue - - if "" == tag: - end_html.extend("") - - if td_index in matched_index.keys(): - b_with = False - if ( - "" in ocr_contents[matched_index[td_index][0]] - and len(matched_index[td_index]) > 1 - ): - b_with = True - end_html.extend("") - - for i, td_index_index in enumerate(matched_index[td_index]): - content = ocr_contents[td_index_index][0] - if len(matched_index[td_index]) > 1: - if len(content) == 0: - continue - - if content[0] == " ": - content = content[1:] - - if "" in content: - content = content[3:] - - if "" in content: - content = content[:-4] - - if len(content) == 0: - continue - - if i != len(matched_index[td_index]) - 1 and " " != content[-1]: - content += " " - end_html.extend(content) - - if b_with: - end_html.extend("") - - if "" == tag: - end_html.append("") - else: - end_html.append(tag) - - td_index += 1 - - # Filter elements - filter_elements = ["", "", "", ""] - end_html = [v for v in end_html if v not in filter_elements] - return "".join(end_html), end_html - def decode_logic_points(self, pred_structures): - logic_points = [] - current_row = 0 - current_col = 0 - max_rows = 0 - max_cols = 0 - occupied_cells = {} # 用于记录已经被占用的单元格 - - def is_occupied(row, col): - return (row, col) in occupied_cells - - def mark_occupied(row, col, rowspan, colspan): - for r in range(row, row + rowspan): - for c in range(col, col + colspan): - occupied_cells[(r, c)] = True - - i = 0 - while i < len(pred_structures): - token = pred_structures[i] - - if token == '': - current_col = 0 # 每次遇到 时,重置当前列号 - elif token == '': - current_row += 1 # 行结束,行号增加 - elif token .startswith(''): - if 'colspan=' in pred_structures[j]: - colspan = int(pred_structures[j].split('=')[1].strip('"\'')) - elif 'rowspan=' in pred_structures[j]: - rowspan = int(pred_structures[j].split('=')[1].strip('"\'')) - j += 1 - - # 跳过已经处理过的属性 token - i = j - - # 找到下一个未被占用的列 - while is_occupied(current_row, current_col): - current_col += 1 - - # 计算逻辑坐标 - r_start = current_row - r_end = current_row + rowspan - 1 - col_start = current_col - col_end = current_col + colspan - 1 - - # 记录逻辑坐标 - logic_points.append([r_start, r_end, col_start, col_end]) - - # 标记占用的单元格 - mark_occupied(r_start, col_start, rowspan, colspan) - - # 更新当前列号 - current_col += colspan - - # 更新最大行数和列数 - max_rows = max(max_rows, r_end + 1) - max_cols = max(max_cols, col_end + 1) - - i += 1 - - return logic_points - - def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res): - y1 = pred_bboxes[:, 1::2].min() - new_dt_boxes = [] - new_rec_res = [] - - for box, rec in zip(dt_boxes, rec_res): - if np.max(box[1::2]) < y1: - continue - new_dt_boxes.append(box) - new_rec_res.append(rec) - return new_dt_boxes, new_rec_res diff --git a/rapid_table_torch/table_matcher/utils.py b/rapid_table_torch/table_matcher/utils.py deleted file mode 100644 index 3ec8fcc..0000000 --- a/rapid_table_torch/table_matcher/utils.py +++ /dev/null @@ -1,248 +0,0 @@ -# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -import re - - -def deal_isolate_span(thead_part): - """ - Deal with isolate span cases in this function. - It causes by wrong prediction in structure recognition model. - eg. predict to rowspan="2">. - :param thead_part: - :return: - """ - # 1. find out isolate span tokens. - isolate_pattern = ( - ' rowspan="(\d)+" colspan="(\d)+">|' - ' colspan="(\d)+" rowspan="(\d)+">|' - ' rowspan="(\d)+">|' - ' colspan="(\d)+">' - ) - isolate_iter = re.finditer(isolate_pattern, thead_part) - isolate_list = [i.group() for i in isolate_iter] - - # 2. find out span number, by step 1 results. - span_pattern = ( - ' rowspan="(\d)+" colspan="(\d)+"|' - ' colspan="(\d)+" rowspan="(\d)+"|' - ' rowspan="(\d)+"|' - ' colspan="(\d)+"' - ) - corrected_list = [] - for isolate_item in isolate_list: - span_part = re.search(span_pattern, isolate_item) - spanStr_in_isolateItem = span_part.group() - # 3. merge the span number into the span token format string. - if spanStr_in_isolateItem is not None: - corrected_item = "".format(spanStr_in_isolateItem) - corrected_list.append(corrected_item) - else: - corrected_list.append(None) - - # 4. replace original isolated token. - for corrected_item, isolate_item in zip(corrected_list, isolate_list): - if corrected_item is not None: - thead_part = thead_part.replace(isolate_item, corrected_item) - else: - pass - return thead_part - - -def deal_duplicate_bb(thead_part): - """ - Deal duplicate or after replace. - Keep one in a token. - :param thead_part: - :return: - """ - # 1. find out in . - td_pattern = ( - '(.+?)|' - '(.+?)|' - '(.+?)|' - '(.+?)|' - "(.*?)" - ) - td_iter = re.finditer(td_pattern, thead_part) - td_list = [t.group() for t in td_iter] - - # 2. is multiply in or not? - new_td_list = [] - for td_item in td_list: - if td_item.count("") > 1 or td_item.count("") > 1: - # multiply in case. - # 1. remove all - td_item = td_item.replace("", "").replace("", "") - # 2. replace -> , -> . - td_item = td_item.replace("", "").replace("", "") - new_td_list.append(td_item) - else: - new_td_list.append(td_item) - - # 3. replace original thead part. - for td_item, new_td_item in zip(td_list, new_td_list): - thead_part = thead_part.replace(td_item, new_td_item) - return thead_part - - -def deal_bb(result_token): - """ - In our opinion, always occurs in text's context. - This function will find out all tokens in and insert by manual. - :param result_token: - :return: - """ - # find out parts. - thead_pattern = "(.*?)" - if re.search(thead_pattern, result_token) is None: - return result_token - thead_part = re.search(thead_pattern, result_token).group() - origin_thead_part = copy.deepcopy(thead_part) - - # check "rowspan" or "colspan" occur in parts or not . - span_pattern = '|||' - span_iter = re.finditer(span_pattern, thead_part) - span_list = [s.group() for s in span_iter] - has_span_in_head = True if len(span_list) > 0 else False - - if not has_span_in_head: - # not include "rowspan" or "colspan" branch 1. - # 1. replace to , and to - # 2. it is possible to predict text include or by Text-line recognition, - # so we replace to , and to - thead_part = ( - thead_part.replace("", "") - .replace("", "") - .replace("", "") - .replace("", "") - ) - else: - # include "rowspan" or "colspan" branch 2. - # Firstly, we deal rowspan or colspan cases. - # 1. replace > to > - # 2. replace to - # 3. it is possible to predict text include or by Text-line recognition, - # so we replace to , and to - - # Secondly, deal ordinary cases like branch 1 - - # replace ">" to "" - replaced_span_list = [] - for sp in span_list: - replaced_span_list.append(sp.replace(">", ">")) - for sp, rsp in zip(span_list, replaced_span_list): - thead_part = thead_part.replace(sp, rsp) - - # replace "" to "" - thead_part = thead_part.replace("", "") - - # remove duplicated by re.sub - mb_pattern = "()+" - single_b_string = "" - thead_part = re.sub(mb_pattern, single_b_string, thead_part) - - mgb_pattern = "()+" - single_gb_string = "" - thead_part = re.sub(mgb_pattern, single_gb_string, thead_part) - - # ordinary cases like branch 1 - thead_part = thead_part.replace("", "").replace("", "") - - # convert back to , empty cell has no . - # but space cell( ) is suitable for - thead_part = thead_part.replace("", "") - # deal with duplicated - thead_part = deal_duplicate_bb(thead_part) - # deal with isolate span tokens, which causes by wrong predict by structure prediction. - # eg.PMC5994107_011_00.png - thead_part = deal_isolate_span(thead_part) - # replace original result with new thead part. - result_token = result_token.replace(origin_thead_part, thead_part) - return result_token - - -def deal_eb_token(master_token): - """ - post process with , , ... - emptyBboxTokenDict = { - "[]": '', - "[' ']": '', - "['', ' ', '']": '', - "['\\u2028', '\\u2028']": '', - "['', ' ', '']": '', - "['', '']": '', - "['', ' ', '']": '', - "['', '', '', '']": '', - "['', '', ' ', '', '']": '', - "['', '']": '', - "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": '', - } - :param master_token: - :return: - """ - master_token = master_token.replace("", "") - master_token = master_token.replace("", " ") - master_token = master_token.replace("", " ") - master_token = master_token.replace("", "\u2028\u2028") - master_token = master_token.replace("", " ") - master_token = master_token.replace("", "") - master_token = master_token.replace("", " ") - master_token = master_token.replace("", "") - master_token = master_token.replace("", " ") - master_token = master_token.replace("", "") - master_token = master_token.replace( - "", " \u2028 \u2028 " - ) - return master_token - - -def distance(box_1, box_2): - x1, y1, x2, y2 = box_1 - x3, y3, x4, y4 = box_2 - dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2) - dis_2 = abs(x3 - x1) + abs(y3 - y1) - dis_3 = abs(x4 - x2) + abs(y4 - y2) - return dis + min(dis_2, dis_3) - - -def compute_iou(rec1, rec2): - """ - computing IoU - :param rec1: (y0, x0, y1, x1), which reflects - (top, left, bottom, right) - :param rec2: (y0, x0, y1, x1) - :return: scala value of IoU - """ - # computing area of each rectangles - S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) - S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) - - # computing the sum_area - sum_area = S_rec1 + S_rec2 - - # find the each edge of intersect rectangle - left_line = max(rec1[1], rec2[1]) - right_line = min(rec1[3], rec2[3]) - top_line = max(rec1[0], rec2[0]) - bottom_line = min(rec1[2], rec2[2]) - - # judge if there is an intersect - if left_line >= right_line or top_line >= bottom_line: - return 0.0 - else: - intersect = (right_line - left_line) * (bottom_line - top_line) - return (intersect / (sum_area - intersect)) * 1.0 diff --git a/rapid_table_torch/table_structure/__init__.py b/rapid_table_torch/table_structure/__init__.py deleted file mode 100644 index 4582f3d..0000000 --- a/rapid_table_torch/table_structure/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .table_structure import TableStructurer diff --git a/rapid_table_torch/table_structure/utils.py b/rapid_table_torch/table_structure/utils.py deleted file mode 100644 index f2d76df..0000000 --- a/rapid_table_torch/table_structure/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com - -def trans_char_ocr_res(ocr_res): - word_result = [] - for res in ocr_res: - score = res[2] - for word_box, word in zip(res[3], res[4]): - word_res = [] - word_res.append(word_box) - word_res.append(word) - word_res.append(score) - word_result.append(word_res) - return word_result diff --git a/rapid_table_torch/utils.py b/rapid_table_torch/utils.py deleted file mode 100644 index e7135c5..0000000 --- a/rapid_table_torch/utils.py +++ /dev/null @@ -1,210 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -import os -from io import BytesIO -from pathlib import Path -from typing import Optional, Union, List - -import cv2 -import numpy as np -from PIL import Image, UnidentifiedImageError - -InputType = Union[str, np.ndarray, bytes, Path] - - -class LoadImage: - def __init__( - self, - ): - pass - - def __call__(self, img: InputType) -> np.ndarray: - if not isinstance(img, InputType.__args__): - raise LoadImageError( - f"The img type {type(img)} does not in {InputType.__args__}" - ) - - img = self.load_img(img) - - if img.ndim == 2: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if img.ndim == 3 and img.shape[2] == 4: - return self.cvt_four_to_three(img) - - return img - - def load_img(self, img: InputType) -> np.ndarray: - if isinstance(img, (str, Path)): - self.verify_exist(img) - try: - img = np.array(Image.open(img)) - img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - except UnidentifiedImageError as e: - raise LoadImageError(f"cannot identify image file {img}") from e - return img - - if isinstance(img, bytes): - img = np.array(Image.open(BytesIO(img))) - img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - return img - - if isinstance(img, np.ndarray): - return img - - raise LoadImageError(f"{type(img)} is not supported!") - - @staticmethod - def cvt_four_to_three(img: np.ndarray) -> np.ndarray: - """RGBA → RGB""" - r, g, b, a = cv2.split(img) - new_img = cv2.merge((b, g, r)) - - not_a = cv2.bitwise_not(a) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(new_img, new_img, mask=a) - new_img = cv2.add(new_img, not_a) - return new_img - - @staticmethod - def verify_exist(file_path: Union[str, Path]): - if not Path(file_path).exists(): - raise LoadImageError(f"{file_path} does not exist.") - - -class LoadImageError(Exception): - pass - - -class VisTable: - def __init__( - self, - ): - self.load_img = LoadImage() - - def __call__( - self, - img_path: Union[str, Path], - table_html_str: str, - save_html_path: Optional[str] = None, - table_cell_bboxes: Optional[np.ndarray] = None, - save_drawed_path: Optional[str] = None, - logic_points: List[List[float]] = None, - save_logic_path: Optional[str] = None, - ) -> None: - if save_html_path: - html_with_border = self.insert_border_style(table_html_str) - self.save_html(save_html_path, html_with_border) - - if table_cell_bboxes is None: - return None - - img = self.load_img(img_path) - - dims_bboxes = table_cell_bboxes.shape[1] - if dims_bboxes == 4: - drawed_img = self.draw_rectangle(img, table_cell_bboxes) - elif dims_bboxes == 8: - drawed_img = self.draw_polylines(img, table_cell_bboxes) - else: - raise ValueError("Shape of table bounding boxes is not between in 4 or 8.") - - if save_drawed_path: - self.save_img(save_drawed_path, drawed_img) - if save_logic_path and logic_points: - polygons = [[box[0],box[1], box[4], box[5]] for box in table_cell_bboxes] - self.plot_rec_box_with_logic_info(img_path, save_logic_path, logic_points, polygons) - return drawed_img - - def insert_border_style(self, table_html_str: str): - style_res = f"""""" - prefix_table, suffix_table = table_html_str.split("") - html_with_border = f"{prefix_table}{style_res}{suffix_table}" - return html_with_border - - def plot_rec_box_with_logic_info(self, img_path, output_path, logic_points, sorted_polygons): - """ - :param img_path - :param output_path - :param logic_points: [row_start,row_end,col_start,col_end] - :param sorted_polygons: [xmin,ymin,xmax,ymax] - :return: - """ - # 读取原图 - img = cv2.imread(img_path) - img = cv2.copyMakeBorder( - img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255] - ) - # 绘制 polygons 矩形 - for idx, polygon in enumerate(sorted_polygons): - x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3] - x0 = round(x0) - y0 = round(y0) - x1 = round(x1) - y1 = round(y1) - cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1) - # 增大字体大小和线宽 - font_scale = 0.9 # 原先是0.5 - thickness = 1 # 原先是1 - logic_point = logic_points[idx] - cv2.putText( - img, - f"row: {logic_point[0]}-{logic_point[1]}", - (x0 + 3, y0 + 8), - cv2.FONT_HERSHEY_PLAIN, - font_scale, - (0, 0, 255), - thickness, - ) - cv2.putText( - img, - f"col: {logic_point[2]}-{logic_point[3]}", - (x0 + 3, y0 + 18), - cv2.FONT_HERSHEY_PLAIN, - font_scale, - (0, 0, 255), - thickness, - ) - os.makedirs(os.path.dirname(output_path), exist_ok=True) - # 保存绘制后的图像 - self.save_img(output_path, img) - - @staticmethod - def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray: - img_copy = img.copy() - for box in boxes.astype(int): - x1, y1, x2, y2 = box - cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2) - return img_copy - - @staticmethod - def draw_polylines(img: np.ndarray, points) -> np.ndarray: - img_copy = img.copy() - for point in points.astype(int): - point = point.reshape(4, 2) - cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2) - return img_copy - - @staticmethod - def save_img(save_path: Union[str, Path], img: np.ndarray): - cv2.imwrite(str(save_path), img) - - @staticmethod - def save_html(save_path: Union[str, Path], html: str): - with open(save_path, "w", encoding="utf-8") as f: - f.write(html) diff --git a/requirements_torch.txt b/requirements_torch.txt deleted file mode 100644 index 4ba3572..0000000 --- a/requirements_torch.txt +++ /dev/null @@ -1,8 +0,0 @@ -onnxruntime>=1.7.0 -opencv_python>=4.5.1.48 -numpy>=1.21.6,<2 -Pillow -requests -torch -torchvision -tokenizers \ No newline at end of file diff --git a/setup.py b/setup.py index 0539473..0489760 100644 --- a/setup.py +++ b/setup.py @@ -66,4 +66,5 @@ def get_readme(): ], python_requires=">=3.6,<3.13", entry_points={"console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.main:main"]}, + extras_require={"torch": ["torch", "torchvision", "tokenizers"]}, ) diff --git a/setup_torch.py b/setup_torch.py deleted file mode 100644 index b332a20..0000000 --- a/setup_torch.py +++ /dev/null @@ -1,69 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -import sys -from pathlib import Path - -import setuptools -from get_pypi_latest_version import GetPyPiLatestVersion - - -def get_readme(): - root_dir = Path(__file__).resolve().parent - readme_path = str(root_dir / "docs" / "doc_whl_rapid_table.md") - with open(readme_path, "r", encoding="utf-8") as f: - readme = f.read() - return readme - - -MODULE_NAME = "rapid_table_torch" -obtainer = GetPyPiLatestVersion() -try: - latest_version = obtainer(MODULE_NAME) -except Exception: - latest_version = "0.0.0" -VERSION_NUM = obtainer.version_add_one(latest_version) - -if len(sys.argv) > 2: - match_str = " ".join(sys.argv[2:]) - matched_versions = obtainer.extract_version(match_str) - if matched_versions: - VERSION_NUM = matched_versions -sys.argv = sys.argv[:2] - -setuptools.setup( - name=MODULE_NAME, - version=VERSION_NUM, - platforms="Any", - long_description=get_readme(), - long_description_content_type="text/markdown", - description="Tools for parsing table structures based pytorch.", - author="SWHL", - author_email="liekkaskono@163.com", - url="https://github.com/RapidAI/RapidTable", - license="Apache-2.0", - include_package_data=True, - install_requires=[ - "opencv_python>=4.5.1.48", - "numpy>=1.21.6", - "Pillow", - "torch", - "torchvision", - "tokenizers" - ], - packages=[ - MODULE_NAME, - f"{MODULE_NAME}.models", - f"{MODULE_NAME}.table_matcher", - f"{MODULE_NAME}.table_structure", - ], - package_data={"": [".gitkeep"]}, - keywords=["unitable,table,rapidocr,rapid_table"], - classifiers=[ - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - ], - python_requires=">=3.10,<3.13", - entry_points={"console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.main:main"]}, -) diff --git a/tests/test_table_torch.py b/tests/test_table_torch.py index 640257c..1eb4505 100644 --- a/tests/test_table_torch.py +++ b/tests/test_table_torch.py @@ -11,10 +11,10 @@ sys.path.append(str(root_dir)) -from rapid_table_torch import RapidTable +from rapid_table import RapidTable ocr_engine = RapidOCR() -table_engine = RapidTable() +table_engine = RapidTable(model_type="unitable") test_file_dir = cur_dir / "test_files" img_path = str(test_file_dir / "table.jpg") @@ -22,14 +22,14 @@ def test_ocr_input(): ocr_res, _ = ocr_engine(img_path) - table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_res) + table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_res) assert table_html_str.count("") == 16 def test_input_ocr_none(): - table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path) + table_html_str, table_cell_bboxes, elapse = table_engine(img_path) assert table_html_str.count("") == 16 def test_logic_points_out(): - table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path) + table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, return_logic_points=True) assert len(table_cell_bboxes) == len(logic_points) From 2e514121079bcca130c09482cfd2365014f778ad Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Mon, 30 Dec 2024 17:09:03 +0800 Subject: [PATCH 2/3] chore: rm extra words --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 89b0f59..d0f6594 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,6 @@ table_engine = RapidTable() 完整示例: -#### onnx版本 ```python from pathlib import Path From 4a70f601ee1120841fc0661b0959305f9df1914d Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Mon, 30 Dec 2024 22:36:56 +0800 Subject: [PATCH 3/3] fix: adapt unitable no content label token --- rapid_table/main.py | 3 +++ rapid_table/table_matcher/matcher.py | 5 ++++- rapid_table/table_structure/table_structure_unitable.py | 8 +++++--- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/rapid_table/main.py b/rapid_table/main.py index 202f6e0..b412790 100644 --- a/rapid_table/main.py +++ b/rapid_table/main.py @@ -65,6 +65,9 @@ def __call__( if self.model_type == "slanet-plus": pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes) pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res) + # 过滤掉占位的bbox + mask = ~np.all(pred_bboxes == 0, axis=1) + pred_bboxes = pred_bboxes[mask] # 避免低版本升级后出现问题,默认不打开 if return_logic_points: logic_points = self.table_matcher.decode_logic_points(pred_structures) diff --git a/rapid_table/table_matcher/matcher.py b/rapid_table/table_matcher/matcher.py index 0453929..bc976ed 100644 --- a/rapid_table/table_matcher/matcher.py +++ b/rapid_table/table_matcher/matcher.py @@ -29,7 +29,7 @@ def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res): pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) return pred_html - def match_result(self, dt_boxes, pred_bboxes): + def match_result(self, dt_boxes, pred_bboxes, min_iou=0.1 ** 8): matched = {} for i, gt_box in enumerate(dt_boxes): distances = [] @@ -49,6 +49,9 @@ def match_result(self, dt_boxes, pred_bboxes): sorted_distances = sorted( sorted_distances, key=lambda item: (item[1], item[0]) ) + # must > min_iou + if sorted_distances[0][1] >= 1 - min_iou: + continue if distances.index(sorted_distances[0]) not in matched.keys(): matched[distances.index(sorted_distances[0])] = [i] else: diff --git a/rapid_table/table_structure/table_structure_unitable.py b/rapid_table/table_structure/table_structure_unitable.py index da9ed7f..fdc42b4 100644 --- a/rapid_table/table_structure/table_structure_unitable.py +++ b/rapid_table/table_structure/table_structure_unitable.py @@ -183,8 +183,8 @@ def decode_tokens(self, context): td_attrs = td_match.group(1).strip() td_content = td_match.group(2).strip() if td_attrs: - decoded_list.append('') decoded_list.append('') else: @@ -197,7 +197,9 @@ def decode_tokens(self, context): # 将坐标转换为从左上角开始顺时针到左下角的点的坐标 coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax]) bbox_coords.append(coords) - + else: + # 填充占位的bbox,保证后续流程统一 + bbox_coords.append(np.array([0, 0, 0,0,0,0, 0, 0])) decoded_list.append('') # 将 bbox_coords 转换为 numpy 数组