From 9ea4a74f52de788e2e4aed264cfe3c23975c78db Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Mon, 28 Oct 2024 18:01:35 +0800 Subject: [PATCH] feat: optim param use for table cls --- table_cls/main.py | 17 +++++++++-------- tests/test_table_cls.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/table_cls/main.py b/table_cls/main.py index d6ddb66..ca7ab4c 100644 --- a/table_cls/main.py +++ b/table_cls/main.py @@ -13,11 +13,12 @@ class TableCls: - def __init__(self, model="yolo"): - if model == "yolo": - self.table_engine = YoloCls() + def __init__(self, model_type="yolo", model_path=yolo_cls_model_path): + if model_type == "yolo": + self.table_engine = YoloCls(model_path) else: - self.table_engine = QanythingCls() + model_path = q_cls_model_path + self.table_engine = QanythingCls(model_path) self.load_img = LoadImage() def __call__(self, content: InputType): @@ -30,8 +31,8 @@ def __call__(self, content: InputType): class QanythingCls: - def __init__(self): - self.table_cls = OrtInferSession(q_cls_model_path) + def __init__(self, model_path): + self.table_cls = OrtInferSession(model_path) self.inp_h = 224 self.inp_w = 224 self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) @@ -60,8 +61,8 @@ def __call__(self, img): class YoloCls: - def __init__(self): - self.table_cls = OrtInferSession(yolo_cls_model_path) + def __init__(self, model_path): + self.table_cls = OrtInferSession(model_path) self.cls = {0: "wireless", 1: "wired"} def preprocess(self, img): diff --git a/tests/test_table_cls.py b/tests/test_table_cls.py index d9d1813..b5c5611 100644 --- a/tests/test_table_cls.py +++ b/tests/test_table_cls.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize( "img_path, expected", - [("wired_table.png", "wired"), ("lineless_table.png", "wireless")], + [("wired_table.jpg", "wired"), ("lineless_table.png", "wireless")], ) def test_input_normal(img_path, expected): img_path = test_file_dir / img_path