From 3ff69b9ea32f59ac866780b19a6465f9b9d28053 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Sun, 14 Jul 2024 21:17:29 -0400 Subject: [PATCH] Support ProMax union model (#2998) * Support ProMax union model * nit --- scripts/cldm.py | 6 +++--- scripts/controlnet_model_guess.py | 2 +- scripts/enums.py | 8 ++++++++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/scripts/cldm.py b/scripts/cldm.py index c404fd956..54759dbb8 100644 --- a/scripts/cldm.py +++ b/scripts/cldm.py @@ -91,7 +91,7 @@ def __init__( use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, - union_controlnet=False, + union_controlnet_num_control_type=None, device=None, global_average_pooling=False, ): @@ -282,8 +282,8 @@ def __init__( self.middle_block_out = self.make_zero_conv(ch) self._feature_size += ch - if union_controlnet: - self.num_control_type = 6 + if union_controlnet_num_control_type is not None: + self.num_control_type = union_controlnet_num_control_type num_trans_channel = 320 num_trans_head = 8 num_trans_layer = 1 diff --git a/scripts/controlnet_model_guess.py b/scripts/controlnet_model_guess.py index 030f43deb..49c2000e7 100644 --- a/scripts/controlnet_model_guess.py +++ b/scripts/controlnet_model_guess.py @@ -222,7 +222,7 @@ def build_model_by_guess(state_dict, unet, model_path: str) -> ControlModel: state_dict = final_state_dict if "control_add_embedding.linear_1.bias" in state_dict: # Controlnet Union - config["union_controlnet"] = True + config["union_controlnet_num_control_type"] = state_dict["task_embedding"].shape[0] final_state_dict = {} for k in list(state_dict.keys()): new_k = k.replace('.attn.in_proj_', '.attn.in_proj.') diff --git a/scripts/enums.py b/scripts/enums.py index e4f7b2d65..306050876 100644 --- a/scripts/enums.py +++ b/scripts/enums.py @@ -292,6 +292,8 @@ class ControlNetUnionControlType(Enum): HARD_EDGE = "Hard Edge" NORMAL_MAP = "Normal Map" SEGMENTATION = "Segmentation" + TILE = "Tile" + INPAINT = "Inpaint" UNKNOWN = "Unknown" @@ -308,6 +310,8 @@ def all_tags() -> List[str]: "mlsd", "normalmap", "segmentation", + "inpaint", + "tile", ] @staticmethod @@ -326,6 +330,10 @@ def from_str(s: str) -> ControlNetUnionControlType: return ControlNetUnionControlType.NORMAL_MAP elif s == "segmentation": return ControlNetUnionControlType.SEGMENTATION + elif s in ["tile", "blur"]: + return ControlNetUnionControlType.TILE + elif s == "inpaint": + return ControlNetUnionControlType.INPAINT return ControlNetUnionControlType.UNKNOWN