From 4e26b7b86d02aaca7ef81a18657c61398ce59e68 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sun, 14 Jul 2024 20:58:43 -0400 Subject: [PATCH 1/2] Support ProMax union model --- scripts/cldm.py | 6 +++--- scripts/controlnet_model_guess.py | 2 +- scripts/enums.py | 6 ++++++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/scripts/cldm.py b/scripts/cldm.py index c404fd956..54759dbb8 100644 --- a/scripts/cldm.py +++ b/scripts/cldm.py @@ -91,7 +91,7 @@ def __init__( use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, - union_controlnet=False, + union_controlnet_num_control_type=None, device=None, global_average_pooling=False, ): @@ -282,8 +282,8 @@ def __init__( self.middle_block_out = self.make_zero_conv(ch) self._feature_size += ch - if union_controlnet: - self.num_control_type = 6 + if union_controlnet_num_control_type is not None: + self.num_control_type = union_controlnet_num_control_type num_trans_channel = 320 num_trans_head = 8 num_trans_layer = 1 diff --git a/scripts/controlnet_model_guess.py b/scripts/controlnet_model_guess.py index 030f43deb..49c2000e7 100644 --- a/scripts/controlnet_model_guess.py +++ b/scripts/controlnet_model_guess.py @@ -222,7 +222,7 @@ def build_model_by_guess(state_dict, unet, model_path: str) -> ControlModel: state_dict = final_state_dict if "control_add_embedding.linear_1.bias" in state_dict: # Controlnet Union - config["union_controlnet"] = True + config["union_controlnet_num_control_type"] = state_dict["task_embedding"].shape[0] final_state_dict = {} for k in list(state_dict.keys()): new_k = k.replace('.attn.in_proj_', '.attn.in_proj.') diff --git a/scripts/enums.py b/scripts/enums.py index e4f7b2d65..834847ac1 100644 --- a/scripts/enums.py +++ b/scripts/enums.py @@ -292,6 +292,8 @@ class ControlNetUnionControlType(Enum): HARD_EDGE = "Hard Edge" NORMAL_MAP = "Normal Map" SEGMENTATION = "Segmentation" + TILE = "Tile" + INPAINT = "Inpaint" UNKNOWN = "Unknown" @@ -326,6 +328,10 @@ def from_str(s: str) -> ControlNetUnionControlType: return ControlNetUnionControlType.NORMAL_MAP elif s == "segmentation": return ControlNetUnionControlType.SEGMENTATION + elif s in ["tile", "blur"]: + return ControlNetUnionControlType.TILE + elif s == "inpaint": + return ControlNetUnionControlType.INPAINT return ControlNetUnionControlType.UNKNOWN From 8768a0728abc286bbd06cf39d309a5ce2f8d3238 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sun, 14 Jul 2024 21:16:30 -0400 Subject: [PATCH 2/2] nit --- scripts/enums.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/enums.py b/scripts/enums.py index 834847ac1..306050876 100644 --- a/scripts/enums.py +++ b/scripts/enums.py @@ -310,6 +310,8 @@ def all_tags() -> List[str]: "mlsd", "normalmap", "segmentation", + "inpaint", + "tile", ] @staticmethod