Skip to content

Commit

Permalink
More robust ipadapter detection
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed May 27, 2024
1 parent ba3984a commit b6460b3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
21 changes: 10 additions & 11 deletions internal_controlnet/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PuLIDMode,
)
from annotator.util import HWC3
from scripts.supported_preprocessor import Preprocessor


def _unimplemented_func(*args, **kwargs):
Expand Down Expand Up @@ -228,21 +229,12 @@ def parse_ipadapter_input(cls, value) -> Optional[List[torch.Tensor]]:
animatediff_batch: bool = False
batch_modifiers: list = []
batch_image_files: list = []
batch_keyframe_idx: Optional[str|list] = None
batch_keyframe_idx: Optional[str | list] = None

@property
def accepts_multiple_inputs(self) -> bool:
"""This unit can accept multiple input images."""
return self.module in (
"ip-adapter-auto",
"ip-adapter_clip_sdxl",
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter_clip_sd15",
"ip-adapter_face_id",
"ip-adapter_face_id_plus",
"ip-adapter_pulid",
"instant_id_face_embedding",
)
return self.is_ipadapter

@property
def is_animate_diff_batch(self) -> bool:
Expand All @@ -263,6 +255,13 @@ def uses_clip(self) -> bool:
def is_inpaint(self) -> bool:
return "inpaint" in self.module

@property
def is_ipadapter(self) -> bool:
p = Preprocessor.get_preprocessor(self.module)
if p is None:
return False
return "IP-Adapter" in p.tags

def get_actual_preprocessors(self) -> List[Any]:
p = ControlNetUnit.cls_get_preprocessor(self.module)
# Map "ip-adapter-auto" to actual preprocessor.
Expand Down
4 changes: 2 additions & 2 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def controlnet_main_entry(self, p):
elif unit.is_animate_diff_batch or control_model_type in [ControlModelType.SparseCtrl]:
cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None)
def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx):
if unit.accepts_multiple_inputs:
if unit.is_ipadapter:
ip_adapter_image_emb_cond = []
model_net.ipadapter.image_proj_model.to(torch.float32) # noqa
for c in cc:
Expand All @@ -975,7 +975,7 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
logger.info(f"\t{frame_idx}: {frame_path}")
c = SparseCtrl.create_cond_mask(cn_ad_keyframe_idx, c, p.batch_size).cpu()
elif unit.accepts_multiple_inputs:
elif unit.is_ipadapter:
# ip-adapter should do prompt travel
logger.info("IP-Adapter: control prompts will be traveled in the following way:")
for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files):
Expand Down

0 comments on commit b6460b3

Please sign in to comment.