Skip to content

Commit

Permalink
feat: optim param use for table cls
Browse files Browse the repository at this point in the history
  • Loading branch information
Joker1212 committed Oct 28, 2024
1 parent 490f328 commit 9ea4a74
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
17 changes: 9 additions & 8 deletions table_cls/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@


class TableCls:
def __init__(self, model="yolo"):
if model == "yolo":
self.table_engine = YoloCls()
def __init__(self, model_type="yolo", model_path=yolo_cls_model_path):
if model_type == "yolo":
self.table_engine = YoloCls(model_path)
else:
self.table_engine = QanythingCls()
model_path = q_cls_model_path
self.table_engine = QanythingCls(model_path)
self.load_img = LoadImage()

def __call__(self, content: InputType):
Expand All @@ -30,8 +31,8 @@ def __call__(self, content: InputType):


class QanythingCls:
def __init__(self):
self.table_cls = OrtInferSession(q_cls_model_path)
def __init__(self, model_path):
self.table_cls = OrtInferSession(model_path)
self.inp_h = 224
self.inp_w = 224
self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
Expand Down Expand Up @@ -60,8 +61,8 @@ def __call__(self, img):


class YoloCls:
def __init__(self):
self.table_cls = OrtInferSession(yolo_cls_model_path)
def __init__(self, model_path):
self.table_cls = OrtInferSession(model_path)
self.cls = {0: "wireless", 1: "wired"}

def preprocess(self, img):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_table_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@pytest.mark.parametrize(
"img_path, expected",
[("wired_table.png", "wired"), ("lineless_table.png", "wireless")],
[("wired_table.jpg", "wired"), ("lineless_table.png", "wireless")],
)
def test_input_normal(img_path, expected):
img_path = test_file_dir / img_path
Expand Down

0 comments on commit 9ea4a74

Please sign in to comment.