Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance positional encoding adjustment in SparseCtrl loading with exp… #83

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions adv_control/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,10 +523,34 @@ def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroup=None):
def is_advanced_controlnet(input_object):
return hasattr(input_object, "sub_idxs")

def adjust_positional_encoding_parameters(controlnet_data, expected_seq_len):
"""
Adjusts the positional encoding parameters in the model state dict for the expected sequence length.
This is a utility function to ensure compatibility with models saved with different configurations.
"""
pe_keys = [key for key in controlnet_data.keys() if "pos_encoder.pe" in key]
for key in pe_keys:
original_pe = controlnet_data[key]
_, seq_len, dim = original_pe.shape
if seq_len != expected_seq_len:
# Ensure expected_seq_len and dim are integers
expected_seq_len = int(expected_seq_len)
dim = int(dim)

# Adjust the positional encoding to match the expected sequence length.
adjusted_pe = torch.zeros((1, expected_seq_len, dim))
length_to_copy = min(seq_len, expected_seq_len)
adjusted_pe[:, :length_to_copy, :] = original_pe[:, :length_to_copy, :]
controlnet_data[key] = adjusted_pe


def load_sparsectrl(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, sparse_settings=SparseSettings.default(), model=None) -> SparseCtrlAdvanced:
if controlnet_data is None:
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)

# Adjust positional encoding parameters before loading parts of the model, using the expected_seq_len from sparse_settings
adjust_positional_encoding_parameters(controlnet_data, sparse_settings.expected_seq_len)

# first, separate out motion part from normal controlnet part and attempt to load that portion
motion_data = {}
for key in list(controlnet_data.keys()):
Expand Down
5 changes: 3 additions & 2 deletions adv_control/control_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,17 @@ def __setitem__(self, *args, **kwargs):


class SparseSettings:
def __init__(self, sparse_method: 'SparseMethod', use_motion: bool=True, motion_strength=1.0, motion_scale=1.0, merged=False):
def __init__(self, sparse_method: 'SparseMethod', use_motion: bool=True, motion_strength=1.0, motion_scale=1.0, merged=False, expected_seq_len=32):
self.sparse_method = sparse_method
self.use_motion = use_motion
self.motion_strength = motion_strength
self.motion_scale = motion_scale
self.merged = merged
self.expected_seq_len = expected_seq_len # Add expected sequence length for positional encodings

@classmethod
def default(cls):
return SparseSettings(sparse_method=SparseSpreadMethod(), use_motion=True)
return SparseSettings(sparse_method=SparseSpreadMethod(), use_motion=True, expected_seq_len=32)


class SparseMethod(ABC):
Expand Down
5 changes: 3 additions & 2 deletions adv_control/nodes_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def INPUT_TYPES(s):
"use_motion": ("BOOLEAN", {"default": True}, ),
"motion_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
"motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
"expected_seq_len": ("INT", {"default": 32.0, "min": 1.0, "step": 1}, ),
},
"optional": {
"sparse_method": ("SPARSE_METHOD", ),
Expand All @@ -32,9 +33,9 @@ def INPUT_TYPES(s):

CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl"

def load_controlnet(self, sparsectrl_name: str, use_motion: bool, motion_strength: float, motion_scale: float, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None):
def load_controlnet(self, sparsectrl_name: str, use_motion: bool, motion_strength: float, motion_scale: float, expected_seq_len: int = 32, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None):
sparsectrl_path = folder_paths.get_full_path("controlnet", sparsectrl_name)
sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale)
sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale, expected_seq_len=expected_seq_len)
sparsectrl = load_sparsectrl(sparsectrl_path, timestep_keyframe=tk_optional, sparse_settings=sparse_settings)
return (sparsectrl,)

Expand Down