From b152937f30dee033e7631529bb04ebbfd7e3bf1a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 17 Aug 2024 15:30:55 +1000 Subject: [PATCH] feat(ui): move socket event handling out of redux Download events and invocation status events (including progress images) are very frequent. There's no real need for these to pass through redux. Handling them outside redux is a significant performance win - far fewer store subscription calls, far fewer trips through middleware. All event handling is moved outside middleware. Cleanup of unused actions and listeners to follow. --- .../frontend/web/src/app/hooks/useSocketIO.ts | 10 +- .../addCommitStagingAreaImageListener.ts | 4 +- .../socketio/socketGeneratorProgress.ts | 4 +- .../listeners/socketio/socketQueueEvents.tsx | 10 +- .../src/common/hooks/useIsReadyToEnqueue.ts | 8 +- invokeai/frontend/web/src/common/types.ts | 5 + .../IPAdapter/IPAdapterImagePreview.tsx | 142 ++-- .../konva/CanvasProgressImage.ts | 2 +- .../controlLayers/konva/CanvasStagingArea.ts | 2 +- .../controlLayers/konva/CanvasStateApi.ts | 4 +- .../controlLayers/store/canvasV2Slice.ts | 2 - .../components/DeleteImageButton.tsx | 4 +- .../ImageViewer/CurrentImageButtons.tsx | 24 +- .../ImageViewer/CurrentImagePreview.tsx | 4 +- .../components/ImageViewer/ProgressImage.tsx | 12 +- .../nodes/CurrentImage/CurrentImageNode.tsx | 27 +- .../inputs/ImageFieldInputComponent.tsx | 6 +- .../features/queue/hooks/useCancelBatch.ts | 5 +- .../queue/hooks/useCancelCurrentQueueItem.ts | 5 +- .../queue/hooks/useCancelQueueItem.ts | 5 +- .../queue/hooks/useClearInvocationCache.ts | 5 +- .../src/features/queue/hooks/useClearQueue.ts | 6 +- .../queue/hooks/useDisableInvocationCache.ts | 5 +- .../queue/hooks/useEnableInvocationCache.ts | 5 +- .../features/queue/hooks/usePauseProcessor.ts | 5 +- .../src/features/queue/hooks/usePruneQueue.ts | 6 +- .../queue/hooks/useResumeProcessor.ts | 5 +- .../system/components/ProgressBar.tsx | 26 +- .../system/components/StatusIndicator.tsx | 5 +- .../src/services/events/setEventListeners.ts | 136 ---- .../src/services/events/setEventListeners.tsx | 621 ++++++++++++++++++ 31 files changed, 809 insertions(+), 301 deletions(-) delete mode 100644 invokeai/frontend/web/src/services/events/setEventListeners.ts create mode 100644 invokeai/frontend/web/src/services/events/setEventListeners.tsx diff --git a/invokeai/frontend/web/src/app/hooks/useSocketIO.ts b/invokeai/frontend/web/src/app/hooks/useSocketIO.ts index 8a530b8229..89cb1ae172 100644 --- a/invokeai/frontend/web/src/app/hooks/useSocketIO.ts +++ b/invokeai/frontend/web/src/app/hooks/useSocketIO.ts @@ -2,7 +2,7 @@ import { useStore } from '@nanostores/react'; import { $authToken } from 'app/store/nanostores/authToken'; import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $isDebugging } from 'app/store/nanostores/isDebugging'; -import { useAppDispatch } from 'app/store/storeHooks'; +import { useAppStore } from 'app/store/nanostores/store'; import type { MapStore } from 'nanostores'; import { atom, map } from 'nanostores'; import { useEffect, useMemo } from 'react'; @@ -28,13 +28,15 @@ export const getSocket = () => { return socket; }; export const $socketOptions = map>({}); + const $isSocketInitialized = atom(false); +export const $isConnected = atom(false); /** * Initializes the socket.io connection and sets up event listeners. */ export const useSocketIO = () => { - const dispatch = useAppDispatch(); + const { dispatch, getState } = useAppStore(); const baseUrl = useStore($baseUrl); const authToken = useStore($authToken); const addlSocketOptions = useStore($socketOptions); @@ -72,7 +74,7 @@ export const useSocketIO = () => { const socket: AppSocket = io(socketUrl, socketOptions); $socket.set(socket); - setEventListeners({ dispatch, socket }); + setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set }); socket.connect(); if ($isDebugging.get() || import.meta.env.MODE === 'development') { @@ -94,5 +96,5 @@ export const useSocketIO = () => { socket.disconnect(); $isSocketInitialized.set(false); }; - }, [dispatch, socketOptions, socketUrl]); + }, [dispatch, getState, socketOptions, socketUrl]); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts index d6cb10ff43..6b8d9782ca 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts @@ -1,7 +1,6 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { - $lastProgressEvent, rasterLayerAdded, sessionStagingAreaImageAccepted, sessionStagingAreaReset, @@ -11,6 +10,7 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/types'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { queueApi } from 'services/api/endpoints/queue'; +import { $lastCanvasProgressEvent } from 'services/events/setEventListeners'; import { assert } from 'tsafe'; export const addStagingListeners = (startAppListening: AppStartListening) => { @@ -29,7 +29,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => { const { canceled } = await req.unwrap(); req.reset(); - $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); if (canceled > 0) { log.debug(`Canceled ${canceled} canvas batches`); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts index e28235da59..1aff46d0a3 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts @@ -2,10 +2,10 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { deepClone } from 'common/util/deepClone'; import { parseify } from 'common/util/serialize'; -import { $lastProgressEvent } from 'features/controlLayers/store/canvasV2Slice'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import { socketGeneratorProgress } from 'services/events/actions'; +import { $lastCanvasProgressEvent } from 'services/events/setEventListeners'; const log = logger('socketio'); @@ -27,7 +27,7 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis } if (origin === 'canvas') { - $lastProgressEvent.set(action.payload.data); + $lastCanvasProgressEvent.set(action.payload.data); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueEvents.tsx b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueEvents.tsx index 5ba1013bb7..0b37104ca6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueEvents.tsx +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueEvents.tsx @@ -1,7 +1,6 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { deepClone } from 'common/util/deepClone'; -import { $lastProgressEvent } from 'features/controlLayers/store/canvasV2Slice'; import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription'; @@ -9,6 +8,7 @@ import { toast } from 'features/toast/toast'; import { forEach } from 'lodash-es'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; import { socketQueueItemStatusChanged } from 'services/events/actions'; +import { $lastCanvasProgressEvent } from 'services/events/setEventListeners'; const log = logger('socketio'); @@ -17,13 +17,13 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni startAppListening({ matcher: queueApi.endpoints.clearQueue.matchFulfilled, effect: () => { - $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); }, }); startAppListening({ actionCreator: socketQueueItemStatusChanged, - effect: async (action, { dispatch, getState }) => { + effect: (action, { dispatch, getState }) => { // we've got new status for the queue item, batch and queue const { item_id, @@ -103,7 +103,7 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni const isLocal = getState().config.isLocal ?? true; const sessionId = session_id; if (origin === 'canvas') { - $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); } toast({ @@ -122,7 +122,7 @@ export const addSocketQueueEventsListeners = (startAppListening: AppStartListeni ), }); } else if (status === 'canceled' && origin === 'canvas') { - $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); } }, }); diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index eb2ff5bc27..d30ee3f964 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -1,4 +1,5 @@ import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice'; @@ -25,7 +26,7 @@ const LAYER_TYPE_TO_TKEY = { control_layer: 'controlLayers.globalControlAdapter', } as const; -const createSelector = (templates: Templates) => +const createSelector = (templates: Templates, isConnected: boolean) => createMemoizedSelector( [ selectSystemSlice, @@ -41,8 +42,6 @@ const createSelector = (templates: Templates) => const { bbox } = canvasV2; const { model, positivePrompt } = canvasV2.params; - const { isConnected } = system; - const reasons: { prefix?: string; content: string }[] = []; // Cannot generate if not connected @@ -240,7 +239,8 @@ const createSelector = (templates: Templates) => export const useIsReadyToEnqueue = () => { const templates = useStore($templates); - const selector = useMemo(() => createSelector(templates), [templates]); + const isConnected = useStore($isConnected) + const selector = useMemo(() => createSelector(templates, isConnected), [templates, isConnected]); const value = useAppSelector(selector); return value; }; diff --git a/invokeai/frontend/web/src/common/types.ts b/invokeai/frontend/web/src/common/types.ts index f3037dcc2b..dd23638b8f 100644 --- a/invokeai/frontend/web/src/common/types.ts +++ b/invokeai/frontend/web/src/common/types.ts @@ -3,3 +3,8 @@ type JSONValue = string | number | boolean | null | JSONValue[] | { [key: string export interface JSONObject { [k: string]: JSONValue; } + +type SerializableValue = string | number | boolean | null | undefined | SerializableValue[] | SerializableObject; +export type SerializableObject = { + [k: string | number]: SerializableValue; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx b/invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx index 9e76aa1b91..e1f6b07857 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterImagePreview.tsx @@ -1,5 +1,7 @@ import { Flex, useShiftModifier } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { skipToken } from '@reduxjs/toolkit/query'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; @@ -22,79 +24,85 @@ type Props = { postUploadAction: PostUploadAction; }; -export const IPAdapterImagePreview = memo(({ image, onChangeImage, ipAdapterId, droppableData, postUploadAction }: Props) => { - const { t } = useTranslation(); - const dispatch = useAppDispatch(); - const isConnected = useAppSelector((s) => s.system.isConnected); - const optimalDimension = useAppSelector(selectOptimalDimension); - const shift = useShiftModifier(); +export const IPAdapterImagePreview = memo( + ({ image, onChangeImage, ipAdapterId, droppableData, postUploadAction }: Props) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const isConnected = useStore($isConnected); + const optimalDimension = useAppSelector(selectOptimalDimension); + const shift = useShiftModifier(); - const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(image?.image_name ?? skipToken); - const handleResetControlImage = useCallback(() => { - onChangeImage(null); - }, [onChangeImage]); + const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery( + image?.image_name ?? skipToken + ); + const handleResetControlImage = useCallback(() => { + onChangeImage(null); + }, [onChangeImage]); - const handleSetControlImageToDimensions = useCallback(() => { - if (!controlImage) { - return; - } + const handleSetControlImageToDimensions = useCallback(() => { + if (!controlImage) { + return; + } - const options = { updateAspectRatio: true, clamp: true }; - if (shift) { - const { width, height } = controlImage; - dispatch(bboxWidthChanged({ width, ...options })); - dispatch(bboxHeightChanged({ height, ...options })); - } else { - const { width, height } = calculateNewSize( - controlImage.width / controlImage.height, - optimalDimension * optimalDimension - ); - dispatch(bboxWidthChanged({ width, ...options })); - dispatch(bboxHeightChanged({ height, ...options })); - } - }, [controlImage, dispatch, optimalDimension, shift]); + const options = { updateAspectRatio: true, clamp: true }; + if (shift) { + const { width, height } = controlImage; + dispatch(bboxWidthChanged({ width, ...options })); + dispatch(bboxHeightChanged({ height, ...options })); + } else { + const { width, height } = calculateNewSize( + controlImage.width / controlImage.height, + optimalDimension * optimalDimension + ); + dispatch(bboxWidthChanged({ width, ...options })); + dispatch(bboxHeightChanged({ height, ...options })); + } + }, [controlImage, dispatch, optimalDimension, shift]); - const draggableData = useMemo(() => { - if (controlImage) { - return { - id: ipAdapterId, - payloadType: 'IMAGE_DTO', - payload: { imageDTO: controlImage }, - }; - } - }, [controlImage, ipAdapterId]); + const draggableData = useMemo(() => { + if (controlImage) { + return { + id: ipAdapterId, + payloadType: 'IMAGE_DTO', + payload: { imageDTO: controlImage }, + }; + } + }, [controlImage, ipAdapterId]); - useEffect(() => { - if (isConnected && isErrorControlImage) { - handleResetControlImage(); - } - }, [handleResetControlImage, isConnected, isErrorControlImage]); + useEffect(() => { + if (isConnected && isErrorControlImage) { + handleResetControlImage(); + } + }, [handleResetControlImage, isConnected, isErrorControlImage]); - return ( - - + return ( + + - {controlImage && ( - - } - tooltip={t('controlnet.resetControlImage')} - /> - } - tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')} - /> - - )} - - ); -}); + {controlImage && ( + + } + tooltip={t('controlnet.resetControlImage')} + /> + } + tooltip={ + shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions') + } + /> + + )} + + ); + } +); IPAdapterImagePreview.displayName = 'IPAdapterImagePreview'; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasProgressImage.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasProgressImage.ts index 00c796b2c2..0739c267b5 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasProgressImage.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasProgressImage.ts @@ -48,7 +48,7 @@ export class CanvasProgressImage { image: null, }; - this.manager.stateApi.$lastProgressEvent.listen((event) => { + this.manager.stateApi.$lastCanvasProgressEvent.listen((event) => { this.lastProgressEvent = event; this.render(); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStagingArea.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStagingArea.ts index c58186a14d..2ce21d239e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStagingArea.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStagingArea.ts @@ -76,7 +76,7 @@ export class CanvasStagingArea { if (!this.image.isLoading && !this.image.isError) { await this.image.updateImageSource(imageDTO.image_name); - this.manager.stateApi.$lastProgressEvent.set(null); + this.manager.stateApi.$lastCanvasProgressEvent.set(null); } this.image.konva.group.visible(shouldShowStagedImage); } else { diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApi.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApi.ts index 262e28dce5..31a2aee8b2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApi.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApi.ts @@ -11,7 +11,6 @@ import { $lastAddedPoint, $lastCursorPos, $lastMouseDownPos, - $lastProgressEvent, $shouldShowStagedImage, $spaceKey, $stageAttrs, @@ -51,6 +50,7 @@ import type { import { RGBA_RED } from 'features/controlLayers/store/types'; import type { WritableAtom } from 'nanostores'; import { atom } from 'nanostores'; +import { $lastCanvasProgressEvent } from 'services/events/setEventListeners'; type EntityStateAndAdapter = | { @@ -263,7 +263,7 @@ export class CanvasStateApi { $lastAddedPoint = $lastAddedPoint; $lastMouseDownPos = $lastMouseDownPos; $lastCursorPos = $lastCursorPos; - $lastProgressEvent = $lastProgressEvent; + $lastCanvasProgressEvent = $lastCanvasProgressEvent; $spaceKey = $spaceKey; $altKey = $alt; $ctrlKey = $ctrl; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts index bc939f0894..d0ae06c28a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts @@ -20,7 +20,6 @@ import { initialAspectRatioState } from 'features/parameters/components/Document import { getOptimalDimension } from 'features/parameters/util/optimalDimension'; import { isEqual, pick } from 'lodash-es'; import { atom } from 'nanostores'; -import type { InvocationDenoiseProgressEvent } from 'services/events/types'; import { assert } from 'tsafe'; import type { @@ -622,7 +621,6 @@ export const $stageAttrs = atom({ scale: 0, }); export const $shouldShowStagedImage = atom(true); -export const $lastProgressEvent = atom(null); export const $isDrawing = atom(false); export const $isMouseDown = atom(false); export const $lastAddedPoint = atom(null); diff --git a/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx b/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx index 6855cb8e55..452d101fa2 100644 --- a/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx +++ b/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx @@ -1,5 +1,7 @@ import type { IconButtonProps } from '@invoke-ai/ui-library'; import { IconButton } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { useAppSelector } from 'app/store/storeHooks'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -12,7 +14,7 @@ type DeleteImageButtonProps = Omit & { export const DeleteImageButton = memo((props: DeleteImageButtonProps) => { const { onClick, isDisabled } = props; const { t } = useTranslation(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const imageSelectionLength: number = useAppSelector((s) => s.gallery.selection.length); const labelMessage: string = `${t('gallery.deleteImage', { count: imageSelectionLength })} (Del)`; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx index 1ef91e7e2e..9ccd69b898 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx @@ -1,7 +1,7 @@ import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/query'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton'; @@ -10,17 +10,15 @@ import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMe import { useImageActions } from 'features/gallery/hooks/useImageActions'; import { sentImageToImg2Img } from 'features/gallery/store/actions'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; -import { selectGallerySlice } from 'features/gallery/store/gallerySlice'; import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers'; import { $templates } from 'features/nodes/store/nodesSlice'; import { PostProcessingPopover } from 'features/parameters/components/PostProcessing/PostProcessingPopover'; import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import { selectSystemSlice } from 'features/system/store/systemSlice'; import { setActiveTab } from 'features/ui/store/uiSlice'; import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow'; import { size } from 'lodash-es'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; import { @@ -33,23 +31,17 @@ import { PiRulerBold, } from 'react-icons/pi'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; - -const selectShouldDisableToolbarButtons = createSelector( - selectSystemSlice, - selectGallerySlice, - selectLastSelectedImage, - (system, gallery, lastSelectedImage) => { - const hasProgressImage = Boolean(system.denoiseProgress?.progress_image); - return hasProgressImage || !lastSelectedImage; - } -); +import { $progressImage } from 'services/events/setEventListeners'; const CurrentImageButtons = () => { const dispatch = useAppDispatch(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const lastSelectedImage = useAppSelector(selectLastSelectedImage); + const progressImage = useStore($progressImage); const selection = useAppSelector((s) => s.gallery.selection); - const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons); + const shouldDisableToolbarButtons = useMemo(() => { + return Boolean(progressImage) || !lastSelectedImage; + }, [lastSelectedImage, progressImage]); const templates = useStore($templates); const isUpscalingEnabled = useFeatureStatus('upscaling'); const isQueueMutationInProgress = useIsQueueMutationInProgress(); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx index a812391992..23e75498ec 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx @@ -1,4 +1,5 @@ import { Box, Flex } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/query'; import { useAppSelector } from 'app/store/storeHooks'; @@ -14,6 +15,7 @@ import { memo, useCallback, useMemo, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { PiImageBold } from 'react-icons/pi'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { $hasProgress } from 'services/events/setEventListeners'; import ProgressImage from './ProgressImage'; @@ -26,7 +28,7 @@ const CurrentImagePreview = () => { const { t } = useTranslation(); const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails); const imageName = useAppSelector(selectLastSelectedImageName); - const hasDenoiseProgress = useAppSelector((s) => Boolean(s.system.denoiseProgress)); + const hasDenoiseProgress = useStore($hasProgress); const shouldShowProgressInViewer = useAppSelector((s) => s.ui.shouldShowProgressInViewer); const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImage.tsx index 0ee75fbcd4..46c1bd71c2 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImage.tsx @@ -1,10 +1,12 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Image } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { useAppSelector } from 'app/store/storeHooks'; import { memo, useMemo } from 'react'; +import { $progressImage } from 'services/events/setEventListeners'; const CurrentImagePreview = () => { - const progress_image = useAppSelector((s) => s.system.denoiseProgress?.progress_image); + const progressImage = useStore($progressImage); const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage); const sx = useMemo( @@ -14,15 +16,15 @@ const CurrentImagePreview = () => { [shouldAntialiasProgressImage] ); - if (!progress_image) { + if (!progressImage) { return null; } return ( { - const imageDTO = gallery.selection[gallery.selection.length - 1]; - - return { - imageDTO, - progressImage: system.denoiseProgress?.progress_image, - }; -}); +import { $lastProgressEvent } from 'services/events/setEventListeners'; const CurrentImageNode = (props: NodeProps) => { - const { progressImage, imageDTO } = useAppSelector(selector); + const imageDTO = useAppSelector((s) => s.gallery.selection[s.gallery.selection.length - 1]); + const lastProgressEvent = useStore($lastProgressEvent); - if (progressImage) { + if (lastProgressEvent?.progress_image) { return ( - + ); } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx index c3224238c5..1ec0b575f6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx @@ -1,6 +1,8 @@ import { Flex, Text } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { skipToken } from '@reduxjs/toolkit/query'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { $isConnected } from 'app/hooks/useSocketIO'; +import { useAppDispatch } from 'app/store/storeHooks'; import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types'; @@ -17,7 +19,7 @@ import type { FieldComponentProps } from './types'; const ImageFieldInputComponent = (props: FieldComponentProps) => { const { nodeId, field } = props; const dispatch = useAppDispatch(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const { currentData: imageDTO, isError } = useGetImageDTOQuery(field.value?.image_name ?? skipToken); const handleReset = useCallback(() => { diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts index 9d92eabff8..d9ad1a736f 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts @@ -1,11 +1,12 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useCancelByBatchIdsMutation, useGetBatchStatusQuery } from 'services/api/endpoints/queue'; export const useCancelBatch = (batch_id: string) => { - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const { isCanceled } = useGetBatchStatusQuery( { batch_id }, { diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts index 057490ed99..9ae8e2dd2e 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts @@ -1,4 +1,5 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { isNil } from 'lodash-es'; import { useCallback, useMemo } from 'react'; @@ -6,7 +7,7 @@ import { useTranslation } from 'react-i18next'; import { useCancelQueueItemMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue'; export const useCancelCurrentQueueItem = () => { - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const { data: queueStatus } = useGetQueueStatusQuery(); const [trigger, { isLoading }] = useCancelQueueItemMutation(); const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts index 268eca75cc..bf0af41605 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts @@ -1,11 +1,12 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useCancelQueueItemMutation } from 'services/api/endpoints/queue'; export const useCancelQueueItem = (item_id: number) => { - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const [trigger, { isLoading }] = useCancelQueueItemMutation(); const { t } = useTranslation(); const cancelQueueItem = useCallback(async () => { diff --git a/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts b/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts index 7ef9d93742..d177a72f5f 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts @@ -1,4 +1,5 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -7,7 +8,7 @@ import { useClearInvocationCacheMutation, useGetInvocationCacheStatusQuery } fro export const useClearInvocationCache = () => { const { t } = useTranslation(); const { data: cacheStatus } = useGetInvocationCacheStatusQuery(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const [trigger, { isLoading }] = useClearInvocationCacheMutation({ fixedCacheKey: 'clearInvocationCache', }); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useClearQueue.ts b/invokeai/frontend/web/src/features/queue/hooks/useClearQueue.ts index ca7d1e4894..bb80f7aa10 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useClearQueue.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useClearQueue.ts @@ -1,4 +1,6 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; +import { useAppDispatch } from 'app/store/storeHooks'; import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; @@ -9,7 +11,7 @@ export const useClearQueue = () => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const { data: queueStatus } = useGetQueueStatusQuery(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const [trigger, { isLoading }] = useClearQueueMutation({ fixedCacheKey: 'clearQueue', }); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts b/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts index 371e9198e7..cf71e4bd4b 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts @@ -1,4 +1,5 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -7,7 +8,7 @@ import { useDisableInvocationCacheMutation, useGetInvocationCacheStatusQuery } f export const useDisableInvocationCache = () => { const { t } = useTranslation(); const { data: cacheStatus } = useGetInvocationCacheStatusQuery(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const [trigger, { isLoading }] = useDisableInvocationCacheMutation({ fixedCacheKey: 'disableInvocationCache', }); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts index fb39cf7347..7f28bddd78 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts @@ -1,4 +1,5 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -7,7 +8,7 @@ import { useEnableInvocationCacheMutation, useGetInvocationCacheStatusQuery } fr export const useEnableInvocationCache = () => { const { t } = useTranslation(); const { data: cacheStatus } = useGetInvocationCacheStatusQuery(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const [trigger, { isLoading }] = useEnableInvocationCacheMutation({ fixedCacheKey: 'enableInvocationCache', }); diff --git a/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts b/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts index f5424c6b18..d25c8051e5 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts @@ -1,4 +1,5 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -6,7 +7,7 @@ import { useGetQueueStatusQuery, usePauseProcessorMutation } from 'services/api/ export const usePauseProcessor = () => { const { t } = useTranslation(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const { data: queueStatus } = useGetQueueStatusQuery(); const [trigger, { isLoading }] = usePauseProcessorMutation({ fixedCacheKey: 'pauseProcessor', diff --git a/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts b/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts index eaeabe5423..f9426291be 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts @@ -1,4 +1,6 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; +import { useAppDispatch } from 'app/store/storeHooks'; import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; @@ -8,7 +10,7 @@ import { useGetQueueStatusQuery, usePruneQueueMutation } from 'services/api/endp export const usePruneQueue = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const [trigger, { isLoading }] = usePruneQueueMutation({ fixedCacheKey: 'pruneQueue', }); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts b/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts index 851b268416..72d787103b 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts @@ -1,11 +1,12 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery, useResumeProcessorMutation } from 'services/api/endpoints/queue'; export const useResumeProcessor = () => { - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const { data: queueStatus } = useGetQueueStatusQuery(); const { t } = useTranslation(); const [trigger, { isLoading }] = useResumeProcessorMutation({ diff --git a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx index 4389431813..06c7e70c7f 100644 --- a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx +++ b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx @@ -1,28 +1,28 @@ import { Progress } from '@invoke-ai/ui-library'; -import { createSelector } from '@reduxjs/toolkit'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectSystemSlice } from 'features/system/store/systemSlice'; -import { memo } from 'react'; +import { useStore } from '@nanostores/react'; +import { $isConnected } from 'app/hooks/useSocketIO'; +import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; - -const selectProgressValue = createSelector( - selectSystemSlice, - (system) => (system.denoiseProgress?.percentage ?? 0) * 100 -); +import { $lastProgressEvent } from 'services/events/setEventListeners'; const ProgressBar = () => { const { t } = useTranslation(); const { data: queueStatus } = useGetQueueStatusQuery(); - const isConnected = useAppSelector((s) => s.system.isConnected); - const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress)); - const value = useAppSelector(selectProgressValue); + const isConnected = useStore($isConnected); + const lastProgressEvent = useStore($lastProgressEvent); + const value = useMemo(() => { + if (!lastProgressEvent) { + return 0; + } + return (lastProgressEvent.percentage ?? 0) * 100; + }, [lastProgressEvent]); return ( { - const isConnected = useAppSelector((s) => s.system.isConnected); + const isConnected = useStore($isConnected); const { t } = useTranslation(); if (!isConnected) { diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.ts b/invokeai/frontend/web/src/services/events/setEventListeners.ts deleted file mode 100644 index 8c8c9da2e8..0000000000 --- a/invokeai/frontend/web/src/services/events/setEventListeners.ts +++ /dev/null @@ -1,136 +0,0 @@ -import { $baseUrl } from 'app/store/nanostores/baseUrl'; -import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId'; -import { $queueId } from 'app/store/nanostores/queueId'; -import type { AppDispatch } from 'app/store/store'; -import { toast } from 'features/toast/toast'; -import { - socketBatchEnqueued, - socketBulkDownloadComplete, - socketBulkDownloadError, - socketBulkDownloadStarted, - socketConnected, - socketDisconnected, - socketDownloadCancelled, - socketDownloadComplete, - socketDownloadError, - socketDownloadProgress, - socketDownloadStarted, - socketGeneratorProgress, - socketInvocationComplete, - socketInvocationError, - socketInvocationStarted, - socketModelInstallCancelled, - socketModelInstallComplete, - socketModelInstallDownloadProgress, - socketModelInstallDownloadsComplete, - socketModelInstallError, - socketModelInstallStarted, - socketModelLoadComplete, - socketModelLoadStarted, - socketQueueCleared, - socketQueueItemStatusChanged, -} from 'services/events/actions'; -import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; -import type { Socket } from 'socket.io-client'; - -type SetEventListenersArg = { - socket: Socket; - dispatch: AppDispatch; -}; - -export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) => { - socket.on('connect', () => { - dispatch(socketConnected()); - const queue_id = $queueId.get(); - socket.emit('subscribe_queue', { queue_id }); - if (!$baseUrl.get()) { - const bulk_download_id = $bulkDownloadId.get(); - socket.emit('subscribe_bulk_download', { bulk_download_id }); - } - }); - socket.on('connect_error', (error) => { - if (error && error.message) { - const data: string | undefined = (error as unknown as { data: string | undefined }).data; - if (data === 'ERR_UNAUTHENTICATED') { - toast({ - id: `connect-error-${error.message}`, - title: error.message, - status: 'error', - duration: 10000, - }); - } - } - }); - socket.on('disconnect', () => { - dispatch(socketDisconnected()); - }); - socket.on('invocation_started', (data) => { - dispatch(socketInvocationStarted({ data })); - }); - socket.on('invocation_denoise_progress', (data) => { - dispatch(socketGeneratorProgress({ data })); - }); - socket.on('invocation_error', (data) => { - dispatch(socketInvocationError({ data })); - }); - socket.on('invocation_complete', (data) => { - dispatch(socketInvocationComplete({ data })); - }); - socket.on('model_load_started', (data) => { - dispatch(socketModelLoadStarted({ data })); - }); - socket.on('model_load_complete', (data) => { - dispatch(socketModelLoadComplete({ data })); - }); - socket.on('download_started', (data) => { - dispatch(socketDownloadStarted({ data })); - }); - socket.on('download_progress', (data) => { - dispatch(socketDownloadProgress({ data })); - }); - socket.on('download_complete', (data) => { - dispatch(socketDownloadComplete({ data })); - }); - socket.on('download_cancelled', (data) => { - dispatch(socketDownloadCancelled({ data })); - }); - socket.on('download_error', (data) => { - dispatch(socketDownloadError({ data })); - }); - socket.on('model_install_started', (data) => { - dispatch(socketModelInstallStarted({ data })); - }); - socket.on('model_install_download_progress', (data) => { - dispatch(socketModelInstallDownloadProgress({ data })); - }); - socket.on('model_install_downloads_complete', (data) => { - dispatch(socketModelInstallDownloadsComplete({ data })); - }); - socket.on('model_install_complete', (data) => { - dispatch(socketModelInstallComplete({ data })); - }); - socket.on('model_install_error', (data) => { - dispatch(socketModelInstallError({ data })); - }); - socket.on('model_install_cancelled', (data) => { - dispatch(socketModelInstallCancelled({ data })); - }); - socket.on('queue_item_status_changed', (data) => { - dispatch(socketQueueItemStatusChanged({ data })); - }); - socket.on('queue_cleared', (data) => { - dispatch(socketQueueCleared({ data })); - }); - socket.on('batch_enqueued', (data) => { - dispatch(socketBatchEnqueued({ data })); - }); - socket.on('bulk_download_started', (data) => { - dispatch(socketBulkDownloadStarted({ data })); - }); - socket.on('bulk_download_complete', (data) => { - dispatch(socketBulkDownloadComplete({ data })); - }); - socket.on('bulk_download_error', (data) => { - dispatch(socketBulkDownloadError({ data })); - }); -}; diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx new file mode 100644 index 0000000000..379b280032 --- /dev/null +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -0,0 +1,621 @@ +import { ExternalLink } from '@invoke-ai/ui-library'; +import { logger } from 'app/logging/logger'; +import { $baseUrl } from 'app/store/nanostores/baseUrl'; +import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId'; +import { $queueId } from 'app/store/nanostores/queueId'; +import type { AppDispatch, RootState } from 'app/store/store'; +import type { SerializableObject } from 'common/types'; +import { deepClone } from 'common/util/deepClone'; +import { sessionImageStaged } from 'features/controlLayers/store/canvasV2Slice'; +import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice'; +import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; +import { zNodeStatus } from 'features/nodes/types/invocation'; +import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription'; +import { toast } from 'features/toast/toast'; +import { t } from 'i18next'; +import { forEach } from 'lodash-es'; +import { atom, computed } from 'nanostores'; +import { api, LIST_TAG } from 'services/api'; +import { boardsApi } from 'services/api/endpoints/boards'; +import { imagesApi } from 'services/api/endpoints/images'; +import { modelsApi } from 'services/api/endpoints/models'; +import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; +import { getCategories, getListImagesUrl } from 'services/api/util'; +import { socketConnected } from 'services/events/actions'; +import type { ClientToServerEvents, InvocationDenoiseProgressEvent, ServerToClientEvents } from 'services/events/types'; +import type { Socket } from 'socket.io-client'; + +const log = logger('socketio'); + +type SetEventListenersArg = { + socket: Socket; + dispatch: AppDispatch; + getState: () => RootState; + setIsConnected: (isConnected: boolean) => void; +}; + +const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select(); +const nodeTypeDenylist = ['load_image', 'image']; +export const $lastProgressEvent = atom(null); +export const $lastCanvasProgressEvent = atom(null); +export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val)); +export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null); +const cancellations = new Set(); + +export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }: SetEventListenersArg) => { + socket.on('connect', () => { + log.debug('Connected'); + setIsConnected(true); + dispatch(socketConnected()); + const queue_id = $queueId.get(); + socket.emit('subscribe_queue', { queue_id }); + if (!$baseUrl.get()) { + const bulk_download_id = $bulkDownloadId.get(); + socket.emit('subscribe_bulk_download', { bulk_download_id }); + } + $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); + cancellations.clear(); + }); + + socket.on('connect_error', (error) => { + log.debug('Connect error'); + setIsConnected(false); + $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); + if (error && error.message) { + const data: string | undefined = (error as unknown as { data: string | undefined }).data; + if (data === 'ERR_UNAUTHENTICATED') { + toast({ + id: `connect-error-${error.message}`, + title: error.message, + status: 'error', + duration: 10000, + }); + } + } + cancellations.clear(); + }); + + socket.on('disconnect', () => { + log.debug('Disconnected'); + $lastProgressEvent.set(null); + $lastCanvasProgressEvent.set(null); + setIsConnected(false); + cancellations.clear(); + }); + + socket.on('invocation_started', (data) => { + const { invocation_source_id, invocation } = data; + log.debug({ data } as SerializableObject, `Invocation started (${invocation.type}, ${invocation_source_id})`); + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); + if (nes) { + nes.status = zNodeStatus.enum.IN_PROGRESS; + upsertExecutionState(nes.nodeId, nes); + } + cancellations.clear(); + }); + + socket.on('invocation_denoise_progress', (data) => { + const { invocation_source_id, invocation, step, total_steps, progress_image, origin, percentage, session_id } = + data; + + if (cancellations.has(session_id)) { + // Do not update the progress if this session has been cancelled. This prevents a race condition where we get a + // progress update after the session has been cancelled. + return; + } + + log.trace( + { data } as SerializableObject, + `Denoise ${Math.round(percentage * 100)}% (${invocation.type}, ${invocation_source_id})` + ); + + $lastProgressEvent.set(data); + + if (origin === 'workflows') { + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); + if (nes) { + nes.status = zNodeStatus.enum.IN_PROGRESS; + nes.progress = (step + 1) / total_steps; + nes.progressImage = progress_image ?? null; + upsertExecutionState(nes.nodeId, nes); + } + } + + if (origin === 'canvas') { + $lastCanvasProgressEvent.set(data); + } + }); + + socket.on('invocation_error', (data) => { + const { invocation_source_id, invocation, error_type, error_message, error_traceback } = data; + log.error({ data } as SerializableObject, `Invocation error (${invocation.type}, ${invocation_source_id})`); + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); + if (nes) { + nes.status = zNodeStatus.enum.FAILED; + nes.progress = null; + nes.progressImage = null; + nes.error = { + error_type, + error_message, + error_traceback, + }; + upsertExecutionState(nes.nodeId, nes); + } + }); + + socket.on('invocation_complete', async (data) => { + log.debug( + { data } as SerializableObject, + `Invocation complete (${data.invocation.type}, ${data.invocation_source_id})` + ); + + const { result, invocation_source_id } = data; + + if (data.origin === 'workflows') { + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); + if (nes) { + nes.status = zNodeStatus.enum.COMPLETED; + if (nes.progress !== null) { + nes.progress = 1; + } + nes.outputs.push(result); + upsertExecutionState(nes.nodeId, nes); + } + } + + // This complete event has an associated image output + if ( + (data.result.type === 'image_output' || data.result.type === 'canvas_v2_mask_and_crop_output') && + !nodeTypeDenylist.includes(data.invocation.type) + ) { + const { image_name } = data.result.image; + const { gallery, canvasV2 } = getState(); + + // This populates the `getImageDTO` cache + const imageDTORequest = dispatch( + imagesApi.endpoints.getImageDTO.initiate(image_name, { + forceRefetch: true, + }) + ); + + const imageDTO = await imageDTORequest.unwrap(); + imageDTORequest.unsubscribe(); + + // handle tab-specific logic + if (data.origin === 'canvas' && data.invocation_source_id === 'canvas_output') { + if (data.result.type === 'canvas_v2_mask_and_crop_output') { + const { offset_x, offset_y } = data.result; + if (canvasV2.session.isStaging) { + dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: offset_x, offsetY: offset_y } })); + } + } else if (data.result.type === 'image_output') { + if (canvasV2.session.isStaging) { + dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } })); + } + } + } + + if (!imageDTO.is_intermediate) { + // update the total images for the board + dispatch( + boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + draft.total += 1; + }) + ); + + dispatch( + imagesApi.util.invalidateTags([ + { type: 'Board', id: imageDTO.board_id ?? 'none' }, + { + type: 'ImageList', + id: getListImagesUrl({ + board_id: imageDTO.board_id ?? 'none', + categories: getCategories(imageDTO), + }), + }, + ]) + ); + + const { shouldAutoSwitch } = gallery; + + // If auto-switch is enabled, select the new image + if (shouldAutoSwitch) { + // if auto-add is enabled, switch the gallery view and board if needed as the image comes in + if (gallery.galleryView !== 'images') { + dispatch(galleryViewChanged('images')); + } + + if (imageDTO.board_id && imageDTO.board_id !== gallery.selectedBoardId) { + dispatch( + boardIdSelected({ + boardId: imageDTO.board_id, + selectedImageName: imageDTO.image_name, + }) + ); + } + + dispatch(offsetChanged({ offset: 0 })); + + if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') { + dispatch( + boardIdSelected({ + boardId: 'none', + selectedImageName: imageDTO.image_name, + }) + ); + } + + dispatch(imageSelected(imageDTO)); + } + } + } + + $lastProgressEvent.set(null); + }); + + socket.on('model_load_started', (data) => { + const { config, submodel_type } = data; + const { name, base, type } = config; + + const extras: string[] = [base, type]; + + if (submodel_type) { + extras.push(submodel_type); + } + + const message = `Model load started: ${name} (${extras.join(', ')})`; + + log.debug({ data }, message); + }); + + socket.on('model_load_complete', (data) => { + const { config, submodel_type } = data; + const { name, base, type } = config; + + const extras: string[] = [base, type]; + if (submodel_type) { + extras.push(submodel_type); + } + + const message = `Model load complete: ${name} (${extras.join(', ')})`; + + log.debug({ data }, message); + }); + + socket.on('download_started', (data) => { + log.debug({ data }, 'Download started'); + }); + + socket.on('download_progress', (data) => { + log.trace({ data }, 'Download progress'); + }); + + socket.on('download_complete', (data) => { + log.debug({ data }, 'Download complete'); + }); + + socket.on('download_cancelled', (data) => { + log.warn({ data }, 'Download cancelled'); + }); + + socket.on('download_error', (data) => { + log.error({ data }, 'Download error'); + }); + + socket.on('model_install_started', (data) => { + log.debug({ data }, 'Model install started'); + + const { id } = data; + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'running'; + } + return draft; + }) + ); + } + }); + + socket.on('model_install_download_started', (data) => { + log.debug({ data }, 'Model install download started'); + + const { id } = data; + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'downloading'; + } + return draft; + }) + ); + } + }); + + socket.on('model_install_download_progress', (data) => { + log.trace({ data }, 'Model install download progress'); + + const { bytes, total_bytes, id } = data; + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.bytes = bytes; + modelImport.total_bytes = total_bytes; + modelImport.status = 'downloading'; + } + return draft; + }) + ); + } + }); + + socket.on('model_install_downloads_complete', (data) => { + log.debug({ data }, 'Model install downloads complete'); + + const { id } = data; + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'downloads_done'; + } + return draft; + }) + ); + } + }); + + socket.on('model_install_complete', (data) => { + log.debug({ data }, 'Model install complete'); + + const { id } = data; + + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'completed'; + } + return draft; + }) + ); + } + + dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }])); + dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }])); + }); + + socket.on('model_install_error', (data) => { + log.error({ data }, 'Model install error'); + + const { id, error, error_type } = data; + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'error'; + modelImport.error_reason = error_type; + modelImport.error = error; + } + return draft; + }) + ); + } + }); + + socket.on('model_install_cancelled', (data) => { + log.warn({ data }, 'Model install cancelled'); + + const { id } = data; + const installs = selectModelInstalls(getState()).data; + + if (!installs?.find((install) => install.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'cancelled'; + } + return draft; + }) + ); + } + }); + + socket.on('queue_item_status_changed', (data) => { + // we've got new status for the queue item, batch and queue + const { + item_id, + session_id, + status, + started_at, + updated_at, + completed_at, + batch_status, + queue_status, + error_type, + error_message, + error_traceback, + origin, + } = data; + + log.debug({ data }, `Queue item ${item_id} status updated: ${status}`); + + // Update this specific queue item in the list of queue items (this is the queue item DTO, without the session) + dispatch( + queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => { + queueItemsAdapter.updateOne(draft, { + id: String(item_id), + changes: { + status, + started_at, + updated_at: updated_at ?? undefined, + completed_at: completed_at ?? undefined, + error_type, + error_message, + error_traceback, + }, + }); + }) + ); + + // Update the queue status (we do not get the processor status here) + dispatch( + queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => { + if (!draft) { + return; + } + Object.assign(draft.queue, queue_status); + }) + ); + + // Update the batch status + dispatch(queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status)); + + // Invalidate caches for things we cannot update + // TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again + dispatch( + queueApi.util.invalidateTags([ + 'CurrentSessionQueueItem', + 'NextSessionQueueItem', + 'InvocationCacheStatus', + { type: 'SessionQueueItem', id: item_id }, + ]) + ); + + if (status === 'in_progress') { + forEach($nodeExecutionStates.get(), (nes) => { + if (!nes) { + return; + } + const clone = deepClone(nes); + clone.status = zNodeStatus.enum.PENDING; + clone.error = null; + clone.progress = null; + clone.progressImage = null; + clone.outputs = []; + $nodeExecutionStates.setKey(clone.nodeId, clone); + }); + } else if (status === 'failed' && error_type) { + const isLocal = getState().config.isLocal ?? true; + const sessionId = session_id; + $lastProgressEvent.set(null); + + if (origin === 'canvas') { + $lastCanvasProgressEvent.set(null); + } + + toast({ + id: `INVOCATION_ERROR_${error_type}`, + title: getTitleFromErrorType(error_type), + status: 'error', + duration: null, + updateDescription: isLocal, + description: ( + + ), + }); + cancellations.add(session_id); + } else if (status === 'canceled') { + $lastProgressEvent.set(null); + if (origin === 'canvas') { + $lastCanvasProgressEvent.set(null); + } + cancellations.add(session_id); + } else if (status === 'completed') { + $lastProgressEvent.set(null); + cancellations.add(session_id); + } + }); + + socket.on('queue_cleared', (data) => { + log.debug({ data }, 'Queue cleared'); + }); + + socket.on('batch_enqueued', (data) => { + log.debug({ data }, 'Batch enqueued'); + }); + + socket.on('bulk_download_started', (data) => { + log.debug({ data }, 'Bulk gallery download preparation started'); + }); + + socket.on('bulk_download_complete', (data) => { + log.debug({ data }, 'Bulk gallery download ready'); + const { bulk_download_item_name } = data; + + // TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first + const url = `/api/v1/images/download/${bulk_download_item_name}`; + + toast({ + id: bulk_download_item_name, + title: t('gallery.bulkDownloadReady', 'Download ready'), + status: 'success', + description: ( + + ), + duration: null, + }); + }); + + socket.on('bulk_download_error', (data) => { + log.error({ data }, 'Bulk gallery download error'); + + const { bulk_download_item_name, error } = data; + + toast({ + id: bulk_download_item_name, + title: t('gallery.bulkDownloadFailed'), + status: 'error', + description: error, + duration: null, + }); + }); +};