From d985dfe82156f8174a62b8ddfe0f57d3fff995ce Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 27 May 2025 16:42:45 +1000 Subject: [PATCH] refactor(ui): canvas flow events (wip) --- .../components/CanvasMainPanelContent.tsx | 69 +++---------------- .../controlLayers/store/paramsSlice.ts | 5 -- .../services/events/onInvocationComplete.tsx | 9 ++- .../src/services/events/setEventListeners.tsx | 6 +- .../web/src/services/events/stores.ts | 16 +++++ 5 files changed, 38 insertions(+), 67 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasMainPanelContent.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasMainPanelContent.tsx index e17df9b425..e93587ef18 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasMainPanelContent.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasMainPanelContent.tsx @@ -29,23 +29,17 @@ import { selectStagedImageIndex, selectStagedImages, stagingAreaImageSelected, - stagingAreaImageStaged, stagingAreaNextStagedImageSelected, stagingAreaPrevStagedImageSelected, } from 'features/controlLayers/store/canvasStagingAreaSlice'; -import { isImageField, type ProgressImage } from 'features/nodes/types/common'; -import { isCanvasOutputEvent } from 'features/nodes/util/graph/graphBuilderUtils'; -import type { Atom } from 'nanostores'; -import { atom } from 'nanostores'; -import { memo, useCallback, useEffect, useState } from 'react'; -import { flushSync } from 'react-dom'; +import type { ProgressImage } from 'features/nodes/types/common'; +import { memo, useCallback, useEffect } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi'; -import { getImageDTOSafe } from 'services/api/endpoints/images'; import type { ImageDTO, S } from 'services/api/types'; -import { $socket } from 'services/events/stores'; +import { $lastCanvasProgressImage, $socket } from 'services/events/stores'; import type { Equals } from 'tsafe'; -import { assert, objectEntries } from 'tsafe'; +import { assert } from 'tsafe'; import { CanvasAlertsInvocationProgress } from './CanvasAlerts/CanvasAlertsInvocationProgress'; @@ -131,53 +125,14 @@ const SimpleActiveSession = memo(() => { const dispatch = useAppDispatch(); const isStaging = useAppSelector(selectIsStaging); const socket = useStore($socket); - const [$progressImage] = useState(() => atom(null)); useEffect(() => { if (!socket) { return; } - const onInvocationProgress = (event: S['InvocationProgressEvent']) => { - if (!event) { - return; - } - if (event.origin !== 'canvas') { - return; - } - if (!event.image) { - return; - } - $progressImage.set({ sessionId: event.session_id, image: event.image }); - }; - const onInvocationComplete = async (event: S['InvocationCompleteEvent']) => { - const progressImage = $progressImage.get(); - if (!progressImage) { - return; - } - if (progressImage.sessionId !== event.session_id) { - return; - } - if (!isCanvasOutputEvent(event)) { - return; - } - let imageDTO: ImageDTO | null = null; - for (const [_name, value] of objectEntries(event.result)) { - if (isImageField(value)) { - imageDTO = await getImageDTOSafe(value.image_name); - break; - } - } - if (!imageDTO) { - return; - } - flushSync(() => { - dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } })); - }); - $progressImage.set(null); - }; const onQueueItemStatusChanged = (event: S['QueueItemStatusChangedEvent']) => { - const progressImage = $progressImage.get(); + const progressImage = $lastCanvasProgressImage.get(); if (!progressImage) { return; } @@ -187,20 +142,16 @@ const SimpleActiveSession = memo(() => { if (event.status !== 'canceled' && event.status !== 'failed') { return; } - $progressImage.set(null); + $lastCanvasProgressImage.set(null); }; console.log('SUB session preview image listeners'); - socket.on('invocation_progress', onInvocationProgress); - socket.on('invocation_complete', onInvocationComplete); socket.on('queue_item_status_changed', onQueueItemStatusChanged); return () => { console.log('UNSUB session preview image listeners'); - socket.off('invocation_progress', onInvocationProgress); - socket.off('invocation_complete', onInvocationComplete); socket.off('queue_item_status_changed', onQueueItemStatusChanged); }; - }, [$progressImage, dispatch, socket]); + }, [dispatch, socket]); const onReset = useCallback(() => { dispatch(canvasReset()); @@ -226,15 +177,15 @@ const SimpleActiveSession = memo(() => { - + ); }); SimpleActiveSession.displayName = 'SimpleActiveSession'; -const SelectedImage = memo(({ $progressImage }: { $progressImage: Atom }) => { - const progressImage = useStore($progressImage); +const SelectedImage = memo(() => { + const progressImage = useStore($lastCanvasProgressImage); const selectedImage = useAppSelector(selectSelectedImage); if (progressImage) { diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 9fcf4a89f8..5fb816e709 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -1,7 +1,6 @@ import type { PayloadAction, Selector } from '@reduxjs/toolkit'; import { createSelector, createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; -import { canvasSessionStarted } from 'features/controlLayers/store/canvasStagingAreaSlice'; import type { ParamsState, RgbaColor } from 'features/controlLayers/store/types'; import { getInitialParamsState } from 'features/controlLayers/store/types'; import { CLIP_SKIP_MAP } from 'features/parameters/types/constants'; @@ -25,7 +24,6 @@ import { clamp } from 'lodash-es'; import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models'; import { isNonRefinerMainModelConfig } from 'services/api/types'; - export const paramsSlice = createSlice({ name: 'params', initialState: getInitialParamsState(), @@ -188,9 +186,6 @@ export const paramsSlice = createSlice({ }, paramsReset: (state) => resetState(state), }, - extraReducers(builder) { - builder.addCase(canvasSessionStarted, (state) => resetState(state)); - }, }); const resetState = (state: ParamsState): ParamsState => { diff --git a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx index 92ad0aa011..3c5f154950 100644 --- a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx +++ b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx @@ -7,12 +7,13 @@ import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks import { isImageField, isImageFieldCollection } from 'features/nodes/types/common'; import { zNodeStatus } from 'features/nodes/types/invocation'; import { isCanvasOutputEvent } from 'features/nodes/util/graph/graphBuilderUtils'; +import { flushSync } from 'react-dom'; import type { ApiTagDescription } from 'services/api'; import { boardsApi } from 'services/api/endpoints/boards'; import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images'; import type { ImageDTO, S } from 'services/api/types'; import { getCategories, getListImagesUrl } from 'services/api/util'; -import { $lastProgressEvent } from 'services/events/stores'; +import { $lastCanvasProgressImage, $lastProgressEvent } from 'services/events/stores'; import type { Param0 } from 'tsafe'; import { objectEntries } from 'tsafe'; import type { JsonObject } from 'type-fest'; @@ -176,7 +177,11 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A return; } - dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } })); + flushSync(() => { + dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } })); + }); + + $lastCanvasProgressImage.set(null); }; const handleOriginOther = async (data: S['InvocationCompleteEvent']) => { diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 15b5f545ed..f41e9b5da6 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -30,7 +30,7 @@ import type { ClientToServerEvents, ServerToClientEvents } from 'services/events import type { Socket } from 'socket.io-client'; import type { JsonObject } from 'type-fest'; -import { $lastProgressEvent } from './stores'; +import { $lastCanvasProgressEvent, $lastProgressEvent } from './stores'; const log = logger('events'); @@ -428,6 +428,10 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis // If the queue item is completed, failed, or cancelled, we want to clear the last progress event $lastProgressEvent.set(null); + if (data.origin === 'canvas') { + $lastCanvasProgressEvent.set(null); + } + // When a validation run is completed, we want to clear the validation run batch ID & set the workflow as published const validationRunData = $validationRunData.get(); if (!validationRunData || batch_status.batch_id !== validationRunData.batchId || status !== 'completed') { diff --git a/invokeai/frontend/web/src/services/events/stores.ts b/invokeai/frontend/web/src/services/events/stores.ts index d25a887c7d..b724f246c7 100644 --- a/invokeai/frontend/web/src/services/events/stores.ts +++ b/invokeai/frontend/web/src/services/events/stores.ts @@ -1,3 +1,4 @@ +import type { ProgressImage } from 'features/nodes/types/common'; import { round } from 'lodash-es'; import { atom, computed, map } from 'nanostores'; import type { S } from 'services/api/types'; @@ -15,18 +16,33 @@ $lastProgressEvent.subscribe((event) => { switch (event.destination) { case 'workflows': $lastWorkflowsProgressEvent.set(event); + if (event.image) { + $lastWorkflowsProgressImage.set({ sessionId: event.session_id, image: event.image }); + } break; case 'upscaling': $lastUpscalingProgressEvent.set(event); + if (event.image) { + $lastUpscalingProgressImage.set({ sessionId: event.session_id, image: event.image }); + } break; case 'canvas': $lastCanvasProgressEvent.set(event); + if (event.image) { + $lastCanvasProgressImage.set({ sessionId: event.session_id, image: event.image }); + } break; } }); + +type EphemeralProgressImage = { sessionId: string; image: ProgressImage }; + export const $lastCanvasProgressEvent = atom(null); +export const $lastCanvasProgressImage = atom(null); export const $lastWorkflowsProgressEvent = atom(null); +export const $lastWorkflowsProgressImage = atom(null); export const $lastUpscalingProgressEvent = atom(null); +export const $lastUpscalingProgressImage = atom(null); export const $progressImage = computed($lastProgressEvent, (val) => val?.image ?? null); export const $hasProgressImage = computed($lastProgressEvent, (val) => Boolean(val?.image));