diff --git a/api/rapidocr_api/main.py b/api/rapidocr_api/main.py index cb2ca0cb0..2de71c2af 100644 --- a/api/rapidocr_api/main.py +++ b/api/rapidocr_api/main.py @@ -1,22 +1,21 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com - import argparse import base64 +import importlib.util import io import os import sys from pathlib import Path from typing import Dict -import importlib.util import numpy as np import uvicorn from fastapi import FastAPI, Form, UploadFile from PIL import Image -if importlib.util.find_spec("rapidocr_runtime"): +if importlib.util.find_spec("rapidocr_onnxruntime"): from rapidocr_onnxruntime import RapidOCR elif importlib.util.find_spec("rapidocr_paddle"): from rapidocr_paddle import RapidOCR @@ -24,7 +23,7 @@ from rapidocr_openvino import RapidOCR else: raise ImportError( - "Pleas install one of [rapidocr-runtime,rapidocr-paddle,rapidocr-openvino]" + "Please install one of [rapidocr_onnxruntime,rapidocr-paddle,rapidocr-openvino]" ) sys.path.append(str(Path(__file__).resolve().parent.parent)) @@ -32,16 +31,18 @@ class OCRAPIUtils: def __init__(self) -> None: - # 从环境变量中读取参数 det_model_path = os.getenv("det_model_path", None) cls_model_path = os.getenv("cls_model_path", None) rec_model_path = os.getenv("rec_model_path", None) - self.ocr = RapidOCR( - det_model_path=det_model_path, - cls_model_path=cls_model_path, - rec_model_path=rec_model_path, - ) + if det_model_path is None or cls_model_path is None or rec_model_path is None: + self.ocr = RapidOCR() + else: + self.ocr = RapidOCR( + det_model_path=det_model_path, + cls_model_path=cls_model_path, + rec_model_path=rec_model_path, + ) def __call__( self, img: Image.Image, use_det=None, use_cls=None, use_rec=None, **kwargs @@ -54,7 +55,6 @@ def __call__( if not ocr_res: return {} - # 转换为字典格式: 兼容所有参数情况 out_dict = {} for i, dats in enumerate(ocr_res): values = {} diff --git a/api/setup.py b/api/setup.py index 61af9ec1c..0f852fa8d 100644 --- a/api/setup.py +++ b/api/setup.py @@ -33,7 +33,7 @@ def get_readme(): latest_version = obtainer(MODULE_NAME) except ValueError: latest_version = "0.0.1" -VERSION_NUM = obtainer.version_add_one(latest_version) +VERSION_NUM = obtainer.version_add_one(latest_version, add_patch=True) if len(sys.argv) > 2: match_str = " ".join(sys.argv[2:]) @@ -56,11 +56,6 @@ def get_readme(): license="Apache-2.0", include_package_data=True, install_requires=read_txt("requirements.txt"), - extras_require={ - 'onnx': ['rapidocr-onnxruntime'], - 'paddle': ['rapidocr-paddle'], - 'openvino': ['rapidocr-openvino'], - }, packages=[MODULE_NAME], package_data={"": ["*.ico", "*.css", "*.js", "*.html"]}, keywords=[ @@ -81,4 +76,9 @@ def get_readme(): f"{MODULE_NAME}={MODULE_NAME}.main:main", ], }, + extras_require={ + "onnx": ["rapidocr-onnxruntime"], + "paddle": ["rapidocr-paddle"], + "openvino": ["rapidocr-openvino"], + }, )