diff --git a/python/demo.py b/python/demo.py index 5e5d6f702..cd5cf1687 100644 --- a/python/demo.py +++ b/python/demo.py @@ -12,7 +12,7 @@ engine = RapidOCR() vis = VisRes() -image_path = "tests/test_files/test_without_det.jpg" +image_path = "tests/test_files/black_font_color_transparent.png" with open(image_path, "rb") as f: img = f.read() diff --git a/python/rapidocr_onnxruntime/utils.py b/python/rapidocr_onnxruntime/utils.py index 778066dd7..181dd91b5 100644 --- a/python/rapidocr_onnxruntime/utils.py +++ b/python/rapidocr_onnxruntime/utils.py @@ -208,7 +208,12 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray: not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) new_img = cv2.bitwise_and(new_img, new_img, mask=a) - new_img = cv2.add(new_img, not_a) + + mean_color = np.mean(new_img) + if mean_color <= 0.0: + new_img = cv2.add(new_img, not_a) + else: + new_img = cv2.bitwise_not(new_img) return new_img @staticmethod diff --git a/python/tests/test_files/black_font_color_transparent.png b/python/tests/test_files/black_font_color_transparent.png new file mode 100644 index 000000000..dfd20a14d Binary files /dev/null and b/python/tests/test_files/black_font_color_transparent.png differ diff --git a/python/tests/test_files/white_font_color_transparent.png b/python/tests/test_files/white_font_color_transparent.png new file mode 100644 index 000000000..9f9f095eb Binary files /dev/null and b/python/tests/test_files/white_font_color_transparent.png differ diff --git a/python/tests/test_ort.py b/python/tests/test_ort.py index ca0e993f8..45e8971fb 100644 --- a/python/tests/test_ort.py +++ b/python/tests/test_ort.py @@ -20,6 +20,25 @@ package_name = "rapidocr_onnxruntime" +@pytest.mark.parametrize( + "img_name,gt", + [ + ( + "black_font_color_transparent.png", + "我是中国人", + ), + ( + "white_font_color_transparent.png", + "我是中国人", + ), + ], +) +def test_transparent_img(img_name: str, gt: str): + img_path = tests_dir / img_name + result, _ = engine(img_path) + assert result[0][1] == gt + + @pytest.mark.parametrize( "img_name,gt_len,gt_first_len", [