From 97b685ff187a7f76c2deee0f9c265eb4b178f4a5 Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Mon, 30 Dec 2024 17:07:04 +0800 Subject: [PATCH 1/3] feat: optimize torch inference --- .github/workflows/publish_whl.yml | 4 + .github/workflows/publish_whl_torch.yml | 59 ----- README.md | 64 +---- demo_torch.py | 50 ++-- .../download_model.py | 24 +- {rapid_table_torch => rapid_table}/logger.py | 0 rapid_table/main.py | 29 +- rapid_table/params.py | 40 +++ rapid_table/table_structure/__init__.py | 1 + .../table_structure_unitable.py | 13 +- .../table_structure/unitable_modules.py | 0 rapid_table/utils.py | 1 + rapid_table_torch/__init__.py | 2 - rapid_table_torch/main.py | 155 ----------- rapid_table_torch/table_matcher/__init__.py | 4 - rapid_table_torch/table_matcher/matcher.py | 192 -------------- rapid_table_torch/table_matcher/utils.py | 248 ------------------ rapid_table_torch/table_structure/__init__.py | 1 - rapid_table_torch/table_structure/utils.py | 28 -- rapid_table_torch/utils.py | 210 --------------- requirements_torch.txt | 8 - setup.py | 1 + setup_torch.py | 69 ----- tests/test_table_torch.py | 10 +- 24 files changed, 130 insertions(+), 1083 deletions(-) delete mode 100644 .github/workflows/publish_whl_torch.yml rename {rapid_table_torch => rapid_table}/download_model.py (70%) rename {rapid_table_torch => rapid_table}/logger.py (100%) create mode 100644 rapid_table/params.py rename rapid_table_torch/table_structure/table_structure.py => rapid_table/table_structure/table_structure_unitable.py (93%) rename rapid_table_torch/table_structure/components.py => rapid_table/table_structure/unitable_modules.py (100%) delete mode 100644 rapid_table_torch/__init__.py delete mode 100644 rapid_table_torch/main.py delete mode 100644 rapid_table_torch/table_matcher/__init__.py delete mode 100644 rapid_table_torch/table_matcher/matcher.py delete mode 100644 rapid_table_torch/table_matcher/utils.py delete mode 100644 rapid_table_torch/table_structure/__init__.py delete mode 100644 rapid_table_torch/table_structure/utils.py delete mode 100644 rapid_table_torch/utils.py delete mode 100644 requirements_torch.txt delete mode 100644 setup_torch.py diff --git a/.github/workflows/publish_whl.yml b/.github/workflows/publish_whl.yml index 99e137c..7e135b3 100644 --- a/.github/workflows/publish_whl.yml +++ b/.github/workflows/publish_whl.yml @@ -34,8 +34,12 @@ jobs: pip install -r requirements.txt pip install rapidocr_onnxruntime + pip install torch + pip install torchvision + pip install tokenizers pip install pytest pytest tests/test_table.py + pytest tests/test_table_torch.py GenerateWHL_PushPyPi: needs: UnitTesting diff --git a/.github/workflows/publish_whl_torch.yml b/.github/workflows/publish_whl_torch.yml deleted file mode 100644 index 09bee91..0000000 --- a/.github/workflows/publish_whl_torch.yml +++ /dev/null @@ -1,59 +0,0 @@ -name: Push rapidocr_table_torch to pypi - -on: - push: - tags: - - torch_v* - -env: - RESOURCES_URL: https://github.com/RapidAI/RapidTable/releases/download/assets/unitable.zip - - -jobs: - UnitTesting: - runs-on: ubuntu-latest - steps: - - name: Pull latest code - uses: actions/checkout@v3 - - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: 'x64' - - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - - name: Unit testings - run: | - pip install -r requirements_torch.txt - pip install rapidocr_onnxruntime - pip install pytest - pytest tests/test_table_torch.py - - GenerateWHL_PushPyPi: - needs: UnitTesting - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: 'x64' - - - name: Run setup_torch.py - run: | - pip install -r requirements_torch.txt - python -m pip install --upgrade pip - pip install wheel get_pypi_latest_version - python setup_torch.py bdist_wheel ${{ github.ref_name }} - - - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@v1.5.0 - with: - password: ${{ secrets.RAPID_TABLE }} - packages_dir: dist/ diff --git a/README.md b/README.md index a621571..89b0f59 100644 --- a/README.md +++ b/README.md @@ -105,8 +105,8 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu ```bash pip install rapidocr_onnxruntime pip install rapid_table -#pip install rapid_table_torch # for unitable inference -#pip install onnxruntime-gpu # for gpu inference +#pip install rapid_table[torch] # for unitable inference +#pip install onnxruntime-gpu # for onnx gpu inference ``` ### 使用方式 @@ -117,6 +117,7 @@ RapidTable类提供model_path参数,可以自行指定上述2个模型,默 ```python table_engine = RapidTable() +# table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable") ``` 完整示例: @@ -132,6 +133,8 @@ from rapid_table.table_structure.utils import trans_char_ocr_res table_engine = RapidTable() # 开启onnx-gpu推理 # table_engine = RapidTable(use_cuda=True) +# 使用torch推理版本的unitable模型 +# table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable") ocr_engine = RapidOCR() viser = VisTable() @@ -159,41 +162,8 @@ viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_p print(table_html_str) ``` -#### torch版本 -```python -from pathlib import Path -from rapidocr_onnxruntime import RapidOCR - -from rapid_table_torch import RapidTable, VisTable -from rapid_table_torch.table_structure.utils import trans_char_ocr_res - -if __name__ == '__main__': -# Init -ocr_engine = RapidOCR() -table_engine = RapidTable(device="cpu") # 默认使用cpu,若使用cuda,则传入device="cuda:0" -viser = VisTable() -img_path = "tests/test_files/image34.png" -# OCR,本模型检测框比较精准,配合单字匹配效果更好 -ocr_result, _ = ocr_engine(img_path, return_word_box=True) -ocr_result = trans_char_ocr_res(ocr_result) -boxes, txts, scores = list(zip(*ocr_result)) -# Save -save_dir = Path("outputs") -save_dir.mkdir(parents=True, exist_ok=True) - -save_html_path = save_dir / f"{Path(img_path).stem}.html" -save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}" -# 返回逻辑坐标 -table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result) -save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}" -vis_imged = viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path, logic_points, - save_logic_path) -print(f"elapse:{elapse}") -``` - #### 终端运行 -##### onnx: - 用法: ```bash @@ -214,29 +184,7 @@ print(f"elapse:{elapse}") ```bash rapid_table -v -img test_images/table.jpg ``` - -##### pytorch: -- 用法: - - ```bash - $ rapid_table_torch -h - usage: rapid_table_torch [-h] [-v] -img IMG_PATH [-d DEVICE] - - optional arguments: - -h, --help show this help message and exit - -v, --vis Whether to visualize the layout results. - -img IMG_PATH, --img_path IMG_PATH - Path to image for layout. - -d DEVICE, --device device - The model device used for inference. - ``` - -- 示例: - - ```bash - rapid_table_torch -v -img test_images/table.jpg - ``` - + ### 结果 #### 返回结果 diff --git a/demo_torch.py b/demo_torch.py index e8dc1e0..72a2d44 100644 --- a/demo_torch.py +++ b/demo_torch.py @@ -4,31 +4,29 @@ from pathlib import Path from rapidocr_onnxruntime import RapidOCR -from rapid_table_torch import RapidTable, VisTable -from rapid_table_torch.table_structure.utils import trans_char_ocr_res +from rapid_table import RapidTable, VisTable +from rapid_table.table_structure.utils import trans_char_ocr_res -if __name__ == '__main__': - # Init - ocr_engine = RapidOCR() - table_engine = RapidTable(encoder_path="rapid_table_torch/models/encoder.pth", - decoder_path="rapid_table_torch/models/decoder.pth", - vocab_path="rapid_table_torch/models/vocab.json", - device="cpu") - viser = VisTable() - img_path = "tests/test_files/image34.png" - # OCR - ocr_result, _ = ocr_engine(img_path, return_word_box=True) - ocr_result = trans_char_ocr_res(ocr_result) - boxes, txts, scores = list(zip(*ocr_result)) - # Save - save_dir = Path("outputs") - save_dir.mkdir(parents=True, exist_ok=True) +ocr_engine = RapidOCR() +table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable") +viser = VisTable() +img_path = "tests/test_files/table.jpg" +# OCR +ocr_result, _ = ocr_engine(img_path, return_word_box=True) +ocr_result = trans_char_ocr_res(ocr_result) +boxes, txts, scores = list(zip(*ocr_result)) +# Save +save_dir = Path("outputs") +save_dir.mkdir(parents=True, exist_ok=True) - save_html_path = save_dir / f"{Path(img_path).stem}.html" - save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}" - # 返回逻辑坐标 - table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result) - save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}" - vis_imged = viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path, logic_points, - save_logic_path) - print(f"elapse:{elapse}") +save_html_path = save_dir / f"{Path(img_path).stem}.html" +save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}" + +table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result) +viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path) +# 返回逻辑坐标 +# table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result, return_logic_points=True) +# save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}" +# vis_imged = viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path, logic_points, +# save_logic_path) +print(f"elapse:{elapse}") diff --git a/rapid_table_torch/download_model.py b/rapid_table/download_model.py similarity index 70% rename from rapid_table_torch/download_model.py rename to rapid_table/download_model.py index d636f94..9e9c496 100644 --- a/rapid_table_torch/download_model.py +++ b/rapid_table/download_model.py @@ -10,11 +10,33 @@ logger = get_logger("DownloadModel") CUR_DIR = Path(__file__).resolve() PROJECT_DIR = CUR_DIR.parent - +ROOT_URL = "https://www.modelscope.cn/studio/jockerK/TableRec/resolve/master/models/table_rec/unitable/" +KEY_TO_MODEL_URL = { + "unitable": { + "encoder": f"{ROOT_URL}/encoder.pth", + "decoder": f"{ROOT_URL}/decoder.pth", + "vocab": f"{ROOT_URL}/vocab.json", + } +} class DownloadModel: cur_dir = PROJECT_DIR + @staticmethod + def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, None]) -> str: + if path is not None: + return path + + model_url = KEY_TO_MODEL_URL.get(model_type, {}).get(sub_file_type, None) + if model_url: + model_path = DownloadModel.download(model_url) + return model_path + + logger.info( + "model url is None, using the default download model %s", path + ) + return path + @classmethod def download(cls, model_full_url: Union[str, Path]) -> str: save_dir = cls.cur_dir / "models" diff --git a/rapid_table_torch/logger.py b/rapid_table/logger.py similarity index 100% rename from rapid_table_torch/logger.py rename to rapid_table/logger.py diff --git a/rapid_table/main.py b/rapid_table/main.py index d620fe5..202f6e0 100644 --- a/rapid_table/main.py +++ b/rapid_table/main.py @@ -5,33 +5,34 @@ import copy import importlib import time +from dataclasses import asdict from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import cv2 import numpy as np +from .download_model import DownloadModel +from .params import accept_kwargs_as_dataclass, BaseConfig from .table_matcher import TableMatch -from .table_structure import TableStructurer +from .table_structure import TableStructurer, TableStructureUnitable from .utils import LoadImage, VisTable root_dir = Path(__file__).resolve().parent class RapidTable: - def __init__(self, model_path: Optional[str] = None, model_type: str = None, use_cuda: bool = False): - if model_path is None: - model_path = str( - root_dir / "models" / "slanet-plus.onnx" - ) - model_type = "slanet-plus" - self.model_type = model_type + @accept_kwargs_as_dataclass(BaseConfig) + def __init__(self, config: BaseConfig): + self.model_type = config.model_type self.load_img = LoadImage() - config = { - "model_path": model_path, - "use_cuda": use_cuda, - } - self.table_structure = TableStructurer(config) + if self.model_type == "unitable": + config.encoder_path = DownloadModel.get_model_path(self.model_type, "encoder", config.encoder_path) + config.decoder_path = DownloadModel.get_model_path(self.model_type, "decoder", config.decoder_path) + config.vocab_path = DownloadModel.get_model_path(self.model_type, "vocab", config.vocab_path) + self.table_structure = TableStructureUnitable(asdict(config)) + else: + self.table_structure = TableStructurer(asdict(config)) self.table_matcher = TableMatch() try: diff --git a/rapid_table/params.py b/rapid_table/params.py new file mode 100644 index 0000000..c5756aa --- /dev/null +++ b/rapid_table/params.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass, fields +from functools import wraps +from pathlib import Path + +from rapid_table.logger import get_logger + +root_dir = Path(__file__).resolve().parent +logger = get_logger("params") + +@dataclass +class BaseConfig: + model_type: str = "slanet-plus" + model_path: str = str(root_dir / "models" / "slanet-plus.onnx") + use_cuda: bool = False + device: str = "cpu" + encoder_path: str = None + decoder_path: str = None + vocab_path: str = None + + +def accept_kwargs_as_dataclass(cls): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) == 2 and isinstance(args[1], cls): + # 如果已经传递了 ModelConfig 实例,直接调用函数 + return func(*args, **kwargs) + else: + # 提取 cls 中定义的字段 + cls_fields = {field.name for field in fields(cls)} + # 过滤掉未定义的字段 + filtered_kwargs = {k: v for k, v in kwargs.items() if k in cls_fields} + # 发出警告对于未定义的字段 + for k in (kwargs.keys() - cls_fields): + logger.warning(f"Warning: '{k}' is not a valid field in {cls.__name__} and will be ignored.") + # 创建 ModelConfig 实例并调用函数 + config = cls(**filtered_kwargs) + return func(args[0], config=config) + return wrapper + return decorator \ No newline at end of file diff --git a/rapid_table/table_structure/__init__.py b/rapid_table/table_structure/__init__.py index a548391..3e638d9 100644 --- a/rapid_table/table_structure/__init__.py +++ b/rapid_table/table_structure/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from .table_structure import TableStructurer +from .table_structure_unitable import TableStructureUnitable diff --git a/rapid_table_torch/table_structure/table_structure.py b/rapid_table/table_structure/table_structure_unitable.py similarity index 93% rename from rapid_table_torch/table_structure/table_structure.py rename to rapid_table/table_structure/table_structure_unitable.py index 32f3fa0..da9ed7f 100644 --- a/rapid_table_torch/table_structure/table_structure.py +++ b/rapid_table/table_structure/table_structure_unitable.py @@ -1,5 +1,6 @@ import re import time +from typing import Dict, Any import cv2 import numpy as np @@ -7,7 +8,7 @@ from PIL import Image from tokenizers import Tokenizer -from .components import Encoder, GPTFastDecoder +from .unitable_modules import Encoder, GPTFastDecoder from torchvision import transforms IMG_SIZE = 448 @@ -77,8 +78,14 @@ ] -class TableStructurer: - def __init__(self, encoder_path: str, decoder_path: str, vocab_path: str, device: str): +class TableStructureUnitable: + def __init__(self, config:Dict[str, Any]): + # encoder_path: str, decoder_path: str, vocab_path: str, device: str + vocab_path = config["vocab_path"] + encoder_path = config["encoder_path"] + decoder_path = config["decoder_path"] + device = config.get("device", "cuda:0") if config["use_cuda"] else "cpu" + self.vocab = Tokenizer.from_file(vocab_path) self.token_white_list = [self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS] self.bbox_token_ids = set([self.vocab.token_to_id(i) for i in BBOX_TOKENS]) diff --git a/rapid_table_torch/table_structure/components.py b/rapid_table/table_structure/unitable_modules.py similarity index 100% rename from rapid_table_torch/table_structure/components.py rename to rapid_table/table_structure/unitable_modules.py diff --git a/rapid_table/utils.py b/rapid_table/utils.py index e7135c5..77224ce 100644 --- a/rapid_table/utils.py +++ b/rapid_table/utils.py @@ -208,3 +208,4 @@ def save_img(save_path: Union[str, Path], img: np.ndarray): def save_html(save_path: Union[str, Path], html: str): with open(save_path, "w", encoding="utf-8") as f: f.write(html) + diff --git a/rapid_table_torch/__init__.py b/rapid_table_torch/__init__.py deleted file mode 100644 index 86c7b1c..0000000 --- a/rapid_table_torch/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .main import RapidTable -from .utils import VisTable diff --git a/rapid_table_torch/main.py b/rapid_table_torch/main.py deleted file mode 100644 index 5ef3dcc..0000000 --- a/rapid_table_torch/main.py +++ /dev/null @@ -1,155 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -import argparse -import copy -import importlib -import os -import time -from pathlib import Path -from typing import List, Union - -import cv2 -import numpy as np - -from .download_model import DownloadModel -from .logger import get_logger -from .table_matcher import TableMatch -from .table_structure import TableStructurer -from .utils import LoadImage, VisTable - -root_dir = Path(__file__).resolve().parent -model_dir = os.path.join(root_dir, "models") -logger = get_logger("rapid_table_torch") -default_config = os.path.join(root_dir, "config.yaml") -ROOT_URL = "https://www.modelscope.cn/studio/jockerK/TableRec/resolve/master/models/table_rec/unitable/" -KEY_TO_MODEL_URL = { - "unitable": { - "encoder": f"{ROOT_URL}/encoder.pth", - "decoder": f"{ROOT_URL}/decoder.pth", - "vocab": f"{ROOT_URL}/vocab.json", - } -} - - -class RapidTable: - def __init__(self, encoder_path: str = None, decoder_path: str = None, vocab_path: str = None, - model_type: str = "unitable", - device: str = "cpu"): - self.model_type = model_type - self.load_img = LoadImage() - encoder_path = self.get_model_path(model_type, "encoder", encoder_path) - decoder_path = self.get_model_path(model_type, "decoder", decoder_path) - vocab_path = self.get_model_path(model_type, "vocab", vocab_path) - self.table_structure = TableStructurer(encoder_path, decoder_path, vocab_path, device) - self.table_matcher = TableMatch() - try: - self.ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR() - except ModuleNotFoundError: - self.ocr_engine = None - - def __call__( - self, - img_content: Union[str, np.ndarray, bytes, Path], - ocr_result: List[Union[List[List[float]], str, str]] = None - ): - if self.ocr_engine is None and ocr_result is None: - raise ValueError( - "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed." - ) - - img = self.load_img(img_content) - - s = time.time() - h, w = img.shape[:2] - if ocr_result is None: - ocr_result, _ = self.ocr_engine(img) - dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w) - - pred_structures, pred_bboxes, _ = self.table_structure(copy.deepcopy(img)) - pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res) - logic_points = self.table_matcher.decode_logic_points(pred_structures) - elapse = time.time() - s - return pred_html, pred_bboxes, logic_points, elapse - - def get_boxes_recs( - self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int - ): - dt_boxes, rec_res, scores = list(zip(*ocr_result)) - rec_res = list(zip(rec_res, scores)) - - r_boxes = [] - for box in dt_boxes: - box = np.array(box) - x_min = max(0, box[:, 0].min() - 1) - x_max = min(w, box[:, 0].max() + 1) - y_min = max(0, box[:, 1].min() - 1) - y_max = min(h, box[:, 1].max() + 1) - box = [x_min, y_min, x_max, y_max] - r_boxes.append(box) - dt_boxes = np.array(r_boxes) - return dt_boxes, rec_res - - @staticmethod - def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, None]) -> str: - if path is not None: - return path - - model_url = KEY_TO_MODEL_URL.get(model_type, {}).get(sub_file_type, None) - if model_url: - model_path = DownloadModel.download(model_url) - return model_path - - logger.info( - "model url is None, using the default download model %s", path - ) - return path -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "-v", - "--vis", - action="store_true", - help="Wheter to visualize the layout results.", - ) - parser.add_argument( - "-img", "--img_path", type=str, required=True, help="Path to image for layout." - ) - parser.add_argument( - "-d", "--device", type=str, default="cpu", help="device to use" - ) - args = parser.parse_args() - - try: - ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR() - except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime." - ) from exc - - rapid_table = RapidTable(device=args.device) - - img = cv2.imread(args.img_path) - - ocr_result, _ = ocr_engine(img) - table_html_str, table_cell_bboxes, elapse = rapid_table(img, ocr_result) - print(table_html_str) - - viser = VisTable() - if args.vis: - img_path = Path(args.img_path) - - save_dir = img_path.resolve().parent - save_html_path = save_dir / f"{Path(img_path).stem}.html" - save_drawed_path = save_dir / f"vis_{Path(img_path).name}" - viser( - img_path, - table_html_str, - save_html_path, - table_cell_bboxes, - save_drawed_path, - ) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/rapid_table_torch/table_matcher/__init__.py b/rapid_table_torch/table_matcher/__init__.py deleted file mode 100644 index 9bff7d7..0000000 --- a/rapid_table_torch/table_matcher/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -from .matcher import TableMatch diff --git a/rapid_table_torch/table_matcher/matcher.py b/rapid_table_torch/table_matcher/matcher.py deleted file mode 100644 index 0453929..0000000 --- a/rapid_table_torch/table_matcher/matcher.py +++ /dev/null @@ -1,192 +0,0 @@ -# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -*- encoding: utf-8 -*- -import numpy as np - -from .utils import compute_iou, distance - - -class TableMatch: - def __init__(self, filter_ocr_result=True, use_master=False): - self.filter_ocr_result = filter_ocr_result - self.use_master = use_master - - def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res): - if self.filter_ocr_result: - dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, rec_res) - matched_index = self.match_result(dt_boxes, pred_bboxes) - pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) - return pred_html - - def match_result(self, dt_boxes, pred_bboxes): - matched = {} - for i, gt_box in enumerate(dt_boxes): - distances = [] - for j, pred_box in enumerate(pred_bboxes): - if len(pred_box) == 8: - pred_box = [ - np.min(pred_box[0::2]), - np.min(pred_box[1::2]), - np.max(pred_box[0::2]), - np.max(pred_box[1::2]), - ] - distances.append( - (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box)) - ) # compute iou and l1 distance - sorted_distances = distances.copy() - # select det box by iou and l1 distance - sorted_distances = sorted( - sorted_distances, key=lambda item: (item[1], item[0]) - ) - if distances.index(sorted_distances[0]) not in matched.keys(): - matched[distances.index(sorted_distances[0])] = [i] - else: - matched[distances.index(sorted_distances[0])].append(i) - return matched - - def get_pred_html(self, pred_structures, matched_index, ocr_contents): - end_html = [] - td_index = 0 - for tag in pred_structures: - if "" not in tag: - end_html.append(tag) - continue - - if "