Skip to content

Commit

Permalink
Add a custom u2net session (#482)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgatis authored Jun 28, 2023
1 parent 8c02c27 commit c0b08f8
Show file tree
Hide file tree
Showing 15 changed files with 83 additions and 37 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ Passing extras parameters
rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png
```

```
rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png
```

### rembg `p`

Used when input and output are folders.
Expand Down
2 changes: 1 addition & 1 deletion rembg/commands/b_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def rs_command(
except Exception:
pass

session = new_session(model)
session = new_session(model, **kwargs)
bytes_per_img = image_width * image_height * 3

if output_specifier:
Expand Down
2 changes: 1 addition & 1 deletion rembg/commands/i_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
except Exception:
pass

output.write(remove(input.read(), session=new_session(model), **kwargs))
output.write(remove(input.read(), session=new_session(model, **kwargs), **kwargs))
2 changes: 1 addition & 1 deletion rembg/commands/p_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def p_command(
except Exception:
pass

session = new_session(model)
session = new_session(model, **kwargs)

def process(each_input: pathlib.Path) -> None:
try:
Expand Down
29 changes: 13 additions & 16 deletions rembg/commands/s_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
return Response(
remove(
content,
session=sessions.setdefault(commons.model, new_session(commons.model)),
session=sessions.setdefault(
commons.model, new_session(commons.model, **kwargs)
),
alpha_matting=commons.a,
alpha_matting_foreground_threshold=commons.af,
alpha_matting_background_threshold=commons.ab,
Expand Down Expand Up @@ -245,32 +247,27 @@ async def post_index(
return await asyncify(im_without_bg)(file, commons) # type: ignore

def gr_app(app):
def inference(input_path, model):
def inference(input_path, model, cmd_args):
output_path = "output.png"

kwargs = {}
if cmd_args:
kwargs.update(json.loads(cmd_args))
kwargs["session"] = new_session(model, **kwargs)

with open(input_path, "rb") as i:
with open(output_path, "wb") as o:
input = i.read()
output = remove(input, session=new_session(model))
output = remove(input, **kwargs)
o.write(output)
return os.path.join(output_path)

interface = gr.Interface(
inference,
[
gr.components.Image(type="filepath", label="Input"),
gr.components.Dropdown(
[
"u2net",
"u2netp",
"u2net_human_seg",
"u2net_cloth_seg",
"silueta",
"isnet-general-use",
"isnet-anime",
],
value="u2net",
label="Models",
),
gr.components.Dropdown(sessions_names, value="u2net", label="Models"),
gr.components.Textbox(label="Arguments"),
],
gr.components.Image(type="filepath", label="Output"),
)
Expand Down
2 changes: 1 addition & 1 deletion rembg/sessions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
self.providers.extend(_providers)

self.inner_session = ort.InferenceSession(
str(self.__class__.download_models()),
str(self.__class__.download_models(*args, **kwargs)),
providers=self.providers,
sess_options=sess_opts,
)
Expand Down
4 changes: 2 additions & 2 deletions rembg/sessions/dis_anime.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:

@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
None
Expand All @@ -42,7 +42,7 @@ def download_models(cls, *args, **kwargs):
progressbar=True,
)

return os.path.join(cls.u2net_home(), fname)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)

@classmethod
def name(cls, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions rembg/sessions/dis_general_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:

@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
None
Expand All @@ -42,7 +42,7 @@ def download_models(cls, *args, **kwargs):
progressbar=True,
)

return os.path.join(cls.u2net_home(), fname)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)

@classmethod
def name(cls, *args, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions rembg/sessions/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def predict(

@classmethod
def download_models(cls, *args, **kwargs):
fname_encoder = f"{cls.name()}_encoder.onnx"
fname_decoder = f"{cls.name()}_decoder.onnx"
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"

pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
Expand All @@ -160,8 +160,8 @@ def download_models(cls, *args, **kwargs):
)

return (
os.path.join(cls.u2net_home(), fname_encoder),
os.path.join(cls.u2net_home(), fname_decoder),
os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),
)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion rembg/sessions/silueta.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def download_models(cls, *args, **kwargs):
progressbar=True,
)

return os.path.join(cls.u2net_home(), fname)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)

@classmethod
def name(cls, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions rembg/sessions/u2net.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:

@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
None
Expand All @@ -44,7 +44,7 @@ def download_models(cls, *args, **kwargs):
progressbar=True,
)

return os.path.join(cls.u2net_home(), fname)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)

@classmethod
def name(cls, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions rembg/sessions/u2net_cloth_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:

@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
None
Expand All @@ -105,7 +105,7 @@ def download_models(cls, *args, **kwargs):
progressbar=True,
)

return os.path.join(cls.u2net_home(), fname)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)

@classmethod
def name(cls, *args, **kwargs):
Expand Down
45 changes: 45 additions & 0 deletions rembg/sessions/u2net_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
from typing import List

import numpy as np
import pooch
from PIL import Image
from PIL.Image import Image as PILImage

from .base import BaseSession


class U2netCustomSession(BaseSession):
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
),
)

pred = ort_outs[0][:, 0, :, :]

ma = np.max(pred)
mi = np.min(pred)

pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)

mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.LANCZOS)

return [mask]

@classmethod
def download_models(cls, *args, **kwargs):
model_path = kwargs.get("model_path")

if model_path is None:
raise ValueError("model_path is required")

return os.path.abspath(os.path.expanduser(model_path))

@classmethod
def name(cls, *args, **kwargs):
return "u2net_custom"
4 changes: 2 additions & 2 deletions rembg/sessions/u2net_human_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:

@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
None
Expand All @@ -44,7 +44,7 @@ def download_models(cls, *args, **kwargs):
progressbar=True,
)

return os.path.join(cls.u2net_home(), fname)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)

@classmethod
def name(cls, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions rembg/sessions/u2netp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:

@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
None
Expand All @@ -44,7 +44,7 @@ def download_models(cls, *args, **kwargs):
progressbar=True,
)

return os.path.join(cls.u2net_home(), fname)
return os.path.join(cls.u2net_home(*args, **kwargs), fname)

@classmethod
def name(cls, *args, **kwargs):
Expand Down

0 comments on commit c0b08f8

Please sign in to comment.