From 179b557a84fe89f3086453ed6429c694d9da0502 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Sun, 12 Nov 2023 15:37:26 +0800 Subject: [PATCH 1/3] Add support for IP Adapter Full Face --- scripts/controlmodel_ipadapter.py | 55 ++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/scripts/controlmodel_ipadapter.py b/scripts/controlmodel_ipadapter.py index cfd3223f8..8e28caee2 100644 --- a/scripts/controlmodel_ipadapter.py +++ b/scripts/controlmodel_ipadapter.py @@ -7,6 +7,20 @@ SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2 SD_XL_CHANNELS = [640] * 8 + [1280] * 40 + [1280] * 60 + [640] * 12 + [1280] * 20 +class MLPProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), + torch.nn.GELU(), + torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), + torch.nn.LayerNorm(cross_attention_dim) + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens class ImageProjModel(torch.nn.Module): """Projection Model""" @@ -158,27 +172,34 @@ def forward(self, x): class IPAdapterModel(torch.nn.Module): - def __init__(self, state_dict, clip_embeddings_dim, cross_attention_dim, is_plus, sdxl_plus): + def __init__(self, state_dict, clip_embeddings_dim, cross_attention_dim, is_plus, sdxl_plus, is_full): super().__init__() self.device = "cpu" self.cross_attention_dim = cross_attention_dim self.is_plus = is_plus self.sdxl_plus = sdxl_plus + self.is_full = is_full if self.is_plus: - self.clip_extra_context_tokens = 16 - - self.image_proj_model = Resampler( - dim=1280 if sdxl_plus else cross_attention_dim, - depth=4, - dim_head=64, - heads=20 if sdxl_plus else 12, - num_queries=self.clip_extra_context_tokens, - embedding_dim=clip_embeddings_dim, - output_dim=self.cross_attention_dim, - ff_mult=4 - ) + if self.is_full: + self.image_proj_model = MLPProjModel( + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=clip_embeddings_dim + ) + else: + self.clip_extra_context_tokens = 16 + + self.image_proj_model = Resampler( + dim=1280 if sdxl_plus else cross_attention_dim, + depth=4, + dim_head=64, + heads=20 if sdxl_plus else 12, + num_queries=self.clip_extra_context_tokens, + embedding_dim=clip_embeddings_dim, + output_dim=self.cross_attention_dim, + ff_mult=4 + ) else: self.clip_extra_context_tokens = state_dict["image_proj"]["proj.weight"].shape[0] // self.cross_attention_dim @@ -294,7 +315,8 @@ def clear_all_ip_adapter(): class PlugableIPAdapter(torch.nn.Module): def __init__(self, state_dict): super().__init__() - self.is_plus = "latents" in state_dict["image_proj"] + self.is_full = "proj.0.weight" in state_dict['image_proj'] + self.is_plus = self.is_full or "latents" in state_dict["image_proj"] cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1] self.sdxl = cross_attention_dim == 2048 self.sdxl_plus = self.sdxl and self.is_plus @@ -302,6 +324,8 @@ def __init__(self, state_dict): if self.is_plus: if self.sdxl_plus: clip_embeddings_dim = int(state_dict["image_proj"]["latents"].shape[2]) + elif self.is_full: + clip_embeddings_dim = int(state_dict["image_proj"]["proj.0.weight"].shape[1]) else: clip_embeddings_dim = int(state_dict['image_proj']['proj_in.weight'].shape[1]) else: @@ -311,7 +335,8 @@ def __init__(self, state_dict): clip_embeddings_dim=clip_embeddings_dim, cross_attention_dim=cross_attention_dim, is_plus=self.is_plus, - sdxl_plus=self.sdxl_plus) + sdxl_plus=self.sdxl_plus, + is_full=self.is_full) self.disable_memory_management = True self.dtype = None self.weight = 1.0 From e27497e1584df33ce33cd9776e7c44c56a84b4fe Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sun, 12 Nov 2023 20:59:16 -0500 Subject: [PATCH 2/3] :white_check_mark: Add unittests --- .github/workflows/tests.yml | 17 ++++++++++--- tests/web_api/txt2img_test.py | 47 ++++++++++++++++++++++++++++++++--- 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 128c0833d..5771914ef 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,7 +48,7 @@ jobs: uses: actions/cache@v3 with: path: stable-diffusion-webui/extensions/sd-webui-controlnet/models/ - key: controlnet-models-v1 + key: controlnet-models-v2 - name: Cache Preprocessor models uses: actions/cache@v3 with: @@ -56,9 +56,18 @@ jobs: key: preprocessor-models-v1 - name: Download controlnet model for testing run: | - if [ ! -f "extensions/sd-webui-controlnet/models/control_v11p_sd15_canny.pth" ]; then - curl -Lo extensions/sd-webui-controlnet/models/control_v11p_sd15_canny.pth https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth - fi + declare -a urls=( + "https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth" + "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-full-face_sd15.safetensors" + "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.safetensors" + ) + + for url in "${urls[@]}"; do + filename="extensions/sd-webui-controlnet/models/${url##*/}" # Extracts the last part of the URL + if [ ! -f "$filename" ]; then + curl -Lo "$filename" "$url" + fi + done working-directory: stable-diffusion-webui - name: Start test server run: > diff --git a/tests/web_api/txt2img_test.py b/tests/web_api/txt2img_test.py index 249d53c3b..928bd5cd4 100644 --- a/tests/web_api/txt2img_test.py +++ b/tests/web_api/txt2img_test.py @@ -8,9 +8,9 @@ class TestAlwaysonTxt2ImgWorking(unittest.TestCase): def setUp(self): - sd_version = StableDiffusionVersion(int( + self.sd_version = StableDiffusionVersion(int( os.environ.get("CONTROLNET_TEST_SD_VERSION", StableDiffusionVersion.SD1x.value))) - self.model = utils.get_model("canny", sd_version) + self.model = utils.get_model("canny", self.sd_version) controlnet_unit = { "enabled": True, @@ -167,7 +167,48 @@ def test_save_detected_map(self): resp = requests.post(self.url_txt2img, json=self.simple_txt2img).json() self.assertEqual(2 if save_map else 1, len(resp["images"])) - + + def test_ip_adapter_face(self): + match self.sd_version: + case StableDiffusionVersion.SDXL: + model = "ip-adapter-plus-face_sdxl_vit-h" + module = "ip-adapter_clip_sdxl_plus_vith" + case StableDiffusionVersion.SD1x: + model = "ip-adapter-plus-face_sd15" + module = "ip-adapter_clip_sd15" + case _: + # Skip the test for all other versions + return + + self.simple_txt2img["alwayson_scripts"]["ControlNet"]["args"] = [ + { + "input_image": utils.readImage("test/test_files/img2img_basic.png"), + "model": utils.get_model(model, self.sd_version), + "module": module, + } + ] + + self.assert_status_ok() + + def test_ip_adapter_fullface(self): + match self.sd_version: + case StableDiffusionVersion.SD1x: + model = "ip-adapter-full-face_sd15" + module = "ip-adapter_clip_sd15" + case _: + # Skip the test for all other versions + return + + self.simple_txt2img["alwayson_scripts"]["ControlNet"]["args"] = [ + { + "input_image": utils.readImage("test/test_files/img2img_basic.png"), + "model": utils.get_model(model, self.sd_version), + "module": module, + } + ] + + self.assert_status_ok() + if __name__ == "__main__": unittest.main() From b8782875d8a625931e0589104bbc181bf6f5864b Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sun, 12 Nov 2023 21:23:58 -0500 Subject: [PATCH 3/3] nits --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5771914ef..c71626f61 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -103,10 +103,10 @@ jobs: if: always() with: name: output - path: output.txt + path: stable-diffusion-webui/output.txt - name: Upload coverage HTML uses: actions/upload-artifact@v3 if: always() with: name: htmlcov - path: htmlcov + path: stable-diffusion-webui/htmlcov