diff --git a/.github/workflows/publish_whl.yml b/.github/workflows/publish_whl.yml
index 99e137c..7e135b3 100644
--- a/.github/workflows/publish_whl.yml
+++ b/.github/workflows/publish_whl.yml
@@ -34,8 +34,12 @@ jobs:
pip install -r requirements.txt
pip install rapidocr_onnxruntime
+ pip install torch
+ pip install torchvision
+ pip install tokenizers
pip install pytest
pytest tests/test_table.py
+ pytest tests/test_table_torch.py
GenerateWHL_PushPyPi:
needs: UnitTesting
diff --git a/.github/workflows/publish_whl_torch.yml b/.github/workflows/publish_whl_torch.yml
deleted file mode 100644
index 09bee91..0000000
--- a/.github/workflows/publish_whl_torch.yml
+++ /dev/null
@@ -1,59 +0,0 @@
-name: Push rapidocr_table_torch to pypi
-
-on:
- push:
- tags:
- - torch_v*
-
-env:
- RESOURCES_URL: https://github.com/RapidAI/RapidTable/releases/download/assets/unitable.zip
-
-
-jobs:
- UnitTesting:
- runs-on: ubuntu-latest
- steps:
- - name: Pull latest code
- uses: actions/checkout@v3
-
- - name: Set up Python 3.10
- uses: actions/setup-python@v4
- with:
- python-version: '3.10'
- architecture: 'x64'
-
- - name: Display Python version
- run: python -c "import sys; print(sys.version)"
-
- - name: Unit testings
- run: |
- pip install -r requirements_torch.txt
- pip install rapidocr_onnxruntime
- pip install pytest
- pytest tests/test_table_torch.py
-
- GenerateWHL_PushPyPi:
- needs: UnitTesting
- runs-on: ubuntu-latest
-
- steps:
- - uses: actions/checkout@v3
-
- - name: Set up Python 3.10
- uses: actions/setup-python@v4
- with:
- python-version: '3.10'
- architecture: 'x64'
-
- - name: Run setup_torch.py
- run: |
- pip install -r requirements_torch.txt
- python -m pip install --upgrade pip
- pip install wheel get_pypi_latest_version
- python setup_torch.py bdist_wheel ${{ github.ref_name }}
-
- - name: Publish distribution 📦 to PyPI
- uses: pypa/gh-action-pypi-publish@v1.5.0
- with:
- password: ${{ secrets.RAPID_TABLE }}
- packages_dir: dist/
diff --git a/README.md b/README.md
index a621571..d0f6594 100644
--- a/README.md
+++ b/README.md
@@ -105,8 +105,8 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu
```bash
pip install rapidocr_onnxruntime
pip install rapid_table
-#pip install rapid_table_torch # for unitable inference
-#pip install onnxruntime-gpu # for gpu inference
+#pip install rapid_table[torch] # for unitable inference
+#pip install onnxruntime-gpu # for onnx gpu inference
```
### 使用方式
@@ -117,11 +117,11 @@ RapidTable类提供model_path参数,可以自行指定上述2个模型,默
```python
table_engine = RapidTable()
+# table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable")
```
完整示例:
-#### onnx版本
```python
from pathlib import Path
@@ -132,6 +132,8 @@ from rapid_table.table_structure.utils import trans_char_ocr_res
table_engine = RapidTable()
# 开启onnx-gpu推理
# table_engine = RapidTable(use_cuda=True)
+# 使用torch推理版本的unitable模型
+# table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable")
ocr_engine = RapidOCR()
viser = VisTable()
@@ -159,41 +161,8 @@ viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_p
print(table_html_str)
```
-#### torch版本
-```python
-from pathlib import Path
-from rapidocr_onnxruntime import RapidOCR
-
-from rapid_table_torch import RapidTable, VisTable
-from rapid_table_torch.table_structure.utils import trans_char_ocr_res
-
-if __name__ == '__main__':
-# Init
-ocr_engine = RapidOCR()
-table_engine = RapidTable(device="cpu") # 默认使用cpu,若使用cuda,则传入device="cuda:0"
-viser = VisTable()
-img_path = "tests/test_files/image34.png"
-# OCR,本模型检测框比较精准,配合单字匹配效果更好
-ocr_result, _ = ocr_engine(img_path, return_word_box=True)
-ocr_result = trans_char_ocr_res(ocr_result)
-boxes, txts, scores = list(zip(*ocr_result))
-# Save
-save_dir = Path("outputs")
-save_dir.mkdir(parents=True, exist_ok=True)
-
-save_html_path = save_dir / f"{Path(img_path).stem}.html"
-save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
-# 返回逻辑坐标
-table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result)
-save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}"
-vis_imged = viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path, logic_points,
- save_logic_path)
-print(f"elapse:{elapse}")
-```
-
#### 终端运行
-##### onnx:
- 用法:
```bash
@@ -214,29 +183,7 @@ print(f"elapse:{elapse}")
```bash
rapid_table -v -img test_images/table.jpg
```
-
-##### pytorch:
-- 用法:
-
- ```bash
- $ rapid_table_torch -h
- usage: rapid_table_torch [-h] [-v] -img IMG_PATH [-d DEVICE]
-
- optional arguments:
- -h, --help show this help message and exit
- -v, --vis Whether to visualize the layout results.
- -img IMG_PATH, --img_path IMG_PATH
- Path to image for layout.
- -d DEVICE, --device device
- The model device used for inference.
- ```
-
-- 示例:
-
- ```bash
- rapid_table_torch -v -img test_images/table.jpg
- ```
-
+
### 结果
#### 返回结果
diff --git a/demo_torch.py b/demo_torch.py
index e8dc1e0..72a2d44 100644
--- a/demo_torch.py
+++ b/demo_torch.py
@@ -4,31 +4,29 @@
from pathlib import Path
from rapidocr_onnxruntime import RapidOCR
-from rapid_table_torch import RapidTable, VisTable
-from rapid_table_torch.table_structure.utils import trans_char_ocr_res
+from rapid_table import RapidTable, VisTable
+from rapid_table.table_structure.utils import trans_char_ocr_res
-if __name__ == '__main__':
- # Init
- ocr_engine = RapidOCR()
- table_engine = RapidTable(encoder_path="rapid_table_torch/models/encoder.pth",
- decoder_path="rapid_table_torch/models/decoder.pth",
- vocab_path="rapid_table_torch/models/vocab.json",
- device="cpu")
- viser = VisTable()
- img_path = "tests/test_files/image34.png"
- # OCR
- ocr_result, _ = ocr_engine(img_path, return_word_box=True)
- ocr_result = trans_char_ocr_res(ocr_result)
- boxes, txts, scores = list(zip(*ocr_result))
- # Save
- save_dir = Path("outputs")
- save_dir.mkdir(parents=True, exist_ok=True)
+ocr_engine = RapidOCR()
+table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable")
+viser = VisTable()
+img_path = "tests/test_files/table.jpg"
+# OCR
+ocr_result, _ = ocr_engine(img_path, return_word_box=True)
+ocr_result = trans_char_ocr_res(ocr_result)
+boxes, txts, scores = list(zip(*ocr_result))
+# Save
+save_dir = Path("outputs")
+save_dir.mkdir(parents=True, exist_ok=True)
- save_html_path = save_dir / f"{Path(img_path).stem}.html"
- save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
- # 返回逻辑坐标
- table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result)
- save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}"
- vis_imged = viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path, logic_points,
- save_logic_path)
- print(f"elapse:{elapse}")
+save_html_path = save_dir / f"{Path(img_path).stem}.html"
+save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
+
+table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result)
+viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path)
+# 返回逻辑坐标
+# table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result, return_logic_points=True)
+# save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}"
+# vis_imged = viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path, logic_points,
+# save_logic_path)
+print(f"elapse:{elapse}")
diff --git a/rapid_table_torch/download_model.py b/rapid_table/download_model.py
similarity index 70%
rename from rapid_table_torch/download_model.py
rename to rapid_table/download_model.py
index d636f94..9e9c496 100644
--- a/rapid_table_torch/download_model.py
+++ b/rapid_table/download_model.py
@@ -10,11 +10,33 @@
logger = get_logger("DownloadModel")
CUR_DIR = Path(__file__).resolve()
PROJECT_DIR = CUR_DIR.parent
-
+ROOT_URL = "https://www.modelscope.cn/studio/jockerK/TableRec/resolve/master/models/table_rec/unitable/"
+KEY_TO_MODEL_URL = {
+ "unitable": {
+ "encoder": f"{ROOT_URL}/encoder.pth",
+ "decoder": f"{ROOT_URL}/decoder.pth",
+ "vocab": f"{ROOT_URL}/vocab.json",
+ }
+}
class DownloadModel:
cur_dir = PROJECT_DIR
+ @staticmethod
+ def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, None]) -> str:
+ if path is not None:
+ return path
+
+ model_url = KEY_TO_MODEL_URL.get(model_type, {}).get(sub_file_type, None)
+ if model_url:
+ model_path = DownloadModel.download(model_url)
+ return model_path
+
+ logger.info(
+ "model url is None, using the default download model %s", path
+ )
+ return path
+
@classmethod
def download(cls, model_full_url: Union[str, Path]) -> str:
save_dir = cls.cur_dir / "models"
diff --git a/rapid_table_torch/logger.py b/rapid_table/logger.py
similarity index 100%
rename from rapid_table_torch/logger.py
rename to rapid_table/logger.py
diff --git a/rapid_table/main.py b/rapid_table/main.py
index d620fe5..b412790 100644
--- a/rapid_table/main.py
+++ b/rapid_table/main.py
@@ -5,33 +5,34 @@
import copy
import importlib
import time
+from dataclasses import asdict
from pathlib import Path
-from typing import List, Optional, Tuple, Union
+from typing import List, Tuple, Union
import cv2
import numpy as np
+from .download_model import DownloadModel
+from .params import accept_kwargs_as_dataclass, BaseConfig
from .table_matcher import TableMatch
-from .table_structure import TableStructurer
+from .table_structure import TableStructurer, TableStructureUnitable
from .utils import LoadImage, VisTable
root_dir = Path(__file__).resolve().parent
class RapidTable:
- def __init__(self, model_path: Optional[str] = None, model_type: str = None, use_cuda: bool = False):
- if model_path is None:
- model_path = str(
- root_dir / "models" / "slanet-plus.onnx"
- )
- model_type = "slanet-plus"
- self.model_type = model_type
+ @accept_kwargs_as_dataclass(BaseConfig)
+ def __init__(self, config: BaseConfig):
+ self.model_type = config.model_type
self.load_img = LoadImage()
- config = {
- "model_path": model_path,
- "use_cuda": use_cuda,
- }
- self.table_structure = TableStructurer(config)
+ if self.model_type == "unitable":
+ config.encoder_path = DownloadModel.get_model_path(self.model_type, "encoder", config.encoder_path)
+ config.decoder_path = DownloadModel.get_model_path(self.model_type, "decoder", config.decoder_path)
+ config.vocab_path = DownloadModel.get_model_path(self.model_type, "vocab", config.vocab_path)
+ self.table_structure = TableStructureUnitable(asdict(config))
+ else:
+ self.table_structure = TableStructurer(asdict(config))
self.table_matcher = TableMatch()
try:
@@ -64,6 +65,9 @@ def __call__(
if self.model_type == "slanet-plus":
pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes)
pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res)
+ # 过滤掉占位的bbox
+ mask = ~np.all(pred_bboxes == 0, axis=1)
+ pred_bboxes = pred_bboxes[mask]
# 避免低版本升级后出现问题,默认不打开
if return_logic_points:
logic_points = self.table_matcher.decode_logic_points(pred_structures)
diff --git a/rapid_table/params.py b/rapid_table/params.py
new file mode 100644
index 0000000..c5756aa
--- /dev/null
+++ b/rapid_table/params.py
@@ -0,0 +1,40 @@
+from dataclasses import dataclass, fields
+from functools import wraps
+from pathlib import Path
+
+from rapid_table.logger import get_logger
+
+root_dir = Path(__file__).resolve().parent
+logger = get_logger("params")
+
+@dataclass
+class BaseConfig:
+ model_type: str = "slanet-plus"
+ model_path: str = str(root_dir / "models" / "slanet-plus.onnx")
+ use_cuda: bool = False
+ device: str = "cpu"
+ encoder_path: str = None
+ decoder_path: str = None
+ vocab_path: str = None
+
+
+def accept_kwargs_as_dataclass(cls):
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if len(args) == 2 and isinstance(args[1], cls):
+ # 如果已经传递了 ModelConfig 实例,直接调用函数
+ return func(*args, **kwargs)
+ else:
+ # 提取 cls 中定义的字段
+ cls_fields = {field.name for field in fields(cls)}
+ # 过滤掉未定义的字段
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in cls_fields}
+ # 发出警告对于未定义的字段
+ for k in (kwargs.keys() - cls_fields):
+ logger.warning(f"Warning: '{k}' is not a valid field in {cls.__name__} and will be ignored.")
+ # 创建 ModelConfig 实例并调用函数
+ config = cls(**filtered_kwargs)
+ return func(args[0], config=config)
+ return wrapper
+ return decorator
\ No newline at end of file
diff --git a/rapid_table/table_matcher/matcher.py b/rapid_table/table_matcher/matcher.py
index 0453929..bc976ed 100644
--- a/rapid_table/table_matcher/matcher.py
+++ b/rapid_table/table_matcher/matcher.py
@@ -29,7 +29,7 @@ def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res):
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
return pred_html
- def match_result(self, dt_boxes, pred_bboxes):
+ def match_result(self, dt_boxes, pred_bboxes, min_iou=0.1 ** 8):
matched = {}
for i, gt_box in enumerate(dt_boxes):
distances = []
@@ -49,6 +49,9 @@ def match_result(self, dt_boxes, pred_bboxes):
sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0])
)
+ # must > min_iou
+ if sorted_distances[0][1] >= 1 - min_iou:
+ continue
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
diff --git a/rapid_table/table_structure/__init__.py b/rapid_table/table_structure/__init__.py
index a548391..3e638d9 100644
--- a/rapid_table/table_structure/__init__.py
+++ b/rapid_table/table_structure/__init__.py
@@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .table_structure import TableStructurer
+from .table_structure_unitable import TableStructureUnitable
diff --git a/rapid_table_torch/table_structure/table_structure.py b/rapid_table/table_structure/table_structure_unitable.py
similarity index 90%
rename from rapid_table_torch/table_structure/table_structure.py
rename to rapid_table/table_structure/table_structure_unitable.py
index 32f3fa0..fdc42b4 100644
--- a/rapid_table_torch/table_structure/table_structure.py
+++ b/rapid_table/table_structure/table_structure_unitable.py
@@ -1,5 +1,6 @@
import re
import time
+from typing import Dict, Any
import cv2
import numpy as np
@@ -7,7 +8,7 @@
from PIL import Image
from tokenizers import Tokenizer
-from .components import Encoder, GPTFastDecoder
+from .unitable_modules import Encoder, GPTFastDecoder
from torchvision import transforms
IMG_SIZE = 448
@@ -77,8 +78,14 @@
]
-class TableStructurer:
- def __init__(self, encoder_path: str, decoder_path: str, vocab_path: str, device: str):
+class TableStructureUnitable:
+ def __init__(self, config:Dict[str, Any]):
+ # encoder_path: str, decoder_path: str, vocab_path: str, device: str
+ vocab_path = config["vocab_path"]
+ encoder_path = config["encoder_path"]
+ decoder_path = config["decoder_path"]
+ device = config.get("device", "cuda:0") if config["use_cuda"] else "cpu"
+
self.vocab = Tokenizer.from_file(vocab_path)
self.token_white_list = [self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS]
self.bbox_token_ids = set([self.vocab.token_to_id(i) for i in BBOX_TOKENS])
@@ -176,8 +183,8 @@ def decode_tokens(self, context):
td_attrs = td_match.group(1).strip()
td_content = td_match.group(2).strip()
if td_attrs:
- decoded_list.append('
')
decoded_list.append(' | ')
else:
@@ -190,7 +197,9 @@ def decode_tokens(self, context):
# 将坐标转换为从左上角开始顺时针到左下角的点的坐标
coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
bbox_coords.append(coords)
-
+ else:
+ # 填充占位的bbox,保证后续流程统一
+ bbox_coords.append(np.array([0, 0, 0,0,0,0, 0, 0]))
decoded_list.append('')
# 将 bbox_coords 转换为 numpy 数组
diff --git a/rapid_table_torch/table_structure/components.py b/rapid_table/table_structure/unitable_modules.py
similarity index 100%
rename from rapid_table_torch/table_structure/components.py
rename to rapid_table/table_structure/unitable_modules.py
diff --git a/rapid_table/utils.py b/rapid_table/utils.py
index e7135c5..77224ce 100644
--- a/rapid_table/utils.py
+++ b/rapid_table/utils.py
@@ -208,3 +208,4 @@ def save_img(save_path: Union[str, Path], img: np.ndarray):
def save_html(save_path: Union[str, Path], html: str):
with open(save_path, "w", encoding="utf-8") as f:
f.write(html)
+
diff --git a/rapid_table_torch/__init__.py b/rapid_table_torch/__init__.py
deleted file mode 100644
index 86c7b1c..0000000
--- a/rapid_table_torch/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .main import RapidTable
-from .utils import VisTable
diff --git a/rapid_table_torch/main.py b/rapid_table_torch/main.py
deleted file mode 100644
index 5ef3dcc..0000000
--- a/rapid_table_torch/main.py
+++ /dev/null
@@ -1,155 +0,0 @@
-# -*- encoding: utf-8 -*-
-# @Author: SWHL
-# @Contact: liekkaskono@163.com
-import argparse
-import copy
-import importlib
-import os
-import time
-from pathlib import Path
-from typing import List, Union
-
-import cv2
-import numpy as np
-
-from .download_model import DownloadModel
-from .logger import get_logger
-from .table_matcher import TableMatch
-from .table_structure import TableStructurer
-from .utils import LoadImage, VisTable
-
-root_dir = Path(__file__).resolve().parent
-model_dir = os.path.join(root_dir, "models")
-logger = get_logger("rapid_table_torch")
-default_config = os.path.join(root_dir, "config.yaml")
-ROOT_URL = "https://www.modelscope.cn/studio/jockerK/TableRec/resolve/master/models/table_rec/unitable/"
-KEY_TO_MODEL_URL = {
- "unitable": {
- "encoder": f"{ROOT_URL}/encoder.pth",
- "decoder": f"{ROOT_URL}/decoder.pth",
- "vocab": f"{ROOT_URL}/vocab.json",
- }
-}
-
-
-class RapidTable:
- def __init__(self, encoder_path: str = None, decoder_path: str = None, vocab_path: str = None,
- model_type: str = "unitable",
- device: str = "cpu"):
- self.model_type = model_type
- self.load_img = LoadImage()
- encoder_path = self.get_model_path(model_type, "encoder", encoder_path)
- decoder_path = self.get_model_path(model_type, "decoder", decoder_path)
- vocab_path = self.get_model_path(model_type, "vocab", vocab_path)
- self.table_structure = TableStructurer(encoder_path, decoder_path, vocab_path, device)
- self.table_matcher = TableMatch()
- try:
- self.ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
- except ModuleNotFoundError:
- self.ocr_engine = None
-
- def __call__(
- self,
- img_content: Union[str, np.ndarray, bytes, Path],
- ocr_result: List[Union[List[List[float]], str, str]] = None
- ):
- if self.ocr_engine is None and ocr_result is None:
- raise ValueError(
- "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
- )
-
- img = self.load_img(img_content)
-
- s = time.time()
- h, w = img.shape[:2]
- if ocr_result is None:
- ocr_result, _ = self.ocr_engine(img)
- dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
-
- pred_structures, pred_bboxes, _ = self.table_structure(copy.deepcopy(img))
- pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res)
- logic_points = self.table_matcher.decode_logic_points(pred_structures)
- elapse = time.time() - s
- return pred_html, pred_bboxes, logic_points, elapse
-
- def get_boxes_recs(
- self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
- ):
- dt_boxes, rec_res, scores = list(zip(*ocr_result))
- rec_res = list(zip(rec_res, scores))
-
- r_boxes = []
- for box in dt_boxes:
- box = np.array(box)
- x_min = max(0, box[:, 0].min() - 1)
- x_max = min(w, box[:, 0].max() + 1)
- y_min = max(0, box[:, 1].min() - 1)
- y_max = min(h, box[:, 1].max() + 1)
- box = [x_min, y_min, x_max, y_max]
- r_boxes.append(box)
- dt_boxes = np.array(r_boxes)
- return dt_boxes, rec_res
-
- @staticmethod
- def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, None]) -> str:
- if path is not None:
- return path
-
- model_url = KEY_TO_MODEL_URL.get(model_type, {}).get(sub_file_type, None)
- if model_url:
- model_path = DownloadModel.download(model_url)
- return model_path
-
- logger.info(
- "model url is None, using the default download model %s", path
- )
- return path
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-v",
- "--vis",
- action="store_true",
- help="Wheter to visualize the layout results.",
- )
- parser.add_argument(
- "-img", "--img_path", type=str, required=True, help="Path to image for layout."
- )
- parser.add_argument(
- "-d", "--device", type=str, default="cpu", help="device to use"
- )
- args = parser.parse_args()
-
- try:
- ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
- except ModuleNotFoundError as exc:
- raise ModuleNotFoundError(
- "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
- ) from exc
-
- rapid_table = RapidTable(device=args.device)
-
- img = cv2.imread(args.img_path)
-
- ocr_result, _ = ocr_engine(img)
- table_html_str, table_cell_bboxes, elapse = rapid_table(img, ocr_result)
- print(table_html_str)
-
- viser = VisTable()
- if args.vis:
- img_path = Path(args.img_path)
-
- save_dir = img_path.resolve().parent
- save_html_path = save_dir / f"{Path(img_path).stem}.html"
- save_drawed_path = save_dir / f"vis_{Path(img_path).name}"
- viser(
- img_path,
- table_html_str,
- save_html_path,
- table_cell_bboxes,
- save_drawed_path,
- )
-
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/rapid_table_torch/table_matcher/__init__.py b/rapid_table_torch/table_matcher/__init__.py
deleted file mode 100644
index 9bff7d7..0000000
--- a/rapid_table_torch/table_matcher/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# -*- encoding: utf-8 -*-
-# @Author: SWHL
-# @Contact: liekkaskono@163.com
-from .matcher import TableMatch
diff --git a/rapid_table_torch/table_matcher/matcher.py b/rapid_table_torch/table_matcher/matcher.py
deleted file mode 100644
index 0453929..0000000
--- a/rapid_table_torch/table_matcher/matcher.py
+++ /dev/null
@@ -1,192 +0,0 @@
-# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# -*- encoding: utf-8 -*-
-import numpy as np
-
-from .utils import compute_iou, distance
-
-
-class TableMatch:
- def __init__(self, filter_ocr_result=True, use_master=False):
- self.filter_ocr_result = filter_ocr_result
- self.use_master = use_master
-
- def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res):
- if self.filter_ocr_result:
- dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, rec_res)
- matched_index = self.match_result(dt_boxes, pred_bboxes)
- pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
- return pred_html
-
- def match_result(self, dt_boxes, pred_bboxes):
- matched = {}
- for i, gt_box in enumerate(dt_boxes):
- distances = []
- for j, pred_box in enumerate(pred_bboxes):
- if len(pred_box) == 8:
- pred_box = [
- np.min(pred_box[0::2]),
- np.min(pred_box[1::2]),
- np.max(pred_box[0::2]),
- np.max(pred_box[1::2]),
- ]
- distances.append(
- (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
- ) # compute iou and l1 distance
- sorted_distances = distances.copy()
- # select det box by iou and l1 distance
- sorted_distances = sorted(
- sorted_distances, key=lambda item: (item[1], item[0])
- )
- if distances.index(sorted_distances[0]) not in matched.keys():
- matched[distances.index(sorted_distances[0])] = [i]
- else:
- matched[distances.index(sorted_distances[0])].append(i)
- return matched
-
- def get_pred_html(self, pred_structures, matched_index, ocr_contents):
- end_html = []
- td_index = 0
- for tag in pred_structures:
- if "" not in tag:
- end_html.append(tag)
- continue
-
- if " | " == tag:
- end_html.extend("")
-
- if td_index in matched_index.keys():
- b_with = False
- if (
- "" in ocr_contents[matched_index[td_index][0]]
- and len(matched_index[td_index]) > 1
- ):
- b_with = True
- end_html.extend("")
-
- for i, td_index_index in enumerate(matched_index[td_index]):
- content = ocr_contents[td_index_index][0]
- if len(matched_index[td_index]) > 1:
- if len(content) == 0:
- continue
-
- if content[0] == " ":
- content = content[1:]
-
- if "" in content:
- content = content[3:]
-
- if "" in content:
- content = content[:-4]
-
- if len(content) == 0:
- continue
-
- if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
- content += " "
- end_html.extend(content)
-
- if b_with:
- end_html.extend("")
-
- if " | | " == tag:
- end_html.append("")
- else:
- end_html.append(tag)
-
- td_index += 1
-
- # Filter elements
- filter_elements = ["", "", "", ""]
- end_html = [v for v in end_html if v not in filter_elements]
- return "".join(end_html), end_html
- def decode_logic_points(self, pred_structures):
- logic_points = []
- current_row = 0
- current_col = 0
- max_rows = 0
- max_cols = 0
- occupied_cells = {} # 用于记录已经被占用的单元格
-
- def is_occupied(row, col):
- return (row, col) in occupied_cells
-
- def mark_occupied(row, col, rowspan, colspan):
- for r in range(row, row + rowspan):
- for c in range(col, col + colspan):
- occupied_cells[(r, c)] = True
-
- i = 0
- while i < len(pred_structures):
- token = pred_structures[i]
-
- if token == '':
- current_col = 0 # 每次遇到
时,重置当前列号
- elif token == '
':
- current_row += 1 # 行结束,行号增加
- elif token .startswith(' | ':
- j += 1
- # 提取 colspan 和 rowspan 属性
- while j < len(pred_structures) and not pred_structures[j].startswith('>'):
- if 'colspan=' in pred_structures[j]:
- colspan = int(pred_structures[j].split('=')[1].strip('"\''))
- elif 'rowspan=' in pred_structures[j]:
- rowspan = int(pred_structures[j].split('=')[1].strip('"\''))
- j += 1
-
- # 跳过已经处理过的属性 token
- i = j
-
- # 找到下一个未被占用的列
- while is_occupied(current_row, current_col):
- current_col += 1
-
- # 计算逻辑坐标
- r_start = current_row
- r_end = current_row + rowspan - 1
- col_start = current_col
- col_end = current_col + colspan - 1
-
- # 记录逻辑坐标
- logic_points.append([r_start, r_end, col_start, col_end])
-
- # 标记占用的单元格
- mark_occupied(r_start, col_start, rowspan, colspan)
-
- # 更新当前列号
- current_col += colspan
-
- # 更新最大行数和列数
- max_rows = max(max_rows, r_end + 1)
- max_cols = max(max_cols, col_end + 1)
-
- i += 1
-
- return logic_points
-
- def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
- y1 = pred_bboxes[:, 1::2].min()
- new_dt_boxes = []
- new_rec_res = []
-
- for box, rec in zip(dt_boxes, rec_res):
- if np.max(box[1::2]) < y1:
- continue
- new_dt_boxes.append(box)
- new_rec_res.append(rec)
- return new_dt_boxes, new_rec_res
diff --git a/rapid_table_torch/table_matcher/utils.py b/rapid_table_torch/table_matcher/utils.py
deleted file mode 100644
index 3ec8fcc..0000000
--- a/rapid_table_torch/table_matcher/utils.py
+++ /dev/null
@@ -1,248 +0,0 @@
-# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# -*- encoding: utf-8 -*-
-# @Author: SWHL
-# @Contact: liekkaskono@163.com
-import re
-
-
-def deal_isolate_span(thead_part):
- """
- Deal with isolate span cases in this function.
- It causes by wrong prediction in structure recognition model.
- eg. predict | to | rowspan="2">.
- :param thead_part:
- :return:
- """
- # 1. find out isolate span tokens.
- isolate_pattern = (
- ' | rowspan="(\d)+" colspan="(\d)+">|'
- ' | colspan="(\d)+" rowspan="(\d)+">|'
- ' | rowspan="(\d)+">|'
- ' | colspan="(\d)+">'
- )
- isolate_iter = re.finditer(isolate_pattern, thead_part)
- isolate_list = [i.group() for i in isolate_iter]
-
- # 2. find out span number, by step 1 results.
- span_pattern = (
- ' rowspan="(\d)+" colspan="(\d)+"|'
- ' colspan="(\d)+" rowspan="(\d)+"|'
- ' rowspan="(\d)+"|'
- ' colspan="(\d)+"'
- )
- corrected_list = []
- for isolate_item in isolate_list:
- span_part = re.search(span_pattern, isolate_item)
- spanStr_in_isolateItem = span_part.group()
- # 3. merge the span number into the span token format string.
- if spanStr_in_isolateItem is not None:
- corrected_item = " | ".format(spanStr_in_isolateItem)
- corrected_list.append(corrected_item)
- else:
- corrected_list.append(None)
-
- # 4. replace original isolated token.
- for corrected_item, isolate_item in zip(corrected_list, isolate_list):
- if corrected_item is not None:
- thead_part = thead_part.replace(isolate_item, corrected_item)
- else:
- pass
- return thead_part
-
-
-def deal_duplicate_bb(thead_part):
- """
- Deal duplicate or after replace.
- Keep one in a | token.
- :param thead_part:
- :return:
- """
- # 1. find out | in .
- td_pattern = (
- '(.+?) | |'
- '(.+?) | |'
- '(.+?) | |'
- '(.+?) | |'
- "(.*?) | "
- )
- td_iter = re.finditer(td_pattern, thead_part)
- td_list = [t.group() for t in td_iter]
-
- # 2. is multiply in | or not?
- new_td_list = []
- for td_item in td_list:
- if td_item.count("") > 1 or td_item.count("") > 1:
- # multiply in | case.
- # 1. remove all
- td_item = td_item.replace("", "").replace("", "")
- # 2. replace -> , -> .
- td_item = td_item.replace("", " | ").replace(" | ", "")
- new_td_list.append(td_item)
- else:
- new_td_list.append(td_item)
-
- # 3. replace original thead part.
- for td_item, new_td_item in zip(td_list, new_td_list):
- thead_part = thead_part.replace(td_item, new_td_item)
- return thead_part
-
-
-def deal_bb(result_token):
- """
- In our opinion, always occurs in text's context.
- This function will find out all tokens in and insert by manual.
- :param result_token:
- :return:
- """
- # find out parts.
- thead_pattern = "(.*?)"
- if re.search(thead_pattern, result_token) is None:
- return result_token
- thead_part = re.search(thead_pattern, result_token).group()
- origin_thead_part = copy.deepcopy(thead_part)
-
- # check "rowspan" or "colspan" occur in parts or not .
- span_pattern = '| | | | | | '
- span_iter = re.finditer(span_pattern, thead_part)
- span_list = [s.group() for s in span_iter]
- has_span_in_head = True if len(span_list) > 0 else False
-
- if not has_span_in_head:
- # not include "rowspan" or "colspan" branch 1.
- # 1. replace | to | , and | to
- # 2. it is possible to predict text include or by Text-line recognition,
- # so we replace to , and to
- thead_part = (
- thead_part.replace("", " | ")
- .replace(" | ", "")
- .replace("", "")
- .replace("", "")
- )
- else:
- # include "rowspan" or "colspan" branch 2.
- # Firstly, we deal rowspan or colspan cases.
- # 1. replace > to >
- # 2. replace to
- # 3. it is possible to predict text include or by Text-line recognition,
- # so we replace to , and to
-
- # Secondly, deal ordinary cases like branch 1
-
- # replace ">" to ""
- replaced_span_list = []
- for sp in span_list:
- replaced_span_list.append(sp.replace(">", ">"))
- for sp, rsp in zip(span_list, replaced_span_list):
- thead_part = thead_part.replace(sp, rsp)
-
- # replace "" to ""
- thead_part = thead_part.replace("", "")
-
- # remove duplicated by re.sub
- mb_pattern = "()+"
- single_b_string = ""
- thead_part = re.sub(mb_pattern, single_b_string, thead_part)
-
- mgb_pattern = "()+"
- single_gb_string = ""
- thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
-
- # ordinary cases like branch 1
- thead_part = thead_part.replace("", " | ").replace("", "")
-
- # convert back to , empty cell has no .
- # but space cell( ) is suitable for | |
- thead_part = thead_part.replace(" | ", " | ")
- # deal with duplicated
- thead_part = deal_duplicate_bb(thead_part)
- # deal with isolate span tokens, which causes by wrong predict by structure prediction.
- # eg.PMC5994107_011_00.png
- thead_part = deal_isolate_span(thead_part)
- # replace original result with new thead part.
- result_token = result_token.replace(origin_thead_part, thead_part)
- return result_token
-
-
-def deal_eb_token(master_token):
- """
- post process with , , ...
- emptyBboxTokenDict = {
- "[]": '',
- "[' ']": '',
- "['', ' ', '']": '',
- "['\\u2028', '\\u2028']": '',
- "['', ' ', '']": '',
- "['', '']": '',
- "['', ' ', '']": '',
- "['', '', '', '']": '',
- "['', '', ' ', '', '']": '',
- "['', '']": '',
- "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": '',
- }
- :param master_token:
- :return:
- """
- master_token = master_token.replace("", " | ")
- master_token = master_token.replace("", " | ")
- master_token = master_token.replace("", " | ")
- master_token = master_token.replace("", "\u2028\u2028 | ")
- master_token = master_token.replace("", " | ")
- master_token = master_token.replace("", " | ")
- master_token = master_token.replace("", " | ")
- master_token = master_token.replace("", " | ")
- master_token = master_token.replace("", " | ")
- master_token = master_token.replace("", " | ")
- master_token = master_token.replace(
- "", " \u2028 \u2028 | "
- )
- return master_token
-
-
-def distance(box_1, box_2):
- x1, y1, x2, y2 = box_1
- x3, y3, x4, y4 = box_2
- dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
- dis_2 = abs(x3 - x1) + abs(y3 - y1)
- dis_3 = abs(x4 - x2) + abs(y4 - y2)
- return dis + min(dis_2, dis_3)
-
-
-def compute_iou(rec1, rec2):
- """
- computing IoU
- :param rec1: (y0, x0, y1, x1), which reflects
- (top, left, bottom, right)
- :param rec2: (y0, x0, y1, x1)
- :return: scala value of IoU
- """
- # computing area of each rectangles
- S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
- S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
-
- # computing the sum_area
- sum_area = S_rec1 + S_rec2
-
- # find the each edge of intersect rectangle
- left_line = max(rec1[1], rec2[1])
- right_line = min(rec1[3], rec2[3])
- top_line = max(rec1[0], rec2[0])
- bottom_line = min(rec1[2], rec2[2])
-
- # judge if there is an intersect
- if left_line >= right_line or top_line >= bottom_line:
- return 0.0
- else:
- intersect = (right_line - left_line) * (bottom_line - top_line)
- return (intersect / (sum_area - intersect)) * 1.0
diff --git a/rapid_table_torch/table_structure/__init__.py b/rapid_table_torch/table_structure/__init__.py
deleted file mode 100644
index 4582f3d..0000000
--- a/rapid_table_torch/table_structure/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .table_structure import TableStructurer
diff --git a/rapid_table_torch/table_structure/utils.py b/rapid_table_torch/table_structure/utils.py
deleted file mode 100644
index f2d76df..0000000
--- a/rapid_table_torch/table_structure/utils.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# -*- encoding: utf-8 -*-
-# @Author: SWHL
-# @Contact: liekkaskono@163.com
-
-def trans_char_ocr_res(ocr_res):
- word_result = []
- for res in ocr_res:
- score = res[2]
- for word_box, word in zip(res[3], res[4]):
- word_res = []
- word_res.append(word_box)
- word_res.append(word)
- word_res.append(score)
- word_result.append(word_res)
- return word_result
diff --git a/rapid_table_torch/utils.py b/rapid_table_torch/utils.py
deleted file mode 100644
index e7135c5..0000000
--- a/rapid_table_torch/utils.py
+++ /dev/null
@@ -1,210 +0,0 @@
-# -*- encoding: utf-8 -*-
-# @Author: SWHL
-# @Contact: liekkaskono@163.com
-import os
-from io import BytesIO
-from pathlib import Path
-from typing import Optional, Union, List
-
-import cv2
-import numpy as np
-from PIL import Image, UnidentifiedImageError
-
-InputType = Union[str, np.ndarray, bytes, Path]
-
-
-class LoadImage:
- def __init__(
- self,
- ):
- pass
-
- def __call__(self, img: InputType) -> np.ndarray:
- if not isinstance(img, InputType.__args__):
- raise LoadImageError(
- f"The img type {type(img)} does not in {InputType.__args__}"
- )
-
- img = self.load_img(img)
-
- if img.ndim == 2:
- return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
-
- if img.ndim == 3 and img.shape[2] == 4:
- return self.cvt_four_to_three(img)
-
- return img
-
- def load_img(self, img: InputType) -> np.ndarray:
- if isinstance(img, (str, Path)):
- self.verify_exist(img)
- try:
- img = np.array(Image.open(img))
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- except UnidentifiedImageError as e:
- raise LoadImageError(f"cannot identify image file {img}") from e
- return img
-
- if isinstance(img, bytes):
- img = np.array(Image.open(BytesIO(img)))
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- return img
-
- if isinstance(img, np.ndarray):
- return img
-
- raise LoadImageError(f"{type(img)} is not supported!")
-
- @staticmethod
- def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
- """RGBA → RGB"""
- r, g, b, a = cv2.split(img)
- new_img = cv2.merge((b, g, r))
-
- not_a = cv2.bitwise_not(a)
- not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
-
- new_img = cv2.bitwise_and(new_img, new_img, mask=a)
- new_img = cv2.add(new_img, not_a)
- return new_img
-
- @staticmethod
- def verify_exist(file_path: Union[str, Path]):
- if not Path(file_path).exists():
- raise LoadImageError(f"{file_path} does not exist.")
-
-
-class LoadImageError(Exception):
- pass
-
-
-class VisTable:
- def __init__(
- self,
- ):
- self.load_img = LoadImage()
-
- def __call__(
- self,
- img_path: Union[str, Path],
- table_html_str: str,
- save_html_path: Optional[str] = None,
- table_cell_bboxes: Optional[np.ndarray] = None,
- save_drawed_path: Optional[str] = None,
- logic_points: List[List[float]] = None,
- save_logic_path: Optional[str] = None,
- ) -> None:
- if save_html_path:
- html_with_border = self.insert_border_style(table_html_str)
- self.save_html(save_html_path, html_with_border)
-
- if table_cell_bboxes is None:
- return None
-
- img = self.load_img(img_path)
-
- dims_bboxes = table_cell_bboxes.shape[1]
- if dims_bboxes == 4:
- drawed_img = self.draw_rectangle(img, table_cell_bboxes)
- elif dims_bboxes == 8:
- drawed_img = self.draw_polylines(img, table_cell_bboxes)
- else:
- raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
-
- if save_drawed_path:
- self.save_img(save_drawed_path, drawed_img)
- if save_logic_path and logic_points:
- polygons = [[box[0],box[1], box[4], box[5]] for box in table_cell_bboxes]
- self.plot_rec_box_with_logic_info(img_path, save_logic_path, logic_points, polygons)
- return drawed_img
-
- def insert_border_style(self, table_html_str: str):
- style_res = f""""""
- prefix_table, suffix_table = table_html_str.split("")
- html_with_border = f"{prefix_table}{style_res}{suffix_table}"
- return html_with_border
-
- def plot_rec_box_with_logic_info(self, img_path, output_path, logic_points, sorted_polygons):
- """
- :param img_path
- :param output_path
- :param logic_points: [row_start,row_end,col_start,col_end]
- :param sorted_polygons: [xmin,ymin,xmax,ymax]
- :return:
- """
- # 读取原图
- img = cv2.imread(img_path)
- img = cv2.copyMakeBorder(
- img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
- )
- # 绘制 polygons 矩形
- for idx, polygon in enumerate(sorted_polygons):
- x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
- x0 = round(x0)
- y0 = round(y0)
- x1 = round(x1)
- y1 = round(y1)
- cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
- # 增大字体大小和线宽
- font_scale = 0.9 # 原先是0.5
- thickness = 1 # 原先是1
- logic_point = logic_points[idx]
- cv2.putText(
- img,
- f"row: {logic_point[0]}-{logic_point[1]}",
- (x0 + 3, y0 + 8),
- cv2.FONT_HERSHEY_PLAIN,
- font_scale,
- (0, 0, 255),
- thickness,
- )
- cv2.putText(
- img,
- f"col: {logic_point[2]}-{logic_point[3]}",
- (x0 + 3, y0 + 18),
- cv2.FONT_HERSHEY_PLAIN,
- font_scale,
- (0, 0, 255),
- thickness,
- )
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
- # 保存绘制后的图像
- self.save_img(output_path, img)
-
- @staticmethod
- def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
- img_copy = img.copy()
- for box in boxes.astype(int):
- x1, y1, x2, y2 = box
- cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
- return img_copy
-
- @staticmethod
- def draw_polylines(img: np.ndarray, points) -> np.ndarray:
- img_copy = img.copy()
- for point in points.astype(int):
- point = point.reshape(4, 2)
- cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
- return img_copy
-
- @staticmethod
- def save_img(save_path: Union[str, Path], img: np.ndarray):
- cv2.imwrite(str(save_path), img)
-
- @staticmethod
- def save_html(save_path: Union[str, Path], html: str):
- with open(save_path, "w", encoding="utf-8") as f:
- f.write(html)
diff --git a/requirements_torch.txt b/requirements_torch.txt
deleted file mode 100644
index 4ba3572..0000000
--- a/requirements_torch.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-onnxruntime>=1.7.0
-opencv_python>=4.5.1.48
-numpy>=1.21.6,<2
-Pillow
-requests
-torch
-torchvision
-tokenizers
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 0539473..0489760 100644
--- a/setup.py
+++ b/setup.py
@@ -66,4 +66,5 @@ def get_readme():
],
python_requires=">=3.6,<3.13",
entry_points={"console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.main:main"]},
+ extras_require={"torch": ["torch", "torchvision", "tokenizers"]},
)
diff --git a/setup_torch.py b/setup_torch.py
deleted file mode 100644
index b332a20..0000000
--- a/setup_torch.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# -*- encoding: utf-8 -*-
-# @Author: SWHL
-# @Contact: liekkaskono@163.com
-import sys
-from pathlib import Path
-
-import setuptools
-from get_pypi_latest_version import GetPyPiLatestVersion
-
-
-def get_readme():
- root_dir = Path(__file__).resolve().parent
- readme_path = str(root_dir / "docs" / "doc_whl_rapid_table.md")
- with open(readme_path, "r", encoding="utf-8") as f:
- readme = f.read()
- return readme
-
-
-MODULE_NAME = "rapid_table_torch"
-obtainer = GetPyPiLatestVersion()
-try:
- latest_version = obtainer(MODULE_NAME)
-except Exception:
- latest_version = "0.0.0"
-VERSION_NUM = obtainer.version_add_one(latest_version)
-
-if len(sys.argv) > 2:
- match_str = " ".join(sys.argv[2:])
- matched_versions = obtainer.extract_version(match_str)
- if matched_versions:
- VERSION_NUM = matched_versions
-sys.argv = sys.argv[:2]
-
-setuptools.setup(
- name=MODULE_NAME,
- version=VERSION_NUM,
- platforms="Any",
- long_description=get_readme(),
- long_description_content_type="text/markdown",
- description="Tools for parsing table structures based pytorch.",
- author="SWHL",
- author_email="liekkaskono@163.com",
- url="https://github.com/RapidAI/RapidTable",
- license="Apache-2.0",
- include_package_data=True,
- install_requires=[
- "opencv_python>=4.5.1.48",
- "numpy>=1.21.6",
- "Pillow",
- "torch",
- "torchvision",
- "tokenizers"
- ],
- packages=[
- MODULE_NAME,
- f"{MODULE_NAME}.models",
- f"{MODULE_NAME}.table_matcher",
- f"{MODULE_NAME}.table_structure",
- ],
- package_data={"": [".gitkeep"]},
- keywords=["unitable,table,rapidocr,rapid_table"],
- classifiers=[
- "Programming Language :: Python :: 3.10",
- "Programming Language :: Python :: 3.11",
- "Programming Language :: Python :: 3.12",
- ],
- python_requires=">=3.10,<3.13",
- entry_points={"console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.main:main"]},
-)
diff --git a/tests/test_table_torch.py b/tests/test_table_torch.py
index 640257c..1eb4505 100644
--- a/tests/test_table_torch.py
+++ b/tests/test_table_torch.py
@@ -11,10 +11,10 @@
sys.path.append(str(root_dir))
-from rapid_table_torch import RapidTable
+from rapid_table import RapidTable
ocr_engine = RapidOCR()
-table_engine = RapidTable()
+table_engine = RapidTable(model_type="unitable")
test_file_dir = cur_dir / "test_files"
img_path = str(test_file_dir / "table.jpg")
@@ -22,14 +22,14 @@
def test_ocr_input():
ocr_res, _ = ocr_engine(img_path)
- table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_res)
+ table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_res)
assert table_html_str.count("") == 16
def test_input_ocr_none():
- table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path)
+ table_html_str, table_cell_bboxes, elapse = table_engine(img_path)
assert table_html_str.count("
") == 16
def test_logic_points_out():
- table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path)
+ table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, return_logic_points=True)
assert len(table_cell_bboxes) == len(logic_points)