Skip to content

Commit 68e5cd4

Browse files
authored
Merge PR #145 from Kosinkadink/develop - ContextRef support for ADE
Added ContextRef support for ADE + ReferenceCN Refactor
2 parents b2c9b7a + 890e8cb commit 68e5cd4

7 files changed

+635
-185
lines changed

adv_control/control.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,10 @@
1616
from .control_svd import svd_unet_config_from_diffusers_unet, SVDControlNet, svd_unet_to_diffusers
1717
from .utils import (AdvancedControlBase, TimestepKeyframeGroup, LatentKeyframeGroup, AbstractPreprocWrapper, ControlWeightType, ControlWeights, WeightTypeException,
1818
manual_cast_clean_groupnorm, disable_weight_init_clean_groupnorm, prepare_mask_batch, get_properly_arranged_t2i_weights, load_torch_file_with_dict_factory,
19-
broadcast_image_to_extend, extend_to_batch_size)
19+
broadcast_image_to_extend, extend_to_batch_size, ORIG_PREVIOUS_CONTROLNET, CONTROL_INIT_BY_ACN)
2020
from .logger import logger
2121

2222

23-
ORIG_PREVIOUS_CONTROLNET = "_orig_previous_controlnet"
24-
25-
2623
class ControlNetAdvanced(ControlNet, AdvancedControlBase):
2724
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
2825
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, compression_ratio=compression_ratio, latent_format=latent_format, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
@@ -574,7 +571,13 @@ def restore_all_controlnet_conns(conds: list[list[dict[str]]]):
574571
if main_cond is not None:
575572
for cond in main_cond:
576573
if "control" in cond[1]:
577-
_restore_all_controlnet_conns(cond[1]["control"])
574+
# if ACN is the one to have initialized it, delete it
575+
# TODO: maybe check if someone else did a similar hack, and carefully pluck out our stuff?
576+
if CONTROL_INIT_BY_ACN in cond[1]:
577+
cond[1].pop("control")
578+
cond[1].pop(CONTROL_INIT_BY_ACN)
579+
else:
580+
_restore_all_controlnet_conns(cond[1]["control"])
578581

579582

580583
def _restore_all_controlnet_conns(input_object: ControlBase):

adv_control/control_reference.py

+563-151
Large diffs are not rendered by default.

adv_control/nodes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,10 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta
231231
NODE_DISPLAY_NAME_MAPPINGS = {
232232
# Keyframes
233233
"TimestepKeyframe": "Timestep Keyframe 🛂🅐🅒🅝",
234-
"ACN_TimestepKeyframeInterpolation": "Timestep Keyframe Interpolation 🛂🅐🅒🅝",
234+
"ACN_TimestepKeyframeInterpolation": "Timestep Keyframe Interp. 🛂🅐🅒🅝",
235235
"ACN_TimestepKeyframeFromStrengthList": "Timestep Keyframe From List 🛂🅐🅒🅝",
236236
"LatentKeyframe": "Latent Keyframe 🛂🅐🅒🅝",
237-
"LatentKeyframeTiming": "Latent Keyframe Interpolation 🛂🅐🅒🅝",
237+
"LatentKeyframeTiming": "Latent Keyframe Interp. 🛂🅐🅒🅝",
238238
"LatentKeyframeBatchedGroup": "Latent Keyframe From List 🛂🅐🅒🅝",
239239
"LatentKeyframeGroup": "Latent Keyframe Group 🛂🅐🅒🅝",
240240
# Conditioning

adv_control/nodes_keyframes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def INPUT_TYPES(s):
8282
"inherit_missing": ("BOOLEAN", {"default": True},),
8383
"mask_optional": ("MASK", ),
8484
"print_keyframes": ("BOOLEAN", {"default": False}),
85-
"autosize": ("ACNAUTOSIZE", {"padding": 70}),
85+
"autosize": ("ACNAUTOSIZE", {"padding": 50}),
8686
}
8787
}
8888

@@ -355,7 +355,7 @@ def INPUT_TYPES(s):
355355
"optional": {
356356
"prev_latent_kf": ("LATENT_KEYFRAME", ),
357357
"print_keyframes": ("BOOLEAN", {"default": False}),
358-
"autosize": ("ACNAUTOSIZE", {"padding": 90}),
358+
"autosize": ("ACNAUTOSIZE", {"padding": 50}),
359359
}
360360
}
361361

adv_control/sampling.py

+43-18
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,44 @@
1111
RefBasicTransformerBlock, RefTimestepEmbedSequential,
1212
InjectionBasicTransformerBlockHolder, InjectionTimestepEmbedSequentialHolder,
1313
_forward_inject_BasicTransformerBlock, factory_forward_inject_UNetModel,
14-
REF_CONTROL_LIST_ALL)
14+
handle_context_ref_setup,
15+
REF_CONTROL_LIST_ALL, CONTEXTREF_CLEAN_FUNC)
1516
from .control_lllite import (ControlLLLiteAdvanced)
1617
from .utils import torch_dfs
1718

1819

1920
def support_sliding_context_windows(model, positive, negative) -> tuple[bool, dict, dict]:
20-
if not hasattr(model, "motion_injection_params"):
21-
return False, positive, negative
22-
motion_injection_params = getattr(model, "motion_injection_params")
23-
context_options = getattr(motion_injection_params, "context_options")
24-
if context_options.context_length is None:
25-
return False, positive, negative
2621
# convert to advanced, with report if anything was actually modified
2722
modified, new_conds = convert_all_to_advanced([positive, negative])
2823
positive, negative = new_conds
2924
return modified, positive, negative
3025

3126

27+
def has_sliding_context_windows(model):
28+
motion_injection_params = getattr(model, "motion_injection_params", None)
29+
if motion_injection_params is None:
30+
return False
31+
context_options = getattr(motion_injection_params, "context_options")
32+
return context_options.context_length is not None
33+
34+
35+
def get_contextref_obj(model):
36+
motion_injection_params = getattr(model, "motion_injection_params", None)
37+
if motion_injection_params is None:
38+
return None
39+
context_options = getattr(motion_injection_params, "context_options")
40+
extras = getattr(context_options, "extras", None)
41+
if extras is None:
42+
return None
43+
return getattr(extras, "context_ref", None)
44+
45+
3246
def acn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable:
3347
def get_refcn(control: ControlBase, order: int=-1):
3448
ref_set: set[ReferenceAdvanced] = set()
3549
if control is None:
3650
return ref_set
37-
if type(control) == ReferenceAdvanced:
51+
if type(control) == ReferenceAdvanced and not control.is_context_ref:
3852
control.order = order
3953
order -= 1
4054
ref_set.add(control)
@@ -59,13 +73,23 @@ def acn_sample(model: ModelPatcher, *args, **kwargs):
5973
# check if positive or negative conds contain ref cn
6074
positive = args[-3]
6175
negative = args[-2]
62-
# if context options present, convert all CNs to Advanced if needed
63-
controlnets_modified, positive, negative = support_sliding_context_windows(model, positive, negative)
64-
if controlnets_modified:
65-
args = list(args)
66-
args[-3] = positive
67-
args[-2] = negative
68-
args = tuple(args)
76+
# if context options present, perform some special actions that may be required
77+
context_refs = []
78+
if has_sliding_context_windows(model):
79+
model.model_options = model.model_options.copy()
80+
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
81+
# convert all CNs to Advanced if needed
82+
controlnets_modified, positive, negative = support_sliding_context_windows(model, positive, negative)
83+
if controlnets_modified:
84+
args = list(args)
85+
args[-3] = positive
86+
args[-2] = negative
87+
args = tuple(args)
88+
# enable ContextRef, if requested
89+
existing_contextref_obj = get_contextref_obj(model)
90+
if existing_contextref_obj is not None:
91+
context_refs = handle_context_ref_setup(existing_contextref_obj, model.model_options["transformer_options"], positive, negative)
92+
controlnets_modified = True
6993
# look for Advanced ControlNets that will require intervention to work
7094
ref_set = set()
7195
lllite_dict: dict[ControlLLLiteAdvanced, None] = {} # dicts preserve insertion order since py3.7
@@ -88,7 +112,7 @@ def acn_sample(model: ModelPatcher, *args, **kwargs):
88112
for lll in lllite_list:
89113
lll.live_model_patches(model.model_options)
90114
# if no ref cn found, do original function immediately
91-
if len(ref_set) == 0:
115+
if len(ref_set) == 0 and len(context_refs) == 0:
92116
return orig_comfy_sample(model, *args, **kwargs)
93117
# otherwise, injection time
94118
try:
@@ -158,6 +182,7 @@ def acn_sample(model: ModelPatcher, *args, **kwargs):
158182
new_model_options["transformer_options"] = model.model_options["transformer_options"].copy()
159183
ref_list: list[ReferenceAdvanced] = list(ref_set)
160184
new_model_options["transformer_options"][REF_CONTROL_LIST_ALL] = sorted(ref_list, key=lambda x: x.order)
185+
new_model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC] = reference_injections.clean_contextref_module_mem
161186
model.model_options = new_model_options
162187
# continue with original function
163188
return orig_comfy_sample(model, *args, **kwargs)
@@ -167,14 +192,14 @@ def acn_sample(model: ModelPatcher, *args, **kwargs):
167192
attn_modules: list[RefBasicTransformerBlock] = reference_injections.attn_modules
168193
for module in attn_modules:
169194
module.injection_holder.restore(module)
170-
module.injection_holder.clean()
195+
module.injection_holder.clean_all()
171196
del module.injection_holder
172197
del attn_modules
173198
# restore gn modules
174199
gn_modules: list[RefTimestepEmbedSequential] = reference_injections.gn_modules
175200
for module in gn_modules:
176201
module.injection_holder.restore(module)
177-
module.injection_holder.clean()
202+
module.injection_holder.clean_all()
178203
del module.injection_holder
179204
del gn_modules
180205
# restore diffusion_model forward function

adv_control/utils.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
BIGMIN = -(2**53-1)
2222
BIGMAX = (2**53-1)
2323

24+
ORIG_PREVIOUS_CONTROLNET = "_orig_previous_controlnet"
25+
CONTROL_INIT_BY_ACN = "_control_init_by_ACN"
26+
27+
2428
def load_torch_file_with_dict_factory(controlnet_data: dict[str, Tensor], orig_load_torch_file: Callable):
2529
def load_torch_file_with_dict(*args, **kwargs):
2630
# immediately restore load_torch_file to original version
@@ -35,7 +39,7 @@ def wrapper_len_factory(orig_len: Callable) -> Callable:
3539
def wrapper_len(*args, **kwargs):
3640
cond_or_uncond = args[0]
3741
real_length = orig_len(*args, **kwargs)
38-
if real_length > 0 and type(cond_or_uncond) == list and (cond_or_uncond[0] in [0, 1]):
42+
if real_length > 0 and type(cond_or_uncond) == list and isinstance(cond_or_uncond[0], int) and (cond_or_uncond[0] in [0, 1]):
3943
try:
4044
to_return = IntWithCondOrUncond(real_length)
4145
setattr(to_return, "cond_or_uncond", cond_or_uncond)
@@ -569,7 +573,8 @@ def __init__(self, base: ControlBase, timestep_keyframes: TimestepKeyframeGroup,
569573
self.full_latent_length = 0
570574
self.context_length = 0
571575
# timesteps
572-
self.t: Tensor = None
576+
self.t: float = None
577+
self.prev_t: float = None
573578
self.batched_number: Union[int, IntWithCondOrUncond] = None
574579
self.batch_size: int = 0
575580
# weights + override
@@ -627,10 +632,11 @@ def set_timestep_keyframes(self, timestep_keyframes: TimestepKeyframeGroup):
627632
self.weights = None
628633
self.latent_keyframes = None
629634

630-
def prepare_current_timestep(self, t: Tensor, batched_number: int):
635+
def prepare_current_timestep(self, t: Tensor, batched_number: int=1):
631636
self.t = float(t[0])
632-
self.batched_number = batched_number
633-
self.batch_size = len(t)
637+
# check if t has changed (otherwise do nothing, as step already accounted for)
638+
if self.t == self.prev_t:
639+
return
634640
# get current step percent
635641
curr_t: float = self.t
636642
prev_index = self._current_timestep_index
@@ -666,7 +672,8 @@ def prepare_current_timestep(self, t: Tensor, batched_number: int):
666672
# if eval_tk is outside of percent range, stop looking further
667673
else:
668674
break
669-
675+
# update prev_t
676+
self.prev_t = self.t
670677
# update steps current keyframe is used
671678
self._current_used_steps += 1
672679
# if index changed, apply overrides
@@ -740,6 +747,8 @@ def should_run(self):
740747
return True
741748

742749
def get_control_inject(self, x_noisy, t, cond, batched_number):
750+
self.batched_number = batched_number
751+
self.batch_size = len(t)
743752
# prepare timestep and everything related
744753
self.prepare_current_timestep(t=t, batched_number=batched_number)
745754
# if should not perform any actions for the controlnet, exit without doing any work
@@ -932,6 +941,7 @@ def cleanup_advanced(self):
932941
self.full_latent_length = 0
933942
self.context_length = 0
934943
self.t = None
944+
self.prev_t = None
935945
self.batched_number = None
936946
self.batch_size = 0
937947
self.weights = None

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-advanced-controlnet"
33
description = "Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks."
4-
version = "1.1.5"
4+
version = "1.2.0"
55
license = { file = "LICENSE" }
66
dependencies = []
77

0 commit comments

Comments
 (0)