From 4aa4f403eeacfb4d06600249122a3af6254a3d0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 00:13:12 +0800 Subject: [PATCH 01/17] Add anime_face_segment --- annotator/anime_face_segment/LICENSE | 21 +++ annotator/anime_face_segment/__init__.py | 158 +++++++++++++++++++++++ scripts/global_state.py | 2 + scripts/processor.py | 20 +++ 4 files changed, 201 insertions(+) create mode 100644 annotator/anime_face_segment/LICENSE create mode 100644 annotator/anime_face_segment/__init__.py diff --git a/annotator/anime_face_segment/LICENSE b/annotator/anime_face_segment/LICENSE new file mode 100644 index 000000000..9bad05450 --- /dev/null +++ b/annotator/anime_face_segment/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Miaomiao Li + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py new file mode 100644 index 000000000..0aae8e317 --- /dev/null +++ b/annotator/anime_face_segment/__init__.py @@ -0,0 +1,158 @@ +import os +import torch +import torch.nn as nn +from PIL import Image +#import fnmatch +import cv2 + +#import sys + +import numpy as np +from modules import devices +from annotator.annotator_path import models_path + +import torchvision +from torchvision.models import MobileNet_V2_Weights + + +class UNet(nn.Module): + def __init__(self): + super(UNet, self).__init__() + self.NUM_SEG_CLASSES = 7 # Background, hair, face, eye, mouth, skin, clothes + + mobilenet_v2 = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1) + mob_blocks = mobilenet_v2.features + + # Encoder + self.en_block0 = nn.Sequential( # in_ch=3 out_ch=16 + mob_blocks[0], + mob_blocks[1] + ) + self.en_block1 = nn.Sequential( # in_ch=16 out_ch=24 + mob_blocks[2], + mob_blocks[3], + ) + self.en_block2 = nn.Sequential( # in_ch=24 out_ch=32 + mob_blocks[4], + mob_blocks[5], + mob_blocks[6], + ) + self.en_block3 = nn.Sequential( # in_ch=32 out_ch=96 + mob_blocks[7], + mob_blocks[8], + mob_blocks[9], + mob_blocks[10], + mob_blocks[11], + mob_blocks[12], + mob_blocks[13], + ) + self.en_block4 = nn.Sequential( # in_ch=96 out_ch=160 + mob_blocks[14], + mob_blocks[15], + mob_blocks[16], + ) + + # Decoder + self.de_block4 = nn.Sequential( # in_ch=160 out_ch=96 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(160, 96, kernel_size=3, padding=1), + nn.InstanceNorm2d(96), + nn.LeakyReLU(0.1), + nn.Dropout(p=0.2) + ) + self.de_block3 = nn.Sequential( # in_ch=96x2 out_ch=32 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(96*2, 32, kernel_size=3, padding=1), + nn.InstanceNorm2d(32), + nn.LeakyReLU(0.1), + nn.Dropout(p=0.2) + ) + self.de_block2 = nn.Sequential( # in_ch=32x2 out_ch=24 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(32*2, 24, kernel_size=3, padding=1), + nn.InstanceNorm2d(24), + nn.LeakyReLU(0.1), + nn.Dropout(p=0.2) + ) + self.de_block1 = nn.Sequential( # in_ch=24x2 out_ch=16 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(24*2, 16, kernel_size=3, padding=1), + nn.InstanceNorm2d(16), + nn.LeakyReLU(0.1), + nn.Dropout(p=0.2) + ) + + self.de_block0 = nn.Sequential( # in_ch=16x2 out_ch=7 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1), + nn.Softmax2d() + ) + + def forward(self, x): + e0 = self.en_block0(x) + e1 = self.en_block1(e0) + e2 = self.en_block2(e1) + e3 = self.en_block3(e2) + e4 = self.en_block4(e3) + + d4 = self.de_block4(e4) + c4 = torch.cat((d4,e3),1) + d3 = self.de_block3(c4) + c3 = torch.cat((d3,e2),1) + d2 = self.de_block2(c3) + c2 =torch.cat((d2,e1),1) + d1 = self.de_block1(c2) + c1 = torch.cat((d1,e0),1) + y = self.de_block0(c1) + + return y + + +class AnimeFaceSegment: + COLOR_BACKGROUND = (0,255,255) + COLOR_HAIR = (255,0,0) + COLOR_EYE = (0,0,255) + COLOR_MOUTH = (255,255,255) + COLOR_FACE = (0,255,0) + COLOR_SKIN = (255,255,0) + COLOR_CLOTHES = (255,0,255) + PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES] + + model_dir = os.path.join(models_path, "anime_face_segment") + + def __init__(self): + self.model = None + self.device = devices.get_device_for("controlnet") + + def load_model(self): + remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/erika.pth" + modelpath = os.path.join(self.model_dir, "Unet.pth") + if not os.path.exists(modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=self.model_dir) + net = UNet() + ckpt = torch.load(modelpath) + for key in list(ckpt.keys()): + if 'module.' in key: + ckpt[key.replace('module.', '')] = ckpt[key] + del ckpt[key] + net.load_state_dict(ckpt) + net.eval() + self.model = net.to(self.device) + + def unload_model(self): + if self.model is not None: + self.model.cpu() + + def __call__(self, input_image): + if self.model is None: + self.load_model() + self.model.to(self.device) + with torch.no_grad(): + seg = self.model(input_image) + seg = seg.cpu().detach().numpy() + img = np.moveaxis(seg,0,2) + img = [[PALETTE[np.argmax(val)] for val in buf]for buf in seg] + return np.array(img).astype(np.uint8) + + diff --git a/scripts/global_state.py b/scripts/global_state.py index 4821a9440..b4dc539dc 100644 --- a/scripts/global_state.py +++ b/scripts/global_state.py @@ -101,6 +101,7 @@ def unified_preprocessor(preprocessor_name: str, *args, **kwargs): "recolor_luminance": recolor_luminance, "recolor_intensity": recolor_intensity, "blur_gaussian": blur_gaussian, + "anime_face_segment": anime_face_segment } cn_preprocessor_unloadable = { @@ -132,6 +133,7 @@ def unified_preprocessor(preprocessor_name: str, *args, **kwargs): "lineart_anime": unload_lineart_anime, "lineart_anime_denoise": unload_lineart_anime_denoise, "inpaint_only+lama": unload_lama_inpaint + "anime_face_segment": unload_anime_face_segment } preprocessor_aliases = { diff --git a/scripts/processor.py b/scripts/processor.py index 7582e83de..295c03050 100644 --- a/scripts/processor.py +++ b/scripts/processor.py @@ -620,6 +620,26 @@ def blur_gaussian(img, res=512, thr_a=1.0, **kwargs): return result, True +model_anime_face_segement = None + + +def anime_face_segement(img, res=512, **kwargs): + img, remove_pad = resize_image_with_pad(img, res) + global model_anime_face_segement + if model_anime_face_segement is None: + from annotator.model_anime_face_segement import AnimeFaceSegment + model_anime_face_segement = AnimeFaceSegment() + + result = model_manga_line(img) + return remove_pad(result), True + + +def unload_anime_face_segement(): + global model_anime_face_segement + if model_anime_face_segement is not None: + model_anime_face_segement.unload_model() + + model_free_preprocessors = [ "reference_only", "reference_adain", From 42b3bc32a99836324edaa3cb754a0ad4f71fb62d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 00:20:39 +0800 Subject: [PATCH 02/17] Update __init__.py --- annotator/anime_face_segment/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py index 0aae8e317..215567207 100644 --- a/annotator/anime_face_segment/__init__.py +++ b/annotator/anime_face_segment/__init__.py @@ -125,7 +125,7 @@ def __init__(self): self.device = devices.get_device_for("controlnet") def load_model(self): - remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/erika.pth" + remote_model_path = "https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/resolve/main/Annotators/UNet.pth" modelpath = os.path.join(self.model_dir, "Unet.pth") if not os.path.exists(modelpath): from basicsr.utils.download_util import load_file_from_url From 4a452dcaa1e00b4f7dd84d01ffa07c4d8649ca5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 00:41:24 +0800 Subject: [PATCH 03/17] update add aliases --- scripts/global_state.py | 1 + scripts/processor.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/scripts/global_state.py b/scripts/global_state.py index b4dc539dc..ff1c72317 100644 --- a/scripts/global_state.py +++ b/scripts/global_state.py @@ -154,6 +154,7 @@ def unified_preprocessor(preprocessor_name: str, *args, **kwargs): "oneformer_ade20k": "seg_ofade20k", "pidinet_scribble": "scribble_pidinet", "inpaint": "inpaint_global_harmonious", + "anime_face_segment": "seg_anime_face", } ui_preprocessor_keys = ['none', preprocessor_aliases['invert']] diff --git a/scripts/processor.py b/scripts/processor.py index 295c03050..7fdc64886 100644 --- a/scripts/processor.py +++ b/scripts/processor.py @@ -1006,6 +1006,14 @@ def unload_anime_face_segement(): "step": 0.001 } ], + "anime_face_segement": [ + { + "name": flag_preprocessor_resolution, + "value": 512, + "min": 64, + "max": 2048 + } + ], } preprocessor_filters = { From 4b1a65b0b297d2c52ef762e041e079b98bb7c913 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 00:44:11 +0800 Subject: [PATCH 04/17] Update global_state.py --- scripts/global_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/global_state.py b/scripts/global_state.py index ff1c72317..b5d0cbc8e 100644 --- a/scripts/global_state.py +++ b/scripts/global_state.py @@ -101,7 +101,7 @@ def unified_preprocessor(preprocessor_name: str, *args, **kwargs): "recolor_luminance": recolor_luminance, "recolor_intensity": recolor_intensity, "blur_gaussian": blur_gaussian, - "anime_face_segment": anime_face_segment + "anime_face_segment": anime_face_segment, } cn_preprocessor_unloadable = { @@ -132,8 +132,8 @@ def unified_preprocessor(preprocessor_name: str, *args, **kwargs): "lineart_coarse": unload_lineart_coarse, "lineart_anime": unload_lineart_anime, "lineart_anime_denoise": unload_lineart_anime_denoise, - "inpaint_only+lama": unload_lama_inpaint - "anime_face_segment": unload_anime_face_segment + "inpaint_only+lama": unload_lama_inpaint, + "anime_face_segment": unload_anime_face_segment, } preprocessor_aliases = { From 4ed905c7c7d588df07d6acb2860bc82220dd91c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 00:49:57 +0800 Subject: [PATCH 05/17] fix typo --- scripts/processor.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/scripts/processor.py b/scripts/processor.py index 7fdc64886..86a05fdbb 100644 --- a/scripts/processor.py +++ b/scripts/processor.py @@ -620,24 +620,24 @@ def blur_gaussian(img, res=512, thr_a=1.0, **kwargs): return result, True -model_anime_face_segement = None +model_anime_face_segment = None -def anime_face_segement(img, res=512, **kwargs): +def anime_face_segment(img, res=512, **kwargs): img, remove_pad = resize_image_with_pad(img, res) - global model_anime_face_segement - if model_anime_face_segement is None: - from annotator.model_anime_face_segement import AnimeFaceSegment - model_anime_face_segement = AnimeFaceSegment() + global model_anime_face_segment + if model_anime_face_segment is None: + from annotator.model_anime_face_segment import AnimeFaceSegment + model_anime_face_segment = AnimeFaceSegment() result = model_manga_line(img) return remove_pad(result), True -def unload_anime_face_segement(): - global model_anime_face_segement - if model_anime_face_segement is not None: - model_anime_face_segement.unload_model() +def unload_anime_face_segment(): + global model_anime_face_segment + if model_anime_face_segment is not None: + model_anime_face_segment.unload_model() model_free_preprocessors = [ @@ -1006,7 +1006,7 @@ def unload_anime_face_segement(): "step": 0.001 } ], - "anime_face_segement": [ + "anime_face_segment": [ { "name": flag_preprocessor_resolution, "value": 512, From 1fe6185a59c023acc4a55cfdfa335f0187d78e9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 00:54:15 +0800 Subject: [PATCH 06/17] fix typo2 --- scripts/processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/processor.py b/scripts/processor.py index 86a05fdbb..c7548de81 100644 --- a/scripts/processor.py +++ b/scripts/processor.py @@ -627,10 +627,10 @@ def anime_face_segment(img, res=512, **kwargs): img, remove_pad = resize_image_with_pad(img, res) global model_anime_face_segment if model_anime_face_segment is None: - from annotator.model_anime_face_segment import AnimeFaceSegment + from annotator.anime_face_segment import AnimeFaceSegment model_anime_face_segment = AnimeFaceSegment() - result = model_manga_line(img) + result = model_anime_face_segment(img) return remove_pad(result), True From 1aa1729db60772ac2bcfc112c1a18d41705d8809 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 01:06:04 +0800 Subject: [PATCH 07/17] Update __init__.py --- annotator/anime_face_segment/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py index 215567207..ce729c601 100644 --- a/annotator/anime_face_segment/__init__.py +++ b/annotator/anime_face_segment/__init__.py @@ -2,10 +2,10 @@ import torch import torch.nn as nn from PIL import Image -#import fnmatch +import fnmatch import cv2 -#import sys +import sys import numpy as np from modules import devices @@ -148,8 +148,12 @@ def __call__(self, input_image): if self.model is None: self.load_model() self.model.to(self.device) + transform = transforms.Compose([ + transforms.Resize(512), + transforms.ToTensor(),]) with torch.no_grad(): - seg = self.model(input_image) + src = transform(input_image).unsqueeze(dim=0).cuda() + seg = self.model(src).squeeze(dim=0) seg = seg.cpu().detach().numpy() img = np.moveaxis(seg,0,2) img = [[PALETTE[np.argmax(val)] for val in buf]for buf in seg] From 595cf2c1810b5850ec7b39d67c920382cc371bb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 01:10:54 +0800 Subject: [PATCH 08/17] Update __init__.py --- annotator/anime_face_segment/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py index ce729c601..c5643ca69 100644 --- a/annotator/anime_face_segment/__init__.py +++ b/annotator/anime_face_segment/__init__.py @@ -13,6 +13,7 @@ import torchvision from torchvision.models import MobileNet_V2_Weights +from torchvision import transforms class UNet(nn.Module): From 37f8f2a350481d2a9dd5006087af8f6d558ce602 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 01:20:01 +0800 Subject: [PATCH 09/17] Update __init__.py --- annotator/anime_face_segment/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py index c5643ca69..339601a12 100644 --- a/annotator/anime_face_segment/__init__.py +++ b/annotator/anime_face_segment/__init__.py @@ -153,7 +153,8 @@ def __call__(self, input_image): transforms.Resize(512), transforms.ToTensor(),]) with torch.no_grad(): - src = transform(input_image).unsqueeze(dim=0).cuda() + openimg = Image.open(input_image) + src = transform(openimg).unsqueeze(dim=0).cuda() seg = self.model(src).squeeze(dim=0) seg = seg.cpu().detach().numpy() img = np.moveaxis(seg,0,2) From d043cda8ce49236d354660c52a1d06dac4cb213c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 01:35:10 +0800 Subject: [PATCH 10/17] Update __init__.py --- annotator/anime_face_segment/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py index 339601a12..4ed43547e 100644 --- a/annotator/anime_face_segment/__init__.py +++ b/annotator/anime_face_segment/__init__.py @@ -9,6 +9,7 @@ import numpy as np from modules import devices +from einops import rearrange from annotator.annotator_path import models_path import torchvision @@ -149,13 +150,11 @@ def __call__(self, input_image): if self.model is None: self.load_model() self.model.to(self.device) - transform = transforms.Compose([ - transforms.Resize(512), - transforms.ToTensor(),]) + img = np.ascontiguousarray(input_image.copy()).copy() with torch.no_grad(): - openimg = Image.open(input_image) - src = transform(openimg).unsqueeze(dim=0).cuda() - seg = self.model(src).squeeze(dim=0) + image_feed = torch.from_numpy(img).float().to(self.device) + image_feed = rearrange(image_feed, 'h w c -> 1 c h w') + seg = self.model(image_feed).squeeze(dim=0) seg = seg.cpu().detach().numpy() img = np.moveaxis(seg,0,2) img = [[PALETTE[np.argmax(val)] for val in buf]for buf in seg] From 67dce611ff4731d4e88d9edbc24e0f1fac9b99e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 01:38:34 +0800 Subject: [PATCH 11/17] Update __init__.py --- annotator/anime_face_segment/__init__.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py index 4ed43547e..af30266c4 100644 --- a/annotator/anime_face_segment/__init__.py +++ b/annotator/anime_face_segment/__init__.py @@ -16,6 +16,14 @@ from torchvision.models import MobileNet_V2_Weights from torchvision import transforms +COLOR_BACKGROUND = (0,255,255) +COLOR_HAIR = (255,0,0) +COLOR_EYE = (0,0,255) +COLOR_MOUTH = (255,255,255) +COLOR_FACE = (0,255,0) +COLOR_SKIN = (255,255,0) +COLOR_CLOTHES = (255,0,255) +PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES] class UNet(nn.Module): def __init__(self): @@ -111,14 +119,6 @@ def forward(self, x): class AnimeFaceSegment: - COLOR_BACKGROUND = (0,255,255) - COLOR_HAIR = (255,0,0) - COLOR_EYE = (0,0,255) - COLOR_MOUTH = (255,255,255) - COLOR_FACE = (0,255,0) - COLOR_SKIN = (255,255,0) - COLOR_CLOTHES = (255,0,255) - PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES] model_dir = os.path.join(models_path, "anime_face_segment") @@ -147,6 +147,7 @@ def unload_model(self): self.model.cpu() def __call__(self, input_image): + if self.model is None: self.load_model() self.model.to(self.device) From 7267f6e182328b6446c2b6a356fef03b5f44332d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 01:42:56 +0800 Subject: [PATCH 12/17] Update __init__.py --- annotator/anime_face_segment/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py index af30266c4..384499949 100644 --- a/annotator/anime_face_segment/__init__.py +++ b/annotator/anime_face_segment/__init__.py @@ -158,6 +158,7 @@ def __call__(self, input_image): seg = self.model(image_feed).squeeze(dim=0) seg = seg.cpu().detach().numpy() img = np.moveaxis(seg,0,2) + print(img) img = [[PALETTE[np.argmax(val)] for val in buf]for buf in seg] return np.array(img).astype(np.uint8) From d4b0f88ac62cfa4e8cdc57a71b995fca2b116e04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 02:01:59 +0800 Subject: [PATCH 13/17] Update __init__.py --- annotator/anime_face_segment/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py index 384499949..0e8d9327d 100644 --- a/annotator/anime_face_segment/__init__.py +++ b/annotator/anime_face_segment/__init__.py @@ -151,10 +151,8 @@ def __call__(self, input_image): if self.model is None: self.load_model() self.model.to(self.device) - img = np.ascontiguousarray(input_image.copy()).copy() with torch.no_grad(): - image_feed = torch.from_numpy(img).float().to(self.device) - image_feed = rearrange(image_feed, 'h w c -> 1 c h w') + image_feed = torch.from_numpy(input_image).unsqueeze(dim=0).cuda() seg = self.model(image_feed).squeeze(dim=0) seg = seg.cpu().detach().numpy() img = np.moveaxis(seg,0,2) From a79a5396abf4deb538b8b23199595f486e63c848 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 02:16:12 +0800 Subject: [PATCH 14/17] Update __init__.py --- annotator/anime_face_segment/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py index 0e8d9327d..d5c72e093 100644 --- a/annotator/anime_face_segment/__init__.py +++ b/annotator/anime_face_segment/__init__.py @@ -152,10 +152,11 @@ def __call__(self, input_image): self.load_model() self.model.to(self.device) with torch.no_grad(): - image_feed = torch.from_numpy(input_image).unsqueeze(dim=0).cuda() + image_feed = torch.from_numpy(input_image).float().to(self.device) + image = rearrange(image, 'h w c -> 1 c h w') seg = self.model(image_feed).squeeze(dim=0) seg = seg.cpu().detach().numpy() - img = np.moveaxis(seg,0,2) + #img = np.moveaxis(seg,0,2) print(img) img = [[PALETTE[np.argmax(val)] for val in buf]for buf in seg] return np.array(img).astype(np.uint8) From 91f67ddcc7bc47537a6285864abfc12590f46c3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 03:01:59 +0800 Subject: [PATCH 15/17] RGB2BGR --- annotator/anime_face_segment/__init__.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py index d5c72e093..095564189 100644 --- a/annotator/anime_face_segment/__init__.py +++ b/annotator/anime_face_segment/__init__.py @@ -16,12 +16,12 @@ from torchvision.models import MobileNet_V2_Weights from torchvision import transforms -COLOR_BACKGROUND = (0,255,255) -COLOR_HAIR = (255,0,0) -COLOR_EYE = (0,0,255) +COLOR_BACKGROUND = (255,255,0) +COLOR_HAIR = (0,0,255) +COLOR_EYE = (255,0,0) COLOR_MOUTH = (255,255,255) COLOR_FACE = (0,255,0) -COLOR_SKIN = (255,255,0) +COLOR_SKIN = (0,255,255) COLOR_CLOTHES = (255,0,255) PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES] @@ -151,14 +151,16 @@ def __call__(self, input_image): if self.model is None: self.load_model() self.model.to(self.device) + transform = transforms.Compose([ + transforms.Resize(512), + transforms.ToTensor(),]) + img = Image.fromarray(input_image) with torch.no_grad(): - image_feed = torch.from_numpy(input_image).float().to(self.device) - image = rearrange(image, 'h w c -> 1 c h w') - seg = self.model(image_feed).squeeze(dim=0) + img = transform(img).unsqueeze(dim=0).cuda() + seg = self.model(img).squeeze(dim=0) seg = seg.cpu().detach().numpy() - #img = np.moveaxis(seg,0,2) - print(img) - img = [[PALETTE[np.argmax(val)] for val in buf]for buf in seg] + img = np.moveaxis(seg,0,2) + img = [[PALETTE[np.argmax(val)] for val in buf]for buf in img] return np.array(img).astype(np.uint8) From 5bed6aa2acd4d9595448f3c47ceae44e8dae8248 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Fri, 3 Nov 2023 20:47:15 +0800 Subject: [PATCH 16/17] fix for device --- annotator/anime_face_segment/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py index 095564189..8d91b820d 100644 --- a/annotator/anime_face_segment/__init__.py +++ b/annotator/anime_face_segment/__init__.py @@ -156,7 +156,7 @@ def __call__(self, input_image): transforms.ToTensor(),]) img = Image.fromarray(input_image) with torch.no_grad(): - img = transform(img).unsqueeze(dim=0).cuda() + img = transform(img).unsqueeze(dim=0).to(self.device) seg = self.model(img).squeeze(dim=0) seg = seg.cpu().detach().numpy() img = np.moveaxis(seg,0,2) From f6231582d7473b610df2d9006702285138d8f659 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Fri, 3 Nov 2023 22:22:54 +0800 Subject: [PATCH 17/17] fix for special resolution --- annotator/anime_face_segment/__init__.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py index 8d91b820d..44f295ab0 100644 --- a/annotator/anime_face_segment/__init__.py +++ b/annotator/anime_face_segment/__init__.py @@ -1,6 +1,7 @@ import os import torch import torch.nn as nn +import torch.nn.functional as F from PIL import Image import fnmatch import cv2 @@ -106,12 +107,19 @@ def forward(self, x): e4 = self.en_block4(e3) d4 = self.de_block4(e4) + d4 = F.interpolate(d4, size=e3.size()[2:], mode='bilinear', align_corners=True) c4 = torch.cat((d4,e3),1) + d3 = self.de_block3(c4) + d3 = F.interpolate(d3, size=e2.size()[2:], mode='bilinear', align_corners=True) c3 = torch.cat((d3,e2),1) + d2 = self.de_block2(c3) + d2 = F.interpolate(d2, size=e1.size()[2:], mode='bilinear', align_corners=True) c2 =torch.cat((d2,e1),1) + d1 = self.de_block1(c2) + d1 = F.interpolate(d1, size=e0.size()[2:], mode='bilinear', align_corners=True) c1 = torch.cat((d1,e0),1) y = self.de_block0(c1) @@ -152,14 +160,14 @@ def __call__(self, input_image): self.load_model() self.model.to(self.device) transform = transforms.Compose([ - transforms.Resize(512), + transforms.Resize(512,interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(),]) img = Image.fromarray(input_image) with torch.no_grad(): img = transform(img).unsqueeze(dim=0).to(self.device) seg = self.model(img).squeeze(dim=0) seg = seg.cpu().detach().numpy() - img = np.moveaxis(seg,0,2) + img = rearrange(seg,'h w c -> w c h') img = [[PALETTE[np.argmax(val)] for val in buf]for buf in img] return np.array(img).astype(np.uint8)