@@ -15,7 +15,7 @@ import {
15
15
} from 'features/nodes/util/graph/constants' ;
16
16
import { isVectorMaskLayer } from 'features/regionalPrompts/store/regionalPromptsSlice' ;
17
17
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs' ;
18
- import { size , sumBy } from 'lodash-es' ;
18
+ import { size } from 'lodash-es' ;
19
19
import { imagesApi } from 'services/api/endpoints/images' ;
20
20
import type { CollectInvocation , Edge , IPAdapterInvocation , NonNullableGraph , S } from 'services/api/types' ;
21
21
import { assert } from 'tsafe' ;
@@ -39,6 +39,16 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
39
39
return hasTextPrompt || hasIPAdapter ;
40
40
} ) ;
41
41
42
+ const regionalIPAdapters = selectAllIPAdapters ( state . controlAdapters ) . filter (
43
+ ( { id, model, controlImage, isEnabled } ) => {
44
+ const hasModel = Boolean ( model ) ;
45
+ const doesBaseMatch = model ?. base === state . generation . model ?. base ;
46
+ const hasControlImage = controlImage ;
47
+ const isRegional = layers . some ( ( l ) => l . ipAdapterIds . includes ( id ) ) ;
48
+ return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional ;
49
+ }
50
+ ) ;
51
+
42
52
const layerIds = layers . map ( ( l ) => l . id ) ;
43
53
const blobs = await getRegionalPromptLayerBlobs ( layerIds ) ;
44
54
assert ( size ( blobs ) === size ( layerIds ) , 'Mismatch between layer IDs and blobs' ) ;
@@ -105,7 +115,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
105
115
} ,
106
116
} ) ;
107
117
108
- if ( ! graph . nodes [ IP_ADAPTER_COLLECT ] && sumBy ( layers , ( l ) => l . ipAdapterIds . length ) > 0 ) {
118
+ if ( ! graph . nodes [ IP_ADAPTER_COLLECT ] && regionalIPAdapters . length > 0 ) {
109
119
const ipAdapterCollectNode : CollectInvocation = {
110
120
id : IP_ADAPTER_COLLECT ,
111
121
type : 'collect' ,
@@ -284,8 +294,16 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
284
294
}
285
295
286
296
for ( const ipAdapterId of layer . ipAdapterIds ) {
287
- const ipAdapter = selectAllIPAdapters ( state . controlAdapters ) . find ( ( ca ) => ca . id === ipAdapterId ) ;
288
- console . log ( ipAdapter ) ;
297
+ const ipAdapter = selectAllIPAdapters ( state . controlAdapters )
298
+ . filter ( ( { id, model, controlImage, isEnabled } ) => {
299
+ const hasModel = Boolean ( model ) ;
300
+ const doesBaseMatch = model ?. base === state . generation . model ?. base ;
301
+ const hasControlImage = controlImage ;
302
+ const isRegional = layers . some ( ( l ) => l . ipAdapterIds . includes ( id ) ) ;
303
+ return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional ;
304
+ } )
305
+ . find ( ( ca ) => ca . id === ipAdapterId ) ;
306
+
289
307
if ( ! ipAdapter ?. model ) {
290
308
return ;
291
309
}
0 commit comments