mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): rework progress event handling
- Canvas manages its own progress socket event listeners and progress event data. - Remove cancellations listener jank. - Dip into low-level redux subscription API to watch for queue status changes, clearing the last "global" progress event when the queue has nothing in progress. Could also do this in a useEffect I guess. - Had to shuffle some things around to prevent circular imports, so there are a lot of tiny changes here.
This commit is contained in:
committed by
Kent Keirsey
parent
b08a66ecaf
commit
7db4d26837
@@ -1,5 +1,4 @@
|
||||
import { Box, useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
||||
import { useSocketIO } from 'app/hooks/useSocketIO';
|
||||
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||
@@ -31,6 +30,7 @@ import { size } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect } from 'react';
|
||||
import { ErrorBoundary } from 'react-error-boundary';
|
||||
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
|
||||
import { useSocketIO } from 'services/events/useSocketIO';
|
||||
|
||||
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
|
||||
import PreselectedImage from './PreselectedImage';
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import 'i18n';
|
||||
|
||||
import type { Middleware } from '@reduxjs/toolkit';
|
||||
import { $socketOptions } from 'app/hooks/useSocketIO';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { $customNavComponent } from 'app/store/nanostores/customNavComponent';
|
||||
@@ -24,6 +23,7 @@ import type { PropsWithChildren, ReactNode } from 'react';
|
||||
import React, { lazy, memo, useEffect, useMemo } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||
import { $socketOptions } from 'services/events/stores';
|
||||
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
|
||||
|
||||
const App = lazy(() => import('./App'));
|
||||
|
||||
@@ -9,7 +9,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
|
||||
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addCancellationsListeners } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners';
|
||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||
@@ -73,15 +72,6 @@ addAnyEnqueuedListener(startAppListening);
|
||||
addBatchEnqueuedListener(startAppListening);
|
||||
|
||||
// Canvas actions
|
||||
// addCanvasSavedToGalleryListener(startAppListening);
|
||||
// addCanvasMaskSavedToGalleryListener(startAppListening);
|
||||
// addCanvasImageToControlNetListener(startAppListening);
|
||||
// addCanvasMaskToControlNetListener(startAppListening);
|
||||
// addCanvasDownloadedAsImageListener(startAppListening);
|
||||
// addCanvasCopiedToClipboardListener(startAppListening);
|
||||
// addCanvasMergedListener(startAppListening);
|
||||
// addStagingAreaImageSavedListener(startAppListening);
|
||||
// addCommitStagingAreaImageListener(startAppListening);
|
||||
addStagingListeners(startAppListening);
|
||||
|
||||
// Socket.IO
|
||||
@@ -121,6 +111,3 @@ addAdHocPostProcessingRequestedListener(startAppListening);
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
// addControlAdapterPreprocessor(startAppListening);
|
||||
|
||||
addCancellationsListeners(startAppListening);
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { $lastCanvasProgressEvent } from 'features/controlLayers/store/canvasSlice';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
/**
|
||||
* To prevent a race condition where a progress event arrives after a successful cancellation, we need to keep track of
|
||||
* cancellations:
|
||||
* - In the route handlers above, we track and update the cancellations object
|
||||
* - When the user queues a, we should reset the cancellations, also handled int he route handlers above
|
||||
* - When we get a progress event, we should check if the event is cancelled before setting the event
|
||||
*
|
||||
* We have a few ways that cancellations are effected, so we need to track them all:
|
||||
* - by queue item id (in this case, we will compare the session_id and not the item_id)
|
||||
* - by batch id
|
||||
* - by destination
|
||||
* - by clearing the queue
|
||||
*/
|
||||
type Cancellations = {
|
||||
sessionIds: Set<string>;
|
||||
batchIds: Set<string>;
|
||||
destinations: Set<string>;
|
||||
clearQueue: boolean;
|
||||
};
|
||||
|
||||
const resetCancellations = (): void => {
|
||||
cancellations.clearQueue = false;
|
||||
cancellations.sessionIds.clear();
|
||||
cancellations.batchIds.clear();
|
||||
cancellations.destinations.clear();
|
||||
};
|
||||
|
||||
const cancellations: Cancellations = {
|
||||
sessionIds: new Set(),
|
||||
batchIds: new Set(),
|
||||
destinations: new Set(),
|
||||
clearQueue: false,
|
||||
} as Readonly<Cancellations>;
|
||||
|
||||
/**
|
||||
* Checks if an item is cancelled, used to prevent race conditions with event handling.
|
||||
*
|
||||
* To use this, provide the session_id, batch_id and destination from the event payload.
|
||||
*/
|
||||
export const getIsCancelled = (item: {
|
||||
session_id: string;
|
||||
batch_id: string;
|
||||
destination?: string | null;
|
||||
}): boolean => {
|
||||
if (cancellations.clearQueue) {
|
||||
return true;
|
||||
}
|
||||
if (cancellations.sessionIds.has(item.session_id)) {
|
||||
return true;
|
||||
}
|
||||
if (cancellations.batchIds.has(item.batch_id)) {
|
||||
return true;
|
||||
}
|
||||
if (item.destination && cancellations.destinations.has(item.destination)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
export const addCancellationsListeners = (startAppListening: AppStartListening) => {
|
||||
// When we get a cancellation, we may need to clear the last progress event - next few listeners handle those cases.
|
||||
// Maybe we could use the `getIsCancelled` util here, but I think that could introduce _another_ race condition...
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||
effect: () => {
|
||||
resetCancellations();
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.cancelByBatchDestination.matchFulfilled,
|
||||
effect: (action) => {
|
||||
cancellations.destinations.add(action.meta.arg.originalArgs.destination);
|
||||
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
|
||||
effect: (action) => {
|
||||
cancellations.sessionIds.add(action.payload.session_id);
|
||||
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.cancelByBatchIds.matchFulfilled,
|
||||
effect: (action) => {
|
||||
for (const batch_id of action.meta.arg.originalArgs.batch_ids) {
|
||||
cancellations.batchIds.add(batch_id);
|
||||
}
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.clearQueue.matchFulfilled,
|
||||
effect: () => {
|
||||
cancellations.clearQueue = true;
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -66,7 +66,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'generation', destination)
|
||||
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'canvas', destination)
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
|
||||
@@ -15,7 +15,8 @@ import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilder
|
||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
|
||||
import { utilitiesApi } from 'services/api/endpoints/utilities';
|
||||
import { socketConnected } from 'services/events/setEventListeners';
|
||||
|
||||
import { socketConnected } from './socketConnected';
|
||||
|
||||
const matcher = isAnyOf(
|
||||
positivePromptChanged,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
@@ -6,11 +7,11 @@ import { atom } from 'nanostores';
|
||||
import { api } from 'services/api';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
|
||||
import { socketConnected } from 'services/events/setEventListeners';
|
||||
|
||||
const log = logger('events');
|
||||
|
||||
const $isFirstConnection = atom(true);
|
||||
export const socketConnected = createAction('socket/connected');
|
||||
|
||||
export const addSocketConnectedEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { $true } from 'app/store/nanostores/util';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -21,6 +20,7 @@ import i18n from 'i18next';
|
||||
import { forEach, upperFirst } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { getConnectedEdges } from 'reactflow';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
const LAYER_TYPE_TO_TKEY = {
|
||||
reference_image: 'controlLayers.referenceImage',
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||
@@ -16,6 +15,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO, PostUploadAction } from 'services/api/types';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
type Props = {
|
||||
image: ImageWithDims | null;
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $socket } from 'app/hooks/useSocketIO';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
@@ -7,6 +6,7 @@ import { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { $canvasManager } from 'features/controlLayers/store/canvasSlice';
|
||||
import Konva from 'konva';
|
||||
import { useLayoutEffect, useState } from 'react';
|
||||
import { $socket } from 'services/events/stores';
|
||||
import { useDevicePixelRatio } from 'use-device-pixel-ratio';
|
||||
|
||||
const log = logger('canvas');
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { AppSocket } from 'app/hooks/useSocketIO';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
@@ -31,6 +30,7 @@ import Konva from 'konva';
|
||||
import type { Atom } from 'nanostores';
|
||||
import { computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import type { AppSocket } from 'services/events/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { CanvasBackgroundModule } from './CanvasBackgroundModule';
|
||||
|
||||
@@ -4,7 +4,10 @@ import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'
|
||||
import { getPrefixedId, loadImage } from 'features/controlLayers/konva/util';
|
||||
import { selectShowProgressOnCanvas } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import Konva from 'konva';
|
||||
import { atom } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { selectCanvasQueueCounts } from 'services/api/endpoints/queue';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
export class CanvasProgressImageModule extends CanvasModuleBase {
|
||||
readonly type = 'progress_image';
|
||||
@@ -23,7 +26,8 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
|
||||
imageElement: HTMLImageElement | null = null;
|
||||
|
||||
subscriptions = new Set<() => void>();
|
||||
|
||||
$lastProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
|
||||
hasActiveGeneration: boolean = false;
|
||||
mutex: Mutex = new Mutex();
|
||||
|
||||
constructor(manager: CanvasManager) {
|
||||
@@ -41,11 +45,50 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
|
||||
image: null,
|
||||
};
|
||||
|
||||
this.subscriptions.add(this.manager.stateApi.$lastCanvasProgressEvent.listen(this.render));
|
||||
this.subscriptions.add(this.manager.stagingArea.$shouldShowStagedImage.listen(this.render));
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectShowProgressOnCanvas, this.render));
|
||||
this.subscriptions.add(this.setSocketEventListeners());
|
||||
this.subscriptions.add(
|
||||
this.manager.stateApi.createStoreSubscription(selectCanvasQueueCounts, ({ data }) => {
|
||||
if (data && (data.in_progress > 0 || data.pending > 0)) {
|
||||
this.hasActiveGeneration = true;
|
||||
} else {
|
||||
this.hasActiveGeneration = false;
|
||||
this.$lastProgressEvent.set(null);
|
||||
}
|
||||
})
|
||||
);
|
||||
this.subscriptions.add(this.$lastProgressEvent.listen(this.render));
|
||||
}
|
||||
|
||||
setSocketEventListeners = (): (() => void) => {
|
||||
const progressListener = (data: S['InvocationDenoiseProgressEvent']) => {
|
||||
if (data.destination !== 'canvas') {
|
||||
return;
|
||||
}
|
||||
if (!this.hasActiveGeneration) {
|
||||
return;
|
||||
}
|
||||
this.$lastProgressEvent.set(data);
|
||||
};
|
||||
|
||||
const clearProgress = () => {
|
||||
this.$lastProgressEvent.set(null);
|
||||
};
|
||||
|
||||
this.manager.socket.on('invocation_denoise_progress', progressListener);
|
||||
this.manager.socket.on('connect', clearProgress);
|
||||
this.manager.socket.on('connect_error', clearProgress);
|
||||
this.manager.socket.on('disconnect', clearProgress);
|
||||
|
||||
return () => {
|
||||
this.manager.socket.off('invocation_denoise_progress', progressListener);
|
||||
this.manager.socket.off('connect', clearProgress);
|
||||
this.manager.socket.off('connect_error', clearProgress);
|
||||
this.manager.socket.off('disconnect', clearProgress);
|
||||
};
|
||||
};
|
||||
|
||||
getNodes = () => {
|
||||
return [this.konva.group];
|
||||
};
|
||||
@@ -53,7 +96,7 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
|
||||
render = async () => {
|
||||
const release = await this.mutex.acquire();
|
||||
|
||||
const event = this.manager.stateApi.$lastCanvasProgressEvent.get();
|
||||
const event = this.$lastProgressEvent.get();
|
||||
const showProgressOnCanvas = this.manager.stateApi.runSelector(selectShowProgressOnCanvas);
|
||||
|
||||
if (!event || !showProgressOnCanvas) {
|
||||
|
||||
@@ -96,7 +96,7 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
|
||||
|
||||
if (!this.image.isLoading && !this.image.isError) {
|
||||
await this.image.update({ ...this.image.state, image: imageDTOToImageWithDims(imageDTO) }, true);
|
||||
this.manager.stateApi.$lastCanvasProgressEvent.set(null);
|
||||
this.manager.progressImage.$lastProgressEvent.set(null);
|
||||
}
|
||||
this.image.konva.group.visible(shouldShowStagedImage);
|
||||
} else {
|
||||
|
||||
@@ -15,7 +15,6 @@ import {
|
||||
settingsEraserWidthChanged,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import {
|
||||
$lastCanvasProgressEvent,
|
||||
bboxChangedFromCanvas,
|
||||
entityBrushLineAdded,
|
||||
entityEraserLineAdded,
|
||||
@@ -382,12 +381,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
|
||||
*/
|
||||
$isRasterizing = computed(this.$rasterizingAdapter, (rasterizingAdapter) => Boolean(rasterizingAdapter));
|
||||
|
||||
/**
|
||||
* The last canvas progress event. This is set in a global event listener. The staging area may set it to null when it
|
||||
* consumes the event.
|
||||
*/
|
||||
$lastCanvasProgressEvent = $lastCanvasProgressEvent;
|
||||
|
||||
/**
|
||||
* Whether the space key is currently pressed.
|
||||
*/
|
||||
|
||||
@@ -30,13 +30,7 @@ import type { IRect } from 'konva/lib/types';
|
||||
import { merge, omit } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import type {
|
||||
ControlNetModelConfig,
|
||||
ImageDTO,
|
||||
IPAdapterModelConfig,
|
||||
S,
|
||||
T2IAdapterModelConfig,
|
||||
} from 'services/api/types';
|
||||
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type {
|
||||
@@ -1236,7 +1230,6 @@ function actionsThrottlingFilter(action: UnknownAction) {
|
||||
return true;
|
||||
}
|
||||
|
||||
export const $lastCanvasProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
|
||||
/**
|
||||
* The global canvas manager instance.
|
||||
*/
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import type { IconButtonProps } from '@invoke-ai/ui-library';
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSelectionCount } from 'features/gallery/store/gallerySelectors';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
type DeleteImageButtonProps = Omit<IconButtonProps, 'aria-label'> & {
|
||||
onClick: () => void;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { INTERACTION_SCOPES } from 'common/hooks/interactionScopes';
|
||||
@@ -30,7 +29,7 @@ import {
|
||||
PiRulerBold,
|
||||
} from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { $progressImage } from 'services/events/setEventListeners';
|
||||
import { $isConnected, $progressImage } from 'services/events/stores';
|
||||
|
||||
const CurrentImageButtons = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
@@ -15,7 +15,7 @@ import { memo, useCallback, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImageBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { $hasProgress, $isProgressFromCanvas } from 'services/events/setEventListeners';
|
||||
import { $hasProgress, $isProgressFromCanvas } from 'services/events/stores';
|
||||
|
||||
import ProgressImage from './ProgressImage';
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { $isProgressFromCanvas, $progressImage } from 'services/events/setEventListeners';
|
||||
import { $isProgressFromCanvas, $progressImage } from 'services/events/stores';
|
||||
|
||||
const selectShouldAntialiasProgressImage = createSelector(
|
||||
selectSystemSlice,
|
||||
|
||||
@@ -13,7 +13,7 @@ import type { CSSProperties, PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { NodeProps } from 'reactflow';
|
||||
import { $lastProgressEvent } from 'services/events/setEventListeners';
|
||||
import { $lastProgressEvent } from 'services/events/stores';
|
||||
|
||||
const CurrentImageNode = (props: NodeProps) => {
|
||||
const imageDTO = useAppSelector(selectLastSelectedImage);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||
@@ -13,6 +12,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import type { PostUploadAction } from 'services/api/types';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ export const prepareLinearUIBatch = (
|
||||
prepend: boolean,
|
||||
noise: Invocation<'noise' | 'flux_denoise'>,
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>,
|
||||
origin: 'generation' | 'workflows' | 'upscaling',
|
||||
origin: 'canvas' | 'workflows' | 'upscaling',
|
||||
destination: 'canvas' | 'gallery'
|
||||
): BatchConfig => {
|
||||
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.params;
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { ConfirmationAlertDialog, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { buildUseBoolean } from 'common/hooks/useBoolean';
|
||||
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
|
||||
@@ -8,6 +7,7 @@ import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClearQueueMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
const [useClearQueueConfirmationAlertDialog] = buildUseBoolean(false);
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCancelByBatchIdsMutation, useGetBatchStatusQuery } from 'services/api/endpoints/queue';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
export const useCancelBatch = (batch_id: string) => {
|
||||
const isConnected = useStore($isConnected);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCancelQueueItemMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
export const useCancelCurrentQueueItem = () => {
|
||||
const isConnected = useStore($isConnected);
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCancelQueueItemMutation } from 'services/api/endpoints/queue';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
export const useCancelQueueItem = (item_id: number) => {
|
||||
const isConnected = useStore($isConnected);
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClearInvocationCacheMutation, useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
export const useClearInvocationCache = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDisableInvocationCacheMutation, useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
export const useDisableInvocationCache = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useEnableInvocationCacheMutation, useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
export const useEnableInvocationCache = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetQueueStatusQuery, usePauseProcessorMutation } from 'services/api/endpoints/queue';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
export const usePauseProcessor = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetQueueStatusQuery, usePruneQueueMutation } from 'services/api/endpoints/queue';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
export const usePruneQueue = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetQueueStatusQuery, useResumeProcessorMutation } from 'services/api/endpoints/queue';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
export const useResumeProcessor = () => {
|
||||
const isConnected = useStore($isConnected);
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { Progress } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useCurrentDestination } from 'features/queue/hooks/useCurrentDestination';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
import { $lastProgressEvent } from 'services/events/setEventListeners';
|
||||
import { $isConnected, $lastProgressEvent } from 'services/events/stores';
|
||||
|
||||
const ProgressBar = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Icon, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiWarningBold } from 'react-icons/pi';
|
||||
import { $isConnected } from 'services/events/stores';
|
||||
|
||||
const StatusIndicator = () => {
|
||||
const isConnected = useStore($isConnected);
|
||||
|
||||
@@ -78,15 +78,17 @@ export const queueApi = api.injectEndpoints({
|
||||
resetListQueryData(dispatch);
|
||||
/**
|
||||
* When a batch is enqueued, we need to update the queue status. While it might be templting to invalidate the
|
||||
* `SessionQueueStatus` tag here, this can introduce a race condition:
|
||||
* `SessionQueueStatus` tag here, this can introduce a race condition when the queue item executes quickly:
|
||||
*
|
||||
* - Enqueue batch via this query
|
||||
* - Enqueue via this query
|
||||
* - On success, we invalidate `SessionQueueStatus` tag - network request sent to server
|
||||
* - Network request received, response preparing/sending
|
||||
* - A queue item status changes and we receive a socket event w/ updated status
|
||||
* - Update status optimistically in socket handler
|
||||
* - Tag invalidation response received, but by now its payload has stale data
|
||||
* - Stale data is written to the cache
|
||||
* - The server gets the queue status request and responds, but this takes some time... in the meantime:
|
||||
* - The new queue item starts executing, and we receive a socket queue item status changed event
|
||||
* - We optimistically update the queue status in the queue item status changed socket handler
|
||||
* - At this point, the queue status is correct
|
||||
* - Finally, we get the queue status from the tag invalidation request - but it's reporting the queue status
|
||||
* from _before_ the last queue event
|
||||
* - The queue status is now incorrect!
|
||||
*
|
||||
* Ok, what if we just never did optimistic updates and invalidated the tag in the queue event handlers instead?
|
||||
* It's much simpler that way, but it causes a lot of network requests - 3 per queue item, as it moves from
|
||||
@@ -94,7 +96,18 @@ export const queueApi = api.injectEndpoints({
|
||||
*
|
||||
* We can do a bit of extra work here, incrementing the pending and total counts in the queue status, and do
|
||||
* similar optimistic updates in the socket handler. Because this optimistic update runs immediately after the
|
||||
* enqueue network request, it should always occur _before_ the next queue event, so no race condition.
|
||||
* enqueue network request, it should always occur _before_ the next queue event, so no race condition:
|
||||
*
|
||||
* - Enqueue batch via this query
|
||||
* - On success, optimistically update - this happens immediately on the HTTP OK - before the next queue event
|
||||
* - At this point, the queue status is correct
|
||||
* - A queue item status changes and we receive a socket event w/ updated status
|
||||
* - Update status optimistically in socket handler
|
||||
* - Queue status is still correct
|
||||
*
|
||||
* This problem occurs most commonly with canvas filters like Canny edge detection, which are single-node
|
||||
* graphs that execute very quickly. Image generation graphs take long enough to not trigger this race
|
||||
* condition - even when all nodes are cached on the server.
|
||||
*/
|
||||
dispatch(
|
||||
queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => {
|
||||
|
||||
@@ -17,13 +17,9 @@ const isCanvasOutputNode = (data: S['InvocationCompleteEvent']) => {
|
||||
return data.invocation_source_id.split(':')[0] === 'canvas_output';
|
||||
};
|
||||
|
||||
export const buildOnInvocationComplete = (
|
||||
getState: () => RootState,
|
||||
dispatch: AppDispatch,
|
||||
nodeTypeDenylist: string[],
|
||||
setLastProgressEvent: (event: S['InvocationDenoiseProgressEvent'] | null) => void,
|
||||
setLastCanvasProgressEvent: (event: S['InvocationDenoiseProgressEvent'] | null) => void
|
||||
) => {
|
||||
const nodeTypeDenylist = ['load_image', 'image'];
|
||||
|
||||
export const buildOnInvocationComplete = (getState: () => RootState, dispatch: AppDispatch) => {
|
||||
const addImageToGallery = (imageDTO: ImageDTO) => {
|
||||
if (imageDTO.is_intermediate) {
|
||||
return;
|
||||
@@ -113,7 +109,7 @@ export const buildOnInvocationComplete = (
|
||||
}
|
||||
};
|
||||
|
||||
const handleOriginGeneration = async (data: S['InvocationCompleteEvent']) => {
|
||||
const handleOriginCanvas = async (data: S['InvocationCompleteEvent']) => {
|
||||
const imageDTO = await getResultImageDTO(data);
|
||||
|
||||
if (!imageDTO) {
|
||||
@@ -121,6 +117,7 @@ export const buildOnInvocationComplete = (
|
||||
}
|
||||
|
||||
if (data.destination === 'canvas') {
|
||||
// TODO(psyche): Can/should we let canvas handle this itself?
|
||||
if (isCanvasOutputNode(data)) {
|
||||
if (data.result.type === 'canvas_v2_mask_and_crop_output') {
|
||||
const { offset_x, offset_y } = data.result;
|
||||
@@ -131,8 +128,7 @@ export const buildOnInvocationComplete = (
|
||||
addImageToGallery(imageDTO);
|
||||
}
|
||||
} else if (!imageDTO.is_intermediate) {
|
||||
// session.mode === 'generate'
|
||||
setLastCanvasProgressEvent(null);
|
||||
// Desintaion is gallery
|
||||
addImageToGallery(imageDTO);
|
||||
}
|
||||
};
|
||||
@@ -151,15 +147,17 @@ export const buildOnInvocationComplete = (
|
||||
`Invocation complete (${data.invocation.type}, ${data.invocation_source_id})`
|
||||
);
|
||||
|
||||
// Update the node execution states - the image output is handled below
|
||||
if (nodeTypeDenylist.includes(data.invocation.type)) {
|
||||
log.trace('Skipping node type denylisted');
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.origin === 'workflows') {
|
||||
await handleOriginWorkflows(data);
|
||||
} else if (data.origin === 'generation') {
|
||||
await handleOriginGeneration(data);
|
||||
} else if (data.origin === 'canvas') {
|
||||
await handleOriginCanvas(data);
|
||||
} else {
|
||||
await handleOriginOther(data);
|
||||
}
|
||||
|
||||
setLastProgressEvent(null);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,48 +1,40 @@
|
||||
import { ExternalLink } from '@invoke-ai/ui-library';
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { getIsCancelled } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners';
|
||||
import { socketConnected } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
|
||||
import { $queueId } from 'app/store/nanostores/queueId';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $lastCanvasProgressEvent } from 'features/controlLayers/store/canvasSlice';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import { api, LIST_TAG } from 'services/api';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
||||
import type { S } from 'services/api/types';
|
||||
import { buildOnInvocationComplete } from 'services/events/onInvocationComplete';
|
||||
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
|
||||
import type { Socket } from 'socket.io-client';
|
||||
|
||||
export const socketConnected = createAction('socket/connected');
|
||||
import { $lastProgressEvent } from './stores';
|
||||
|
||||
const log = logger('events');
|
||||
|
||||
type SetEventListenersArg = {
|
||||
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
store: AppStore;
|
||||
setIsConnected: (isConnected: boolean) => void;
|
||||
};
|
||||
|
||||
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
|
||||
const nodeTypeDenylist = ['load_image', 'image'];
|
||||
export const $lastProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
|
||||
export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val));
|
||||
export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null);
|
||||
export const $isProgressFromCanvas = computed($lastProgressEvent, (val) => val?.destination === 'canvas');
|
||||
|
||||
export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }: SetEventListenersArg) => {
|
||||
export const setEventListeners = ({ socket, store, setIsConnected }: SetEventListenersArg) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
socket.on('connect', () => {
|
||||
log.debug('Connected');
|
||||
setIsConnected(true);
|
||||
@@ -54,14 +46,12 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
socket.emit('subscribe_bulk_download', { bulk_download_id });
|
||||
}
|
||||
$lastProgressEvent.set(null);
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
});
|
||||
|
||||
socket.on('connect_error', (error) => {
|
||||
log.debug('Connect error');
|
||||
setIsConnected(false);
|
||||
$lastProgressEvent.set(null);
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
if (error && error.message) {
|
||||
const data: string | undefined = (error as unknown as { data: string | undefined }).data;
|
||||
if (data === 'ERR_UNAUTHENTICATED') {
|
||||
@@ -78,7 +68,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
socket.on('disconnect', () => {
|
||||
log.debug('Disconnected');
|
||||
$lastProgressEvent.set(null);
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
setIsConnected(false);
|
||||
});
|
||||
|
||||
@@ -93,24 +82,7 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
});
|
||||
|
||||
socket.on('invocation_denoise_progress', (data) => {
|
||||
const {
|
||||
invocation_source_id,
|
||||
invocation,
|
||||
step,
|
||||
total_steps,
|
||||
progress_image,
|
||||
origin,
|
||||
destination,
|
||||
percentage,
|
||||
session_id,
|
||||
batch_id,
|
||||
} = data;
|
||||
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
|
||||
// progress update after the session has been cancelled.
|
||||
return;
|
||||
}
|
||||
const { invocation_source_id, invocation, step, total_steps, progress_image, origin, percentage } = data;
|
||||
|
||||
log.trace(
|
||||
{ data } as SerializableObject,
|
||||
@@ -128,11 +100,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
}
|
||||
|
||||
// This event is only relevant for the canvas
|
||||
if (destination === 'canvas') {
|
||||
$lastCanvasProgressEvent.set(data);
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('invocation_error', (data) => {
|
||||
@@ -152,13 +119,7 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
}
|
||||
});
|
||||
|
||||
const onInvocationComplete = buildOnInvocationComplete(
|
||||
getState,
|
||||
dispatch,
|
||||
nodeTypeDenylist,
|
||||
$lastProgressEvent.set,
|
||||
$lastCanvasProgressEvent.set
|
||||
);
|
||||
const onInvocationComplete = buildOnInvocationComplete(getState, dispatch);
|
||||
socket.on('invocation_complete', onInvocationComplete);
|
||||
|
||||
socket.on('model_load_started', (data) => {
|
||||
@@ -379,7 +340,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
origin,
|
||||
} = data;
|
||||
|
||||
log.debug({ data }, `Queue item ${item_id} status updated: ${status}`);
|
||||
@@ -402,12 +362,17 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
})
|
||||
);
|
||||
|
||||
// Update the queue status (we do not get the processor status here)
|
||||
// Optimistic update of the queue status. We prefer to do an optimistic update over tag invalidation due to the
|
||||
// frequency of `queue_item_status_changed` events.
|
||||
dispatch(
|
||||
queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => {
|
||||
if (!draft) {
|
||||
return;
|
||||
}
|
||||
/**
|
||||
* Update the queue status - though the getQueueStatus query response contains the processor status (i.e. running
|
||||
* or paused), that data is not provided in the event we are handling. So we can only update `draft.queue` here.
|
||||
*/
|
||||
Object.assign(draft.queue, queue_status);
|
||||
})
|
||||
);
|
||||
@@ -442,11 +407,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
} else if (status === 'failed' && error_type) {
|
||||
const isLocal = getState().config.isLocal ?? true;
|
||||
const sessionId = session_id;
|
||||
$lastProgressEvent.set(null);
|
||||
|
||||
if (origin === 'canvas') {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
|
||||
toast({
|
||||
id: `INVOCATION_ERROR_${error_type}`,
|
||||
@@ -463,13 +423,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
/>
|
||||
),
|
||||
});
|
||||
} else if (status === 'canceled') {
|
||||
$lastProgressEvent.set(null);
|
||||
if (origin === 'canvas') {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
} else if (status === 'completed') {
|
||||
$lastProgressEvent.set(null);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
12
invokeai/frontend/web/src/services/events/stores.ts
Normal file
12
invokeai/frontend/web/src/services/events/stores.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import { atom, computed, map } from 'nanostores';
|
||||
import type { S } from 'services/api/types';
|
||||
import type { AppSocket } from 'services/events/types';
|
||||
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
|
||||
|
||||
export const $socket = atom<AppSocket | null>(null);
|
||||
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
|
||||
export const $isConnected = atom<boolean>(false);
|
||||
export const $lastProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
|
||||
export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val));
|
||||
export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null);
|
||||
export const $isProgressFromCanvas = computed($lastProgressEvent, (val) => val?.destination === 'canvas');
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { S } from 'services/api/types';
|
||||
import type { Socket } from 'socket.io-client';
|
||||
|
||||
type ClientEmitSubscribeQueue = { queue_id: string };
|
||||
type ClientEmitUnsubscribeQueue = ClientEmitSubscribeQueue;
|
||||
@@ -40,3 +41,5 @@ export type ClientToServerEvents = {
|
||||
subscribe_bulk_download: (payload: ClientEmitSubscribeBulkDownload) => void;
|
||||
unsubscribe_bulk_download: (payload: ClientEmitUnsubscribeBulkDownload) => void;
|
||||
};
|
||||
|
||||
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;
|
||||
|
||||
@@ -3,14 +3,17 @@ import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { $isDebugging } from 'app/store/nanostores/isDebugging';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import type { MapStore } from 'nanostores';
|
||||
import { atom, map } from 'nanostores';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { selectQueueStatus } from 'services/api/endpoints/queue';
|
||||
import { setEventListeners } from 'services/events/setEventListeners';
|
||||
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
|
||||
import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client';
|
||||
import type { AppSocket } from 'services/events/types';
|
||||
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
|
||||
import { io } from 'socket.io-client';
|
||||
|
||||
import { $isConnected, $lastProgressEvent, $socket, $socketOptions } from './stores';
|
||||
|
||||
// Inject socket options and url into window for debugging
|
||||
declare global {
|
||||
interface Window {
|
||||
@@ -18,19 +21,12 @@ declare global {
|
||||
}
|
||||
}
|
||||
|
||||
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;
|
||||
|
||||
export const $socket = atom<AppSocket | null>(null);
|
||||
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
|
||||
|
||||
const $isSocketInitialized = atom<boolean>(false);
|
||||
export const $isConnected = atom<boolean>(false);
|
||||
|
||||
/**
|
||||
* Initializes the socket.io connection and sets up event listeners.
|
||||
*/
|
||||
export const useSocketIO = () => {
|
||||
const { dispatch, getState } = useAppStore();
|
||||
useAssertSingleton('useSocketIO');
|
||||
const store = useAppStore();
|
||||
const baseUrl = useStore($baseUrl);
|
||||
const authToken = useStore($authToken);
|
||||
const addlSocketOptions = useStore($socketOptions);
|
||||
@@ -61,14 +57,11 @@ export const useSocketIO = () => {
|
||||
}, [authToken, addlSocketOptions, baseUrl]);
|
||||
|
||||
useEffect(() => {
|
||||
if ($isSocketInitialized.get()) {
|
||||
// Singleton!
|
||||
return;
|
||||
}
|
||||
|
||||
const socket: AppSocket = io(socketUrl, socketOptions);
|
||||
$socket.set(socket);
|
||||
setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set });
|
||||
|
||||
setEventListeners({ socket, store, setIsConnected: $isConnected.set });
|
||||
|
||||
socket.connect();
|
||||
|
||||
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
|
||||
@@ -78,7 +71,12 @@ export const useSocketIO = () => {
|
||||
console.log('Socket initialized', socket);
|
||||
}
|
||||
|
||||
$isSocketInitialized.set(true);
|
||||
const unsubscribeQueueStatusListener = store.subscribe(() => {
|
||||
const queueStatusData = selectQueueStatus(store.getState()).data;
|
||||
if (!queueStatusData || queueStatusData.queue.in_progress === 0) {
|
||||
$lastProgressEvent.set(null);
|
||||
}
|
||||
});
|
||||
|
||||
return () => {
|
||||
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
|
||||
@@ -87,8 +85,8 @@ export const useSocketIO = () => {
|
||||
/* eslint-disable-next-line no-console */
|
||||
console.log('Socket teardown', socket);
|
||||
}
|
||||
unsubscribeQueueStatusListener();
|
||||
socket.disconnect();
|
||||
$isSocketInitialized.set(false);
|
||||
};
|
||||
}, [dispatch, getState, socketOptions, socketUrl]);
|
||||
}, [socketOptions, socketUrl, store]);
|
||||
};
|
||||
Reference in New Issue
Block a user