11
11
RefBasicTransformerBlock , RefTimestepEmbedSequential ,
12
12
InjectionBasicTransformerBlockHolder , InjectionTimestepEmbedSequentialHolder ,
13
13
_forward_inject_BasicTransformerBlock , factory_forward_inject_UNetModel ,
14
- REF_CONTROL_LIST_ALL )
14
+ handle_context_ref_setup ,
15
+ REF_CONTROL_LIST_ALL , CONTEXTREF_CLEAN_FUNC )
15
16
from .control_lllite import (ControlLLLiteAdvanced )
16
17
from .utils import torch_dfs
17
18
18
19
19
20
def support_sliding_context_windows (model , positive , negative ) -> tuple [bool , dict , dict ]:
20
- if not hasattr (model , "motion_injection_params" ):
21
- return False , positive , negative
22
- motion_injection_params = getattr (model , "motion_injection_params" )
23
- context_options = getattr (motion_injection_params , "context_options" )
24
- if context_options .context_length is None :
25
- return False , positive , negative
26
21
# convert to advanced, with report if anything was actually modified
27
22
modified , new_conds = convert_all_to_advanced ([positive , negative ])
28
23
positive , negative = new_conds
29
24
return modified , positive , negative
30
25
31
26
27
+ def has_sliding_context_windows (model ):
28
+ motion_injection_params = getattr (model , "motion_injection_params" , None )
29
+ if motion_injection_params is None :
30
+ return False
31
+ context_options = getattr (motion_injection_params , "context_options" )
32
+ return context_options .context_length is not None
33
+
34
+
35
+ def get_contextref_obj (model ):
36
+ motion_injection_params = getattr (model , "motion_injection_params" , None )
37
+ if motion_injection_params is None :
38
+ return None
39
+ context_options = getattr (motion_injection_params , "context_options" )
40
+ extras = getattr (context_options , "extras" , None )
41
+ if extras is None :
42
+ return None
43
+ return getattr (extras , "context_ref" , None )
44
+
45
+
32
46
def acn_sample_factory (orig_comfy_sample : Callable , is_custom = False ) -> Callable :
33
47
def get_refcn (control : ControlBase , order : int = - 1 ):
34
48
ref_set : set [ReferenceAdvanced ] = set ()
35
49
if control is None :
36
50
return ref_set
37
- if type (control ) == ReferenceAdvanced :
51
+ if type (control ) == ReferenceAdvanced and not control . is_context_ref :
38
52
control .order = order
39
53
order -= 1
40
54
ref_set .add (control )
@@ -59,13 +73,23 @@ def acn_sample(model: ModelPatcher, *args, **kwargs):
59
73
# check if positive or negative conds contain ref cn
60
74
positive = args [- 3 ]
61
75
negative = args [- 2 ]
62
- # if context options present, convert all CNs to Advanced if needed
63
- controlnets_modified , positive , negative = support_sliding_context_windows (model , positive , negative )
64
- if controlnets_modified :
65
- args = list (args )
66
- args [- 3 ] = positive
67
- args [- 2 ] = negative
68
- args = tuple (args )
76
+ # if context options present, perform some special actions that may be required
77
+ context_refs = []
78
+ if has_sliding_context_windows (model ):
79
+ model .model_options = model .model_options .copy ()
80
+ model .model_options ["transformer_options" ] = model .model_options ["transformer_options" ].copy ()
81
+ # convert all CNs to Advanced if needed
82
+ controlnets_modified , positive , negative = support_sliding_context_windows (model , positive , negative )
83
+ if controlnets_modified :
84
+ args = list (args )
85
+ args [- 3 ] = positive
86
+ args [- 2 ] = negative
87
+ args = tuple (args )
88
+ # enable ContextRef, if requested
89
+ existing_contextref_obj = get_contextref_obj (model )
90
+ if existing_contextref_obj is not None :
91
+ context_refs = handle_context_ref_setup (existing_contextref_obj , model .model_options ["transformer_options" ], positive , negative )
92
+ controlnets_modified = True
69
93
# look for Advanced ControlNets that will require intervention to work
70
94
ref_set = set ()
71
95
lllite_dict : dict [ControlLLLiteAdvanced , None ] = {} # dicts preserve insertion order since py3.7
@@ -88,7 +112,7 @@ def acn_sample(model: ModelPatcher, *args, **kwargs):
88
112
for lll in lllite_list :
89
113
lll .live_model_patches (model .model_options )
90
114
# if no ref cn found, do original function immediately
91
- if len (ref_set ) == 0 :
115
+ if len (ref_set ) == 0 and len ( context_refs ) == 0 :
92
116
return orig_comfy_sample (model , * args , ** kwargs )
93
117
# otherwise, injection time
94
118
try :
@@ -158,6 +182,7 @@ def acn_sample(model: ModelPatcher, *args, **kwargs):
158
182
new_model_options ["transformer_options" ] = model .model_options ["transformer_options" ].copy ()
159
183
ref_list : list [ReferenceAdvanced ] = list (ref_set )
160
184
new_model_options ["transformer_options" ][REF_CONTROL_LIST_ALL ] = sorted (ref_list , key = lambda x : x .order )
185
+ new_model_options ["transformer_options" ][CONTEXTREF_CLEAN_FUNC ] = reference_injections .clean_contextref_module_mem
161
186
model .model_options = new_model_options
162
187
# continue with original function
163
188
return orig_comfy_sample (model , * args , ** kwargs )
@@ -167,14 +192,14 @@ def acn_sample(model: ModelPatcher, *args, **kwargs):
167
192
attn_modules : list [RefBasicTransformerBlock ] = reference_injections .attn_modules
168
193
for module in attn_modules :
169
194
module .injection_holder .restore (module )
170
- module .injection_holder .clean ()
195
+ module .injection_holder .clean_all ()
171
196
del module .injection_holder
172
197
del attn_modules
173
198
# restore gn modules
174
199
gn_modules : list [RefTimestepEmbedSequential ] = reference_injections .gn_modules
175
200
for module in gn_modules :
176
201
module .injection_holder .restore (module )
177
- module .injection_holder .clean ()
202
+ module .injection_holder .clean_all ()
178
203
del module .injection_holder
179
204
del gn_modules
180
205
# restore diffusion_model forward function
0 commit comments