@@ -4,13 +4,17 @@ import { deepClone } from 'common/util/deepClone';
44import { stagingAreaImageStaged } from 'features/controlLayers/store/canvasStagingAreaSlice' ;
55import { boardIdSelected , galleryViewChanged , imageSelected , offsetChanged } from 'features/gallery/store/gallerySlice' ;
66import { $nodeExecutionStates , upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState' ;
7+ import { isImageField , isImageFieldCollection } from 'features/nodes/types/common' ;
78import { zNodeStatus } from 'features/nodes/types/invocation' ;
89import { CANVAS_OUTPUT_PREFIX } from 'features/nodes/util/graph/graphBuilderUtils' ;
10+ import type { ApiTagDescription } from 'services/api' ;
911import { boardsApi } from 'services/api/endpoints/boards' ;
1012import { getImageDTOSafe , imagesApi } from 'services/api/endpoints/images' ;
1113import type { ImageDTO , S } from 'services/api/types' ;
1214import { getCategories , getListImagesUrl } from 'services/api/util' ;
1315import { $lastProgressEvent } from 'services/events/stores' ;
16+ import type { Param0 } from 'tsafe' ;
17+ import { objectEntries } from 'tsafe' ;
1418import type { JsonObject } from 'type-fest' ;
1519
1620const log = logger ( 'events' ) ;
@@ -22,58 +26,98 @@ const isCanvasOutputNode = (data: S['InvocationCompleteEvent']) => {
2226const nodeTypeDenylist = [ 'load_image' , 'image' ] ;
2327
2428export const buildOnInvocationComplete = ( getState : ( ) => RootState , dispatch : AppDispatch ) => {
25- const addImageToGallery = ( data : S [ 'InvocationCompleteEvent' ] , imageDTO : ImageDTO ) => {
29+ const addImagesToGallery = ( data : S [ 'InvocationCompleteEvent' ] , imageDTOs : ImageDTO [ ] ) => {
2630 if ( nodeTypeDenylist . includes ( data . invocation . type ) ) {
27- log . trace ( ' Skipping node type denylisted' ) ;
31+ log . trace ( ` Skipping denylisted node type ( ${ data . invocation . type } )` ) ;
2832 return ;
2933 }
3034
31- if ( imageDTO . is_intermediate ) {
35+ // For efficiency's sake, we want to minimize the number of dispatches and invalidations we do.
36+ // We'll keep track of each change we need to make and do them all at once.
37+ const boardTotalAdditions : Record < string , number > = { } ;
38+ const boardTagIdsToInvalidate : Set < string > = new Set ( ) ;
39+ const imageListTagIdsToInvalidate : Set < string > = new Set ( ) ;
40+
41+ for ( const imageDTO of imageDTOs ) {
42+ if ( imageDTO . is_intermediate ) {
43+ return ;
44+ }
45+
46+ const boardId = imageDTO . board_id ?? 'none' ;
47+ // update the total images for the board
48+ boardTotalAdditions [ boardId ] = ( boardTotalAdditions [ boardId ] || 0 ) + 1 ;
49+ // invalidate the board tag
50+ boardTagIdsToInvalidate . add ( boardId ) ;
51+ // invalidate the image list tag
52+ imageListTagIdsToInvalidate . add (
53+ getListImagesUrl ( {
54+ board_id : boardId ,
55+ categories : getCategories ( imageDTO ) ,
56+ } )
57+ ) ;
58+ }
59+
60+ // Update all the board image totals at once
61+ const entries : Param0 < typeof boardsApi . util . upsertQueryEntries > = [ ] ;
62+ for ( const [ boardId , amountToAdd ] of objectEntries ( boardTotalAdditions ) ) {
63+ // upsertQueryEntries doesn't provide a "recipe" function for the update - we must provide the new value
64+ // directly. So we need to select the board totals first.
65+ const total = boardsApi . endpoints . getBoardImagesTotal . select ( boardId ) ( getState ( ) ) . data ?. total ;
66+ if ( total === undefined ) {
67+ // No cache exists for this board, so we can't update it.
68+ continue ;
69+ }
70+ entries . push ( {
71+ endpointName : 'getBoardImagesTotal' ,
72+ arg : boardId ,
73+ value : { total : total + amountToAdd } ,
74+ } ) ;
75+ }
76+ dispatch ( boardsApi . util . upsertQueryEntries ( entries ) ) ;
77+
78+ // Invalidate all tags at once
79+ const boardTags : ApiTagDescription [ ] = Array . from ( boardTagIdsToInvalidate ) . map ( ( boardId ) => ( {
80+ type : 'Board' as const ,
81+ id : boardId ,
82+ } ) ) ;
83+ const imageListTags : ApiTagDescription [ ] = Array . from ( imageListTagIdsToInvalidate ) . map ( ( imageListId ) => ( {
84+ type : 'ImageList' as const ,
85+ id : imageListId ,
86+ } ) ) ;
87+ dispatch ( imagesApi . util . invalidateTags ( [ ...boardTags , ...imageListTags ] ) ) ;
88+
89+ // Finally, we may need to autoswitch to the new image. We'll only do it for the last image in the list.
90+
91+ const lastImageDTO = imageDTOs . at ( - 1 ) ;
92+
93+ if ( ! lastImageDTO ) {
3294 return ;
3395 }
3496
35- // update the total images for the board
36- dispatch (
37- boardsApi . util . updateQueryData ( 'getBoardImagesTotal' , imageDTO . board_id ?? 'none' , ( draft ) => {
38- draft . total += 1 ;
39- } )
40- ) ;
41-
42- dispatch (
43- imagesApi . util . invalidateTags ( [
44- { type : 'Board' , id : imageDTO . board_id ?? 'none' } ,
45- {
46- type : 'ImageList' ,
47- id : getListImagesUrl ( {
48- board_id : imageDTO . board_id ?? 'none' ,
49- categories : getCategories ( imageDTO ) ,
50- } ) ,
51- } ,
52- ] )
53- ) ;
97+ const { image_name, board_id } = lastImageDTO ;
5498
5599 const { shouldAutoSwitch, selectedBoardId, galleryView, offset } = getState ( ) . gallery ;
56100
57101 // If auto-switch is enabled, select the new image
58102 if ( shouldAutoSwitch ) {
59103 // If the image is from a different board, switch to that board - this will also select the image
60- if ( imageDTO . board_id && imageDTO . board_id !== selectedBoardId ) {
104+ if ( board_id && board_id !== selectedBoardId ) {
61105 dispatch (
62106 boardIdSelected ( {
63- boardId : imageDTO . board_id ,
64- selectedImageName : imageDTO . image_name ,
107+ boardId : board_id ,
108+ selectedImageName : image_name ,
65109 } )
66110 ) ;
67- } else if ( ! imageDTO . board_id && selectedBoardId !== 'none' ) {
111+ } else if ( ! board_id && selectedBoardId !== 'none' ) {
68112 dispatch (
69113 boardIdSelected ( {
70114 boardId : 'none' ,
71- selectedImageName : imageDTO . image_name ,
115+ selectedImageName : image_name ,
72116 } )
73117 ) ;
74118 } else {
75119 // Else just select the image, no need to switch boards
76- dispatch ( imageSelected ( imageDTO ) ) ;
120+ dispatch ( imageSelected ( lastImageDTO ) ) ;
77121
78122 if ( galleryView !== 'images' ) {
79123 // We also need to update the gallery view to images. This also updates the offset.
@@ -86,12 +130,25 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
86130 }
87131 } ;
88132
89- const getResultImageDTO = ( data : S [ 'InvocationCompleteEvent' ] ) => {
133+ const getResultImageDTOs = async ( data : S [ 'InvocationCompleteEvent' ] ) : Promise < ImageDTO [ ] > => {
90134 const { result } = data ;
91- if ( result . type === 'image_output' ) {
92- return getImageDTOSafe ( result . image . image_name ) ;
135+ const imageDTOs : ImageDTO [ ] = [ ] ;
136+ for ( const [ _name , value ] of objectEntries ( result ) ) {
137+ if ( isImageField ( value ) ) {
138+ const imageDTO = await getImageDTOSafe ( value . image_name ) ;
139+ if ( imageDTO ) {
140+ imageDTOs . push ( imageDTO ) ;
141+ }
142+ } else if ( isImageFieldCollection ( value ) ) {
143+ for ( const imageField of value ) {
144+ const imageDTO = await getImageDTOSafe ( imageField . image_name ) ;
145+ if ( imageDTO ) {
146+ imageDTOs . push ( imageDTO ) ;
147+ }
148+ }
149+ }
93150 }
94- return null ;
151+ return imageDTOs ;
95152 } ;
96153
97154 const handleOriginWorkflows = async ( data : S [ 'InvocationCompleteEvent' ] ) => {
@@ -107,16 +164,15 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
107164 upsertExecutionState ( nes . nodeId , nes ) ;
108165 }
109166
110- const imageDTO = await getResultImageDTO ( data ) ;
111-
112- if ( imageDTO && ! imageDTO . is_intermediate ) {
113- addImageToGallery ( data , imageDTO ) ;
114- }
167+ const imageDTOs = await getResultImageDTOs ( data ) ;
168+ addImagesToGallery ( data , imageDTOs ) ;
115169 } ;
116170
117171 const handleOriginCanvas = async ( data : S [ 'InvocationCompleteEvent' ] ) => {
118- const imageDTO = await getResultImageDTO ( data ) ;
172+ const imageDTOs = await getResultImageDTOs ( data ) ;
119173
174+ // We expect only a single image in the canvas output
175+ const imageDTO = imageDTOs [ 0 ] ;
120176 if ( ! imageDTO ) {
121177 return ;
122178 }
@@ -127,20 +183,17 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
127183 if ( data . result . type === 'image_output' ) {
128184 dispatch ( stagingAreaImageStaged ( { stagingAreaImage : { imageDTO, offsetX : 0 , offsetY : 0 } } ) ) ;
129185 }
130- addImageToGallery ( data , imageDTO ) ;
186+ addImagesToGallery ( data , [ imageDTO ] ) ;
131187 }
132188 } else if ( ! imageDTO . is_intermediate ) {
133189 // Desintaion is gallery
134- addImageToGallery ( data , imageDTO ) ;
190+ addImagesToGallery ( data , [ imageDTO ] ) ;
135191 }
136192 } ;
137193
138194 const handleOriginOther = async ( data : S [ 'InvocationCompleteEvent' ] ) => {
139- const imageDTO = await getResultImageDTO ( data ) ;
140-
141- if ( imageDTO && ! imageDTO . is_intermediate ) {
142- addImageToGallery ( data , imageDTO ) ;
143- }
195+ const imageDTOs = await getResultImageDTOs ( data ) ;
196+ addImagesToGallery ( data , imageDTOs ) ;
144197 } ;
145198
146199 return async ( data : S [ 'InvocationCompleteEvent' ] ) => {
0 commit comments