From 4a4c86bd6dd04949d1890e67192b7aafef5e94ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 2 Nov 2023 00:13:12 +0800 Subject: [PATCH] Add anime_face_segment Update __init__.py update add aliases Update global_state.py fix typo fix typo2 Update __init__.py Update __init__.py Update __init__.py Update __init__.py Update __init__.py Update __init__.py Update __init__.py Update __init__.py RGB2BGR fix for device fix for special resolution --- annotator/anime_face_segment/LICENSE | 21 +++ annotator/anime_face_segment/__init__.py | 174 +++++++++++++++++++++++ scripts/global_state.py | 5 +- scripts/processor.py | 28 ++++ 4 files changed, 227 insertions(+), 1 deletion(-) create mode 100644 annotator/anime_face_segment/LICENSE create mode 100644 annotator/anime_face_segment/__init__.py diff --git a/annotator/anime_face_segment/LICENSE b/annotator/anime_face_segment/LICENSE new file mode 100644 index 000000000..9bad05450 --- /dev/null +++ b/annotator/anime_face_segment/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Miaomiao Li + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/annotator/anime_face_segment/__init__.py b/annotator/anime_face_segment/__init__.py new file mode 100644 index 000000000..44f295ab0 --- /dev/null +++ b/annotator/anime_face_segment/__init__.py @@ -0,0 +1,174 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +import fnmatch +import cv2 + +import sys + +import numpy as np +from modules import devices +from einops import rearrange +from annotator.annotator_path import models_path + +import torchvision +from torchvision.models import MobileNet_V2_Weights +from torchvision import transforms + +COLOR_BACKGROUND = (255,255,0) +COLOR_HAIR = (0,0,255) +COLOR_EYE = (255,0,0) +COLOR_MOUTH = (255,255,255) +COLOR_FACE = (0,255,0) +COLOR_SKIN = (0,255,255) +COLOR_CLOTHES = (255,0,255) +PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES] + +class UNet(nn.Module): + def __init__(self): + super(UNet, self).__init__() + self.NUM_SEG_CLASSES = 7 # Background, hair, face, eye, mouth, skin, clothes + + mobilenet_v2 = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1) + mob_blocks = mobilenet_v2.features + + # Encoder + self.en_block0 = nn.Sequential( # in_ch=3 out_ch=16 + mob_blocks[0], + mob_blocks[1] + ) + self.en_block1 = nn.Sequential( # in_ch=16 out_ch=24 + mob_blocks[2], + mob_blocks[3], + ) + self.en_block2 = nn.Sequential( # in_ch=24 out_ch=32 + mob_blocks[4], + mob_blocks[5], + mob_blocks[6], + ) + self.en_block3 = nn.Sequential( # in_ch=32 out_ch=96 + mob_blocks[7], + mob_blocks[8], + mob_blocks[9], + mob_blocks[10], + mob_blocks[11], + mob_blocks[12], + mob_blocks[13], + ) + self.en_block4 = nn.Sequential( # in_ch=96 out_ch=160 + mob_blocks[14], + mob_blocks[15], + mob_blocks[16], + ) + + # Decoder + self.de_block4 = nn.Sequential( # in_ch=160 out_ch=96 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(160, 96, kernel_size=3, padding=1), + nn.InstanceNorm2d(96), + nn.LeakyReLU(0.1), + nn.Dropout(p=0.2) + ) + self.de_block3 = nn.Sequential( # in_ch=96x2 out_ch=32 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(96*2, 32, kernel_size=3, padding=1), + nn.InstanceNorm2d(32), + nn.LeakyReLU(0.1), + nn.Dropout(p=0.2) + ) + self.de_block2 = nn.Sequential( # in_ch=32x2 out_ch=24 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(32*2, 24, kernel_size=3, padding=1), + nn.InstanceNorm2d(24), + nn.LeakyReLU(0.1), + nn.Dropout(p=0.2) + ) + self.de_block1 = nn.Sequential( # in_ch=24x2 out_ch=16 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(24*2, 16, kernel_size=3, padding=1), + nn.InstanceNorm2d(16), + nn.LeakyReLU(0.1), + nn.Dropout(p=0.2) + ) + + self.de_block0 = nn.Sequential( # in_ch=16x2 out_ch=7 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1), + nn.Softmax2d() + ) + + def forward(self, x): + e0 = self.en_block0(x) + e1 = self.en_block1(e0) + e2 = self.en_block2(e1) + e3 = self.en_block3(e2) + e4 = self.en_block4(e3) + + d4 = self.de_block4(e4) + d4 = F.interpolate(d4, size=e3.size()[2:], mode='bilinear', align_corners=True) + c4 = torch.cat((d4,e3),1) + + d3 = self.de_block3(c4) + d3 = F.interpolate(d3, size=e2.size()[2:], mode='bilinear', align_corners=True) + c3 = torch.cat((d3,e2),1) + + d2 = self.de_block2(c3) + d2 = F.interpolate(d2, size=e1.size()[2:], mode='bilinear', align_corners=True) + c2 =torch.cat((d2,e1),1) + + d1 = self.de_block1(c2) + d1 = F.interpolate(d1, size=e0.size()[2:], mode='bilinear', align_corners=True) + c1 = torch.cat((d1,e0),1) + y = self.de_block0(c1) + + return y + + +class AnimeFaceSegment: + + model_dir = os.path.join(models_path, "anime_face_segment") + + def __init__(self): + self.model = None + self.device = devices.get_device_for("controlnet") + + def load_model(self): + remote_model_path = "https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/resolve/main/Annotators/UNet.pth" + modelpath = os.path.join(self.model_dir, "Unet.pth") + if not os.path.exists(modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=self.model_dir) + net = UNet() + ckpt = torch.load(modelpath) + for key in list(ckpt.keys()): + if 'module.' in key: + ckpt[key.replace('module.', '')] = ckpt[key] + del ckpt[key] + net.load_state_dict(ckpt) + net.eval() + self.model = net.to(self.device) + + def unload_model(self): + if self.model is not None: + self.model.cpu() + + def __call__(self, input_image): + + if self.model is None: + self.load_model() + self.model.to(self.device) + transform = transforms.Compose([ + transforms.Resize(512,interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(),]) + img = Image.fromarray(input_image) + with torch.no_grad(): + img = transform(img).unsqueeze(dim=0).to(self.device) + seg = self.model(img).squeeze(dim=0) + seg = seg.cpu().detach().numpy() + img = rearrange(seg,'h w c -> w c h') + img = [[PALETTE[np.argmax(val)] for val in buf]for buf in img] + return np.array(img).astype(np.uint8) + + diff --git a/scripts/global_state.py b/scripts/global_state.py index ecd340d05..4b9eb8ede 100644 --- a/scripts/global_state.py +++ b/scripts/global_state.py @@ -103,6 +103,7 @@ def unified_preprocessor(preprocessor_name: str, *args, **kwargs): "recolor_luminance": recolor_luminance, "recolor_intensity": recolor_intensity, "blur_gaussian": blur_gaussian, + "anime_face_segment": anime_face_segment, } cn_preprocessor_unloadable = { @@ -133,7 +134,8 @@ def unified_preprocessor(preprocessor_name: str, *args, **kwargs): "lineart_coarse": unload_lineart_coarse, "lineart_anime": unload_lineart_anime, "lineart_anime_denoise": unload_lineart_anime_denoise, - "inpaint_only+lama": unload_lama_inpaint + "inpaint_only+lama": unload_lama_inpaint, + "anime_face_segment": unload_anime_face_segment, } preprocessor_aliases = { @@ -154,6 +156,7 @@ def unified_preprocessor(preprocessor_name: str, *args, **kwargs): "oneformer_ade20k": "seg_ofade20k", "pidinet_scribble": "scribble_pidinet", "inpaint": "inpaint_global_harmonious", + "anime_face_segment": "seg_anime_face", } ui_preprocessor_keys = ['none', preprocessor_aliases['invert']] diff --git a/scripts/processor.py b/scripts/processor.py index 7582e83de..c7548de81 100644 --- a/scripts/processor.py +++ b/scripts/processor.py @@ -620,6 +620,26 @@ def blur_gaussian(img, res=512, thr_a=1.0, **kwargs): return result, True +model_anime_face_segment = None + + +def anime_face_segment(img, res=512, **kwargs): + img, remove_pad = resize_image_with_pad(img, res) + global model_anime_face_segment + if model_anime_face_segment is None: + from annotator.anime_face_segment import AnimeFaceSegment + model_anime_face_segment = AnimeFaceSegment() + + result = model_anime_face_segment(img) + return remove_pad(result), True + + +def unload_anime_face_segment(): + global model_anime_face_segment + if model_anime_face_segment is not None: + model_anime_face_segment.unload_model() + + model_free_preprocessors = [ "reference_only", "reference_adain", @@ -986,6 +1006,14 @@ def blur_gaussian(img, res=512, thr_a=1.0, **kwargs): "step": 0.001 } ], + "anime_face_segment": [ + { + "name": flag_preprocessor_resolution, + "value": 512, + "min": 64, + "max": 2048 + } + ], } preprocessor_filters = {