Skip to content

Commit

Permalink
support all recent ip-adapters
Browse files Browse the repository at this point in the history
support all recent ip-adapters
  • Loading branch information
lllyasviel authored Nov 1, 2023
1 parent f2aafcf commit 3011ff6
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 19 deletions.
5 changes: 5 additions & 0 deletions annotator/clipvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import cv2
import torch

from modules import devices
Expand Down Expand Up @@ -80,6 +81,9 @@
clip_vision_h_uc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision_h_uc.data')
clip_vision_h_uc = torch.load(clip_vision_h_uc)['uc']

clip_vision_vith_uc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision_vith_uc.data')
clip_vision_vith_uc = torch.load(clip_vision_vith_uc)['uc']


class ClipVisionDetector:
def __init__(self, config):
Expand Down Expand Up @@ -118,6 +122,7 @@ def unload_model(self):

def __call__(self, input_image):
with torch.no_grad():
input_image = cv2.resize(input_image, (224, 224), interpolation=cv2.INTER_AREA)
clip_vision_model = self.model.cpu()
feat = self.processor(images=input_image, return_tensors="pt")
feat['pixel_values'] = feat['pixel_values'].cpu()
Expand Down
Binary file added annotator/clipvision/clip_vision_vith_uc.data
Binary file not shown.
37 changes: 26 additions & 11 deletions scripts/controlmodel_ipadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,22 @@ def forward(self, x):


class IPAdapterModel(torch.nn.Module):
def __init__(self, state_dict, clip_embeddings_dim, is_plus):
def __init__(self, state_dict, clip_embeddings_dim, cross_attention_dim, is_plus, sdxl_plus):
super().__init__()
self.device = "cpu"

# cross_attention_dim is equal to text_encoder output
self.cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1]
self.cross_attention_dim = cross_attention_dim
self.is_plus = is_plus
self.sdxl_plus = sdxl_plus

if self.is_plus:
self.clip_extra_context_tokens = 16

self.image_proj_model = Resampler(
dim=self.cross_attention_dim,
dim=1280 if sdxl_plus else cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
heads=20 if sdxl_plus else 12,
num_queries=self.clip_extra_context_tokens,
embedding_dim=clip_embeddings_dim,
output_dim=self.cross_attention_dim,
Expand All @@ -200,9 +200,9 @@ def get_image_embeds(self, clip_vision_output):
self.image_proj_model.cpu()

if self.is_plus:
from annotator.clipvision import clip_vision_h_uc
from annotator.clipvision import clip_vision_h_uc, clip_vision_vith_uc
cond = self.image_proj_model(clip_vision_output['hidden_states'][-2].to(device='cpu', dtype=torch.float32))
uncond = self.image_proj_model(clip_vision_h_uc.to(cond))
uncond = clip_vision_vith_uc.to(cond) if self.sdxl_plus else self.image_proj_model(clip_vision_h_uc.to(cond))
return cond, uncond

clip_image_embeds = clip_vision_output['image_embeds'].to(device='cpu', dtype=torch.float32)
Expand Down Expand Up @@ -292,11 +292,26 @@ def clear_all_ip_adapter():


class PlugableIPAdapter(torch.nn.Module):
def __init__(self, state_dict, clip_embeddings_dim, is_plus):
def __init__(self, state_dict):
super().__init__()
self.sdxl = clip_embeddings_dim == 1280 and not is_plus
self.is_plus = is_plus
self.ipadapter = IPAdapterModel(state_dict, clip_embeddings_dim=clip_embeddings_dim, is_plus=is_plus)
self.is_plus = "latents" in state_dict["image_proj"]
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1]
self.sdxl = cross_attention_dim == 2048
self.sdxl_plus = self.sdxl and self.is_plus

if self.is_plus:
if self.sdxl_plus:
clip_embeddings_dim = int(state_dict["image_proj"]["latents"].shape[2])
else:
clip_embeddings_dim = int(state_dict['image_proj']['proj_in.weight'].shape[1])
else:
clip_embeddings_dim = int(state_dict['image_proj']['proj.weight'].shape[1])

self.ipadapter = IPAdapterModel(state_dict,
clip_embeddings_dim=clip_embeddings_dim,
cross_attention_dim=cross_attention_dim,
is_plus=self.is_plus,
sdxl_plus=self.sdxl_plus)
self.disable_memory_management = True
self.dtype = None
self.weight = 1.0
Expand Down
7 changes: 1 addition & 6 deletions scripts/controlnet_model_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,7 @@ def build_model_by_guess(state_dict, unet, model_path):
return network

if 'ip_adapter' in state_dict:
plus = "latents" in state_dict["image_proj"]
if plus:
channel = int(state_dict['image_proj']['proj_in.weight'].shape[1])
else:
channel = int(state_dict['image_proj']['proj.weight'].shape[1])
network = PlugableIPAdapter(state_dict, channel, plus)
network = PlugableIPAdapter(state_dict)
network.to('cpu')
return network

Expand Down
2 changes: 1 addition & 1 deletion scripts/controlnet_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version_flag = 'v1.1.411'
version_flag = 'v1.1.415'

from scripts.logging import logger

Expand Down
4 changes: 3 additions & 1 deletion scripts/global_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from typing import Dict, Callable, Optional, Tuple, List

CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors"]
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"]
cn_models_dir = os.path.join(models_path, "ControlNet")
cn_models_dir_old = os.path.join(scripts.basedir(), "models")
cn_models = OrderedDict() # "My_Lora(abcd1234)" -> C:/path/to/model.safetensors
Expand Down Expand Up @@ -67,6 +67,7 @@ def unified_preprocessor(preprocessor_name: str, *args, **kwargs):
"revision_clipvision": functools.partial(clip, config='clip_g'),
"revision_ignore_prompt": functools.partial(clip, config='clip_g'),
"ip-adapter_clip_sd15": functools.partial(clip, config='clip_h'),
"ip-adapter_clip_sdxl_plus_vith": functools.partial(clip, config='clip_h'),
"ip-adapter_clip_sdxl": functools.partial(clip, config='clip_g'),
"color": color,
"pidinet": pidinet,
Expand Down Expand Up @@ -110,6 +111,7 @@ def unified_preprocessor(preprocessor_name: str, *args, **kwargs):
"revision_clipvision": functools.partial(unload_clip, config='clip_g'),
"revision_ignore_prompt": functools.partial(unload_clip, config='clip_g'),
"ip-adapter_clip_sd15": functools.partial(unload_clip, config='clip_h'),
"ip-adapter_clip_sdxl_plus_vith": functools.partial(unload_clip, config='clip_h'),
"ip-adapter_clip_sdxl": functools.partial(unload_clip, config='clip_g'),
"depth": unload_midas,
"depth_leres": unload_leres,
Expand Down

10 comments on commit 3011ff6

@FurkanGozukara
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From where we should download them?

@moonshinegloss
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lllyasviel
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. not these files.
ip adapters can be downloader from ipadapter official huggingface. we now also support bin models.

@moonshinegloss
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lllyasviel
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. we should update other notes later

@FurkanGozukara
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. we should update other notes later

thank you so we need to download from here

https://huggingface.co/h94/IP-Adapter/tree/main/sdxl_models

maybe you can add your repo? https://huggingface.co/lllyasviel/sd_control_collection/tree/main

@FurkanGozukara
Copy link

@FurkanGozukara FurkanGozukara commented on 3011ff6 Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lllyasviel safetensors files added will they work?

tested they failed

@anwoflow
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to get some idiot-proof guide, how to add those.

@FurkanGozukara
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to get some idiot-proof guide, how to add those.

i will hopefully make a video once supports safetensors too

@lllyasviel
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safetensors ipadapters are not official ipadapters AFAIK

Please sign in to comment.