Skip to content

Commit

Permalink
fix: adapt unitable no content label token
Browse files Browse the repository at this point in the history
  • Loading branch information
Joker1212 committed Dec 30, 2024
1 parent 2e51412 commit 4a70f60
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
3 changes: 3 additions & 0 deletions rapid_table/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __call__(
if self.model_type == "slanet-plus":
pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes)
pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res)
# 过滤掉占位的bbox
mask = ~np.all(pred_bboxes == 0, axis=1)
pred_bboxes = pred_bboxes[mask]
# 避免低版本升级后出现问题,默认不打开
if return_logic_points:
logic_points = self.table_matcher.decode_logic_points(pred_structures)
Expand Down
5 changes: 4 additions & 1 deletion rapid_table/table_matcher/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res):
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
return pred_html

def match_result(self, dt_boxes, pred_bboxes):
def match_result(self, dt_boxes, pred_bboxes, min_iou=0.1 ** 8):
matched = {}
for i, gt_box in enumerate(dt_boxes):
distances = []
Expand All @@ -49,6 +49,9 @@ def match_result(self, dt_boxes, pred_bboxes):
sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0])
)
# must > min_iou
if sorted_distances[0][1] >= 1 - min_iou:
continue
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
Expand Down
8 changes: 5 additions & 3 deletions rapid_table/table_structure/table_structure_unitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def decode_tokens(self, context):
td_attrs = td_match.group(1).strip()
td_content = td_match.group(2).strip()
if td_attrs:
decoded_list.append('<td ')
decoded_list.append(td_attrs)
decoded_list.append('<td')
decoded_list.append(" " + td_attrs)
decoded_list.append('>')
decoded_list.append('</td>')
else:
Expand All @@ -197,7 +197,9 @@ def decode_tokens(self, context):
# 将坐标转换为从左上角开始顺时针到左下角的点的坐标
coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
bbox_coords.append(coords)

else:
# 填充占位的bbox,保证后续流程统一
bbox_coords.append(np.array([0, 0, 0,0,0,0, 0, 0]))
decoded_list.append('</tr>')

# 将 bbox_coords 转换为 numpy 数组
Expand Down

0 comments on commit 4a70f60

Please sign in to comment.