Skip to content

Commit

Permalink
Merge PR #169 from Kosinkadink/develop - initial flux support
Browse files Browse the repository at this point in the history
Initial flux support, refactoring weight control
  • Loading branch information
Kosinkadink authored Aug 30, 2024
2 parents 949843e + ef16e3c commit dcc928b
Show file tree
Hide file tree
Showing 7 changed files with 552 additions and 218 deletions.
84 changes: 61 additions & 23 deletions adv_control/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import comfy.model_management
import comfy.model_detection
import comfy.controlnet as comfy_cn
from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter
from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter, StrengthType
from comfy.model_patcher import ModelPatcher

from .control_sparsectrl import SparseModelPatcher, SparseControlNet, SparseCtrlMotionWrapper, SparseSettings, SparseConst
Expand All @@ -21,13 +21,23 @@


class ControlNetAdvanced(ControlNet, AdvancedControlBase):
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, compression_ratio=compression_ratio, latent_format=latent_format, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet())
self.is_flux = False
self.x_noisy_shape = None

def get_universal_weights(self) -> ControlWeights:
raw_weights = [(self.weights.base_multiplier ** float(12 - i)) for i in range(13)]
return self.weights.copy_with_new_weights(raw_weights)
def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
if key == "middle":
return 1.0
c_len = len(control[key])
raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)]
raw_weights = raw_weights[:-1]
if key == "input":
raw_weights.reverse()
return raw_weights[idx]
return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func)

def get_control_advanced(self, x_noisy, t, cond, batched_number):
# perform special version of get_control that supports sliding context and masks
Expand All @@ -49,7 +59,6 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype

output_dtype = x_noisy.dtype
# make cond_hint appropriate dimensions
# TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
Expand All @@ -64,9 +73,9 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
actual_cond_hint_orig = self.cond_hint_original
if self.cond_hint_original.size(0) < self.full_latent_length:
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
else:
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
Expand All @@ -81,25 +90,44 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)

context = cond.get('crossattn_controlnet', cond['c_crossattn'])
y = cond.get('y', None)
if y is not None:
y = y.to(dtype)
extra = self.extra_args.copy()
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)

timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
self.x_noisy_shape = x_noisy.shape
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype=None)

control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(control, control_prev, output_dtype)
def pre_run_advanced(self, *args, **kwargs):
self.is_flux = "Flux" in str(type(self.control_model).__name__)
return super().pre_run_advanced(*args, **kwargs)

def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, flux_shape=None):
if self.is_flux:
flux_shape = self.x_noisy_shape
return super().apply_advanced_strengths_and_masks(x, batched_number, flux_shape)

def copy(self):
c = ControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
self.copy_to_advanced(c)
return c

def cleanup_advanced(self):
self.x_noisy_shape = None
return super().cleanup_advanced()

@staticmethod
def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlNetAdvanced':
to_return = ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe,
global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, device=v.device, load_device=v.load_device, manual_cast_dtype=v.manual_cast_dtype)
global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, device=v.device, load_device=v.load_device,
manual_cast_dtype=v.manual_cast_dtype)
v.copy_to(to_return)
return to_return

Expand All @@ -121,18 +149,28 @@ def control_merge_inject(self, control: dict[str, list[Tensor]], control_prev, o
return AdvancedControlBase.control_merge_inject(self, control, control_prev, output_dtype)

def get_universal_weights(self) -> ControlWeights:
raw_weights = [(self.weights.base_multiplier ** float(7 - i)) for i in range(8)]
raw_weights = [raw_weights[-8], raw_weights[-3], raw_weights[-2], raw_weights[-1]]
raw_weights = get_properly_arranged_t2i_weights(raw_weights)
raw_weights.reverse() # need to reverse to match recent ComfyUI changes
return self.weights.copy_with_new_weights(raw_weights)
def t2i_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
if key == "middle":
return 1.0
c_len = 8 #len(control[key])
raw_weights = [(self.weights.base_multiplier ** float((c_len-1) - i)) for i in range(c_len)]
raw_weights = [raw_weights[-c_len], raw_weights[-3], raw_weights[-2], raw_weights[-1]]
raw_weights = get_properly_arranged_t2i_weights(raw_weights)
if key == "input":
raw_weights.reverse()
return raw_weights[idx]
return self.weights.copy_with_new_weights(new_weight_func=t2i_weights_func)

def get_calc_pow(self, idx: int, control: dict[str, list[Tensor]], key: str) -> int:
if key == "middle":
return 0
# match how T2IAdapterAdvanced deals with universal weights
indeces = [7 - i for i in range(8)]
indeces = [indeces[-8], indeces[-3], indeces[-2], indeces[-1]]
c_len = 8 #len(control[key])
indeces = [(c_len-1) - i for i in range(c_len)]
indeces = [indeces[-c_len], indeces[-3], indeces[-2], indeces[-1]]
indeces = get_properly_arranged_t2i_weights(indeces)
indeces.reverse() # need to reverse to match recent ComfyUI changes
if key == "input":
indeces.reverse() # need to reverse to match recent ComfyUI changes
return indeces[idx]

def get_control_advanced(self, x_noisy, t, cond, batched_number):
Expand Down Expand Up @@ -381,11 +419,11 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(control, control_prev, output_dtype)

def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int):
def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, *args, **kwargs):
# apply mults to indexes with and without a direct condhint
x[self.local_sparse_idxs] *= self.sparse_settings.sparse_hint_mult * self.weights.extras.get(SparseConst.HINT_MULT, 1.0)
x[self.local_sparse_idxs_inverse] *= self.sparse_settings.sparse_nonhint_mult * self.weights.extras.get(SparseConst.NONHINT_MULT, 1.0)
return super().apply_advanced_strengths_and_masks(x, batched_number)
return super().apply_advanced_strengths_and_masks(x, batched_number, *args, **kwargs)

def pre_run_advanced(self, model, percent_to_timestep_function):
super().pre_run_advanced(model, percent_to_timestep_function)
Expand Down
13 changes: 10 additions & 3 deletions adv_control/control_plusplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,16 @@ def __init__(self, control_model: ControlNetPlusPlus, timestep_keyframes: Timest
self.single_control_type: str = None

def get_universal_weights(self) -> ControlWeights:
# TODO: match actual layer count of model
raw_weights = [(self.weights.base_multiplier ** float(12 - i)) for i in range(13)]
return self.weights.copy_with_new_weights(raw_weights)
def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
if key == "middle":
return 1.0
c_len = len(control[key])
raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)]
raw_weights = raw_weights[:-1]
if key == "input":
raw_weights.reverse()
return raw_weights[idx]
return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func)

def verify_control_type(self, model_name: str, pp_group: PlusPlusInputGroup=None):
if pp_group is not None:
Expand Down
Loading

0 comments on commit dcc928b

Please sign in to comment.