From 4a70f601ee1120841fc0661b0959305f9df1914d Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Mon, 30 Dec 2024 22:36:56 +0800 Subject: [PATCH] fix: adapt unitable no content label token --- rapid_table/main.py | 3 +++ rapid_table/table_matcher/matcher.py | 5 ++++- rapid_table/table_structure/table_structure_unitable.py | 8 +++++--- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/rapid_table/main.py b/rapid_table/main.py index 202f6e0..b412790 100644 --- a/rapid_table/main.py +++ b/rapid_table/main.py @@ -65,6 +65,9 @@ def __call__( if self.model_type == "slanet-plus": pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes) pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res) + # 过滤掉占位的bbox + mask = ~np.all(pred_bboxes == 0, axis=1) + pred_bboxes = pred_bboxes[mask] # 避免低版本升级后出现问题,默认不打开 if return_logic_points: logic_points = self.table_matcher.decode_logic_points(pred_structures) diff --git a/rapid_table/table_matcher/matcher.py b/rapid_table/table_matcher/matcher.py index 0453929..bc976ed 100644 --- a/rapid_table/table_matcher/matcher.py +++ b/rapid_table/table_matcher/matcher.py @@ -29,7 +29,7 @@ def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res): pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) return pred_html - def match_result(self, dt_boxes, pred_bboxes): + def match_result(self, dt_boxes, pred_bboxes, min_iou=0.1 ** 8): matched = {} for i, gt_box in enumerate(dt_boxes): distances = [] @@ -49,6 +49,9 @@ def match_result(self, dt_boxes, pred_bboxes): sorted_distances = sorted( sorted_distances, key=lambda item: (item[1], item[0]) ) + # must > min_iou + if sorted_distances[0][1] >= 1 - min_iou: + continue if distances.index(sorted_distances[0]) not in matched.keys(): matched[distances.index(sorted_distances[0])] = [i] else: diff --git a/rapid_table/table_structure/table_structure_unitable.py b/rapid_table/table_structure/table_structure_unitable.py index da9ed7f..fdc42b4 100644 --- a/rapid_table/table_structure/table_structure_unitable.py +++ b/rapid_table/table_structure/table_structure_unitable.py @@ -183,8 +183,8 @@ def decode_tokens(self, context): td_attrs = td_match.group(1).strip() td_content = td_match.group(2).strip() if td_attrs: - decoded_list.append('') decoded_list.append('') else: @@ -197,7 +197,9 @@ def decode_tokens(self, context): # 将坐标转换为从左上角开始顺时针到左下角的点的坐标 coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax]) bbox_coords.append(coords) - + else: + # 填充占位的bbox,保证后续流程统一 + bbox_coords.append(np.array([0, 0, 0,0,0,0, 0, 0])) decoded_list.append('') # 将 bbox_coords 转换为 numpy 数组