feat(ui): rework progress event handling

- Canvas manages its own progress socket event listeners and progress event data.
- Remove cancellations listener jank.
- Dip into low-level redux subscription API to watch for queue status changes, clearing the last "global" progress event when the queue has nothing in progress. Could also do this in a useEffect I guess.
- Had to shuffle some things around to prevent circular imports, so there are a lot of tiny changes here.
This commit is contained in:
psychedelicious
2024-09-17 22:42:53 +10:00
committed by Kent Keirsey
parent b08a66ecaf
commit 7db4d26837
40 changed files with 161 additions and 305 deletions

View File

@@ -1,5 +1,4 @@
import { Box, useGlobalModifiersInit } from '@invoke-ai/ui-library';
import { useSocketIO } from 'app/hooks/useSocketIO';
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
import { useLogger } from 'app/logging/useLogger';
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
@@ -31,6 +30,7 @@ import { size } from 'lodash-es';
import { memo, useCallback, useEffect } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
import { useSocketIO } from 'services/events/useSocketIO';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import PreselectedImage from './PreselectedImage';

View File

@@ -1,7 +1,6 @@
import 'i18n';
import type { Middleware } from '@reduxjs/toolkit';
import { $socketOptions } from 'app/hooks/useSocketIO';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $customNavComponent } from 'app/store/nanostores/customNavComponent';
@@ -24,6 +23,7 @@ import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useMemo } from 'react';
import { Provider } from 'react-redux';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
import { $socketOptions } from 'services/events/stores';
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
const App = lazy(() => import('./App'));

View File

@@ -9,7 +9,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addCancellationsListeners } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
@@ -73,15 +72,6 @@ addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Canvas actions
// addCanvasSavedToGalleryListener(startAppListening);
// addCanvasMaskSavedToGalleryListener(startAppListening);
// addCanvasImageToControlNetListener(startAppListening);
// addCanvasMaskToControlNetListener(startAppListening);
// addCanvasDownloadedAsImageListener(startAppListening);
// addCanvasCopiedToClipboardListener(startAppListening);
// addCanvasMergedListener(startAppListening);
// addStagingAreaImageSavedListener(startAppListening);
// addCommitStagingAreaImageListener(startAppListening);
addStagingListeners(startAppListening);
// Socket.IO
@@ -121,6 +111,3 @@ addAdHocPostProcessingRequestedListener(startAppListening);
addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);
// addControlAdapterPreprocessor(startAppListening);
addCancellationsListeners(startAppListening);

View File

@@ -1,137 +0,0 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $lastCanvasProgressEvent } from 'features/controlLayers/store/canvasSlice';
import { queueApi } from 'services/api/endpoints/queue';
/**
* To prevent a race condition where a progress event arrives after a successful cancellation, we need to keep track of
* cancellations:
* - In the route handlers above, we track and update the cancellations object
* - When the user queues a, we should reset the cancellations, also handled int he route handlers above
* - When we get a progress event, we should check if the event is cancelled before setting the event
*
* We have a few ways that cancellations are effected, so we need to track them all:
* - by queue item id (in this case, we will compare the session_id and not the item_id)
* - by batch id
* - by destination
* - by clearing the queue
*/
type Cancellations = {
sessionIds: Set<string>;
batchIds: Set<string>;
destinations: Set<string>;
clearQueue: boolean;
};
const resetCancellations = (): void => {
cancellations.clearQueue = false;
cancellations.sessionIds.clear();
cancellations.batchIds.clear();
cancellations.destinations.clear();
};
const cancellations: Cancellations = {
sessionIds: new Set(),
batchIds: new Set(),
destinations: new Set(),
clearQueue: false,
} as Readonly<Cancellations>;
/**
* Checks if an item is cancelled, used to prevent race conditions with event handling.
*
* To use this, provide the session_id, batch_id and destination from the event payload.
*/
export const getIsCancelled = (item: {
session_id: string;
batch_id: string;
destination?: string | null;
}): boolean => {
if (cancellations.clearQueue) {
return true;
}
if (cancellations.sessionIds.has(item.session_id)) {
return true;
}
if (cancellations.batchIds.has(item.batch_id)) {
return true;
}
if (item.destination && cancellations.destinations.has(item.destination)) {
return true;
}
return false;
};
export const addCancellationsListeners = (startAppListening: AppStartListening) => {
// When we get a cancellation, we may need to clear the last progress event - next few listeners handle those cases.
// Maybe we could use the `getIsCancelled` util here, but I think that could introduce _another_ race condition...
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
effect: () => {
resetCancellations();
},
});
startAppListening({
matcher: queueApi.endpoints.cancelByBatchDestination.matchFulfilled,
effect: (action) => {
cancellations.destinations.add(action.meta.arg.originalArgs.destination);
const event = $lastCanvasProgressEvent.get();
if (!event) {
return;
}
const { session_id, batch_id, destination } = event;
if (getIsCancelled({ session_id, batch_id, destination })) {
$lastCanvasProgressEvent.set(null);
}
},
});
startAppListening({
matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
effect: (action) => {
cancellations.sessionIds.add(action.payload.session_id);
const event = $lastCanvasProgressEvent.get();
if (!event) {
return;
}
const { session_id, batch_id, destination } = event;
if (getIsCancelled({ session_id, batch_id, destination })) {
$lastCanvasProgressEvent.set(null);
}
},
});
startAppListening({
matcher: queueApi.endpoints.cancelByBatchIds.matchFulfilled,
effect: (action) => {
for (const batch_id of action.meta.arg.originalArgs.batch_ids) {
cancellations.batchIds.add(batch_id);
}
const event = $lastCanvasProgressEvent.get();
if (!event) {
return;
}
const { session_id, batch_id, destination } = event;
if (getIsCancelled({ session_id, batch_id, destination })) {
$lastCanvasProgressEvent.set(null);
}
},
});
startAppListening({
matcher: queueApi.endpoints.clearQueue.matchFulfilled,
effect: () => {
cancellations.clearQueue = true;
const event = $lastCanvasProgressEvent.get();
if (!event) {
return;
}
const { session_id, batch_id, destination } = event;
if (getIsCancelled({ session_id, batch_id, destination })) {
$lastCanvasProgressEvent.set(null);
}
},
});
};

View File

@@ -66,7 +66,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'generation', destination)
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'canvas', destination)
);
if (prepareBatchResult.isErr()) {

View File

@@ -15,7 +15,8 @@ import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilder
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import { utilitiesApi } from 'services/api/endpoints/utilities';
import { socketConnected } from 'services/events/setEventListeners';
import { socketConnected } from './socketConnected';
const matcher = isAnyOf(
positivePromptChanged,

View File

@@ -1,3 +1,4 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
@@ -6,11 +7,11 @@ import { atom } from 'nanostores';
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
import { socketConnected } from 'services/events/setEventListeners';
const log = logger('events');
const $isFirstConnection = atom(true);
export const socketConnected = createAction('socket/connected');
export const addSocketConnectedEventListener = (startAppListening: AppStartListening) => {
startAppListening({

View File

@@ -1,5 +1,4 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { $true } from 'app/store/nanostores/util';
import { useAppSelector } from 'app/store/storeHooks';
@@ -21,6 +20,7 @@ import i18n from 'i18next';
import { forEach, upperFirst } from 'lodash-es';
import { useMemo } from 'react';
import { getConnectedEdges } from 'reactflow';
import { $isConnected } from 'services/events/stores';
const LAYER_TYPE_TO_TKEY = {
reference_image: 'controlLayers.referenceImage',

View File

@@ -1,7 +1,6 @@
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
@@ -16,6 +15,7 @@ import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO, PostUploadAction } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
type Props = {
image: ImageWithDims | null;

View File

@@ -1,5 +1,4 @@
import { useStore } from '@nanostores/react';
import { $socket } from 'app/hooks/useSocketIO';
import { logger } from 'app/logging/logger';
import { useAppStore } from 'app/store/nanostores/store';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
@@ -7,6 +6,7 @@ import { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { $canvasManager } from 'features/controlLayers/store/canvasSlice';
import Konva from 'konva';
import { useLayoutEffect, useState } from 'react';
import { $socket } from 'services/events/stores';
import { useDevicePixelRatio } from 'use-device-pixel-ratio';
const log = logger('canvas');

View File

@@ -1,4 +1,3 @@
import type { AppSocket } from 'app/hooks/useSocketIO';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import type { SerializableObject } from 'common/types';
@@ -31,6 +30,7 @@ import Konva from 'konva';
import type { Atom } from 'nanostores';
import { computed } from 'nanostores';
import type { Logger } from 'roarr';
import type { AppSocket } from 'services/events/types';
import { assert } from 'tsafe';
import { CanvasBackgroundModule } from './CanvasBackgroundModule';

View File

@@ -4,7 +4,10 @@ import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'
import { getPrefixedId, loadImage } from 'features/controlLayers/konva/util';
import { selectShowProgressOnCanvas } from 'features/controlLayers/store/canvasSettingsSlice';
import Konva from 'konva';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import { selectCanvasQueueCounts } from 'services/api/endpoints/queue';
import type { S } from 'services/api/types';
export class CanvasProgressImageModule extends CanvasModuleBase {
readonly type = 'progress_image';
@@ -23,7 +26,8 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
imageElement: HTMLImageElement | null = null;
subscriptions = new Set<() => void>();
$lastProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
hasActiveGeneration: boolean = false;
mutex: Mutex = new Mutex();
constructor(manager: CanvasManager) {
@@ -41,11 +45,50 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
image: null,
};
this.subscriptions.add(this.manager.stateApi.$lastCanvasProgressEvent.listen(this.render));
this.subscriptions.add(this.manager.stagingArea.$shouldShowStagedImage.listen(this.render));
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectShowProgressOnCanvas, this.render));
this.subscriptions.add(this.setSocketEventListeners());
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectCanvasQueueCounts, ({ data }) => {
if (data && (data.in_progress > 0 || data.pending > 0)) {
this.hasActiveGeneration = true;
} else {
this.hasActiveGeneration = false;
this.$lastProgressEvent.set(null);
}
})
);
this.subscriptions.add(this.$lastProgressEvent.listen(this.render));
}
setSocketEventListeners = (): (() => void) => {
const progressListener = (data: S['InvocationDenoiseProgressEvent']) => {
if (data.destination !== 'canvas') {
return;
}
if (!this.hasActiveGeneration) {
return;
}
this.$lastProgressEvent.set(data);
};
const clearProgress = () => {
this.$lastProgressEvent.set(null);
};
this.manager.socket.on('invocation_denoise_progress', progressListener);
this.manager.socket.on('connect', clearProgress);
this.manager.socket.on('connect_error', clearProgress);
this.manager.socket.on('disconnect', clearProgress);
return () => {
this.manager.socket.off('invocation_denoise_progress', progressListener);
this.manager.socket.off('connect', clearProgress);
this.manager.socket.off('connect_error', clearProgress);
this.manager.socket.off('disconnect', clearProgress);
};
};
getNodes = () => {
return [this.konva.group];
};
@@ -53,7 +96,7 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
render = async () => {
const release = await this.mutex.acquire();
const event = this.manager.stateApi.$lastCanvasProgressEvent.get();
const event = this.$lastProgressEvent.get();
const showProgressOnCanvas = this.manager.stateApi.runSelector(selectShowProgressOnCanvas);
if (!event || !showProgressOnCanvas) {

View File

@@ -96,7 +96,7 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
if (!this.image.isLoading && !this.image.isError) {
await this.image.update({ ...this.image.state, image: imageDTOToImageWithDims(imageDTO) }, true);
this.manager.stateApi.$lastCanvasProgressEvent.set(null);
this.manager.progressImage.$lastProgressEvent.set(null);
}
this.image.konva.group.visible(shouldShowStagedImage);
} else {

View File

@@ -15,7 +15,6 @@ import {
settingsEraserWidthChanged,
} from 'features/controlLayers/store/canvasSettingsSlice';
import {
$lastCanvasProgressEvent,
bboxChangedFromCanvas,
entityBrushLineAdded,
entityEraserLineAdded,
@@ -382,12 +381,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
*/
$isRasterizing = computed(this.$rasterizingAdapter, (rasterizingAdapter) => Boolean(rasterizingAdapter));
/**
* The last canvas progress event. This is set in a global event listener. The staging area may set it to null when it
* consumes the event.
*/
$lastCanvasProgressEvent = $lastCanvasProgressEvent;
/**
* Whether the space key is currently pressed.
*/

View File

@@ -30,13 +30,7 @@ import type { IRect } from 'konva/lib/types';
import { merge, omit } from 'lodash-es';
import { atom } from 'nanostores';
import type { UndoableOptions } from 'redux-undo';
import type {
ControlNetModelConfig,
ImageDTO,
IPAdapterModelConfig,
S,
T2IAdapterModelConfig,
} from 'services/api/types';
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import type {
@@ -1236,7 +1230,6 @@ function actionsThrottlingFilter(action: UnknownAction) {
return true;
}
export const $lastCanvasProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
/**
* The global canvas manager instance.
*/

View File

@@ -1,12 +1,12 @@
import type { IconButtonProps } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSelectionCount } from 'features/gallery/store/gallerySelectors';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { $isConnected } from 'services/events/stores';
type DeleteImageButtonProps = Omit<IconButtonProps, 'aria-label'> & {
onClick: () => void;

View File

@@ -1,7 +1,6 @@
import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { $isConnected } from 'app/hooks/useSocketIO';
import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { INTERACTION_SCOPES } from 'common/hooks/interactionScopes';
@@ -30,7 +29,7 @@ import {
PiRulerBold,
} from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { $progressImage } from 'services/events/setEventListeners';
import { $isConnected, $progressImage } from 'services/events/stores';
const CurrentImageButtons = () => {
const dispatch = useAppDispatch();

View File

@@ -15,7 +15,7 @@ import { memo, useCallback, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImageBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { $hasProgress, $isProgressFromCanvas } from 'services/events/setEventListeners';
import { $hasProgress, $isProgressFromCanvas } from 'services/events/stores';
import ProgressImage from './ProgressImage';

View File

@@ -5,7 +5,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { memo, useMemo } from 'react';
import { $isProgressFromCanvas, $progressImage } from 'services/events/setEventListeners';
import { $isProgressFromCanvas, $progressImage } from 'services/events/stores';
const selectShouldAntialiasProgressImage = createSelector(
selectSystemSlice,

View File

@@ -13,7 +13,7 @@ import type { CSSProperties, PropsWithChildren } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import type { NodeProps } from 'reactflow';
import { $lastProgressEvent } from 'services/events/setEventListeners';
import { $lastProgressEvent } from 'services/events/stores';
const CurrentImageNode = (props: NodeProps) => {
const imageDTO = useAppSelector(selectLastSelectedImage);

View File

@@ -1,7 +1,6 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
@@ -13,6 +12,7 @@ import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { PostUploadAction } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
import type { FieldComponentProps } from './types';

View File

@@ -11,7 +11,7 @@ export const prepareLinearUIBatch = (
prepend: boolean,
noise: Invocation<'noise' | 'flux_denoise'>,
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>,
origin: 'generation' | 'workflows' | 'upscaling',
origin: 'canvas' | 'workflows' | 'upscaling',
destination: 'canvas' | 'gallery'
): BatchConfig => {
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.params;

View File

@@ -1,6 +1,5 @@
import { ConfirmationAlertDialog, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch } from 'app/store/storeHooks';
import { buildUseBoolean } from 'common/hooks/useBoolean';
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
@@ -8,6 +7,7 @@ import { toast } from 'features/toast/toast';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useClearQueueMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import { $isConnected } from 'services/events/stores';
const [useClearQueueConfirmationAlertDialog] = buildUseBoolean(false);

View File

@@ -1,9 +1,9 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCancelByBatchIdsMutation, useGetBatchStatusQuery } from 'services/api/endpoints/queue';
import { $isConnected } from 'services/events/stores';
export const useCancelBatch = (batch_id: string) => {
const isConnected = useStore($isConnected);

View File

@@ -1,10 +1,10 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { isNil } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useCancelQueueItemMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import { $isConnected } from 'services/events/stores';
export const useCancelCurrentQueueItem = () => {
const isConnected = useStore($isConnected);

View File

@@ -1,9 +1,9 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCancelQueueItemMutation } from 'services/api/endpoints/queue';
import { $isConnected } from 'services/events/stores';
export const useCancelQueueItem = (item_id: number) => {
const isConnected = useStore($isConnected);

View File

@@ -1,9 +1,9 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useClearInvocationCacheMutation, useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
import { $isConnected } from 'services/events/stores';
export const useClearInvocationCache = () => {
const { t } = useTranslation();

View File

@@ -1,9 +1,9 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDisableInvocationCacheMutation, useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
import { $isConnected } from 'services/events/stores';
export const useDisableInvocationCache = () => {
const { t } = useTranslation();

View File

@@ -1,9 +1,9 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useEnableInvocationCacheMutation, useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
import { $isConnected } from 'services/events/stores';
export const useEnableInvocationCache = () => {
const { t } = useTranslation();

View File

@@ -1,9 +1,9 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetQueueStatusQuery, usePauseProcessorMutation } from 'services/api/endpoints/queue';
import { $isConnected } from 'services/events/stores';
export const usePauseProcessor = () => {
const { t } = useTranslation();

View File

@@ -1,11 +1,11 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useAppDispatch } from 'app/store/storeHooks';
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetQueueStatusQuery, usePruneQueueMutation } from 'services/api/endpoints/queue';
import { $isConnected } from 'services/events/stores';
export const usePruneQueue = () => {
const dispatch = useAppDispatch();

View File

@@ -1,9 +1,9 @@
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetQueueStatusQuery, useResumeProcessorMutation } from 'services/api/endpoints/queue';
import { $isConnected } from 'services/events/stores';
export const useResumeProcessor = () => {
const isConnected = useStore($isConnected);

View File

@@ -1,11 +1,10 @@
import { Progress } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { useCurrentDestination } from 'features/queue/hooks/useCurrentDestination';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import { $lastProgressEvent } from 'services/events/setEventListeners';
import { $isConnected, $lastProgressEvent } from 'services/events/stores';
const ProgressBar = () => {
const { t } = useTranslation();

View File

@@ -1,9 +1,9 @@
import { Icon, Tooltip } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $isConnected } from 'app/hooks/useSocketIO';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiWarningBold } from 'react-icons/pi';
import { $isConnected } from 'services/events/stores';
const StatusIndicator = () => {
const isConnected = useStore($isConnected);

View File

@@ -78,15 +78,17 @@ export const queueApi = api.injectEndpoints({
resetListQueryData(dispatch);
/**
* When a batch is enqueued, we need to update the queue status. While it might be templting to invalidate the
* `SessionQueueStatus` tag here, this can introduce a race condition:
* `SessionQueueStatus` tag here, this can introduce a race condition when the queue item executes quickly:
*
* - Enqueue batch via this query
* - Enqueue via this query
* - On success, we invalidate `SessionQueueStatus` tag - network request sent to server
* - Network request received, response preparing/sending
* - A queue item status changes and we receive a socket event w/ updated status
* - Update status optimistically in socket handler
* - Tag invalidation response received, but by now its payload has stale data
* - Stale data is written to the cache
* - The server gets the queue status request and responds, but this takes some time... in the meantime:
* - The new queue item starts executing, and we receive a socket queue item status changed event
* - We optimistically update the queue status in the queue item status changed socket handler
* - At this point, the queue status is correct
* - Finally, we get the queue status from the tag invalidation request - but it's reporting the queue status
* from _before_ the last queue event
* - The queue status is now incorrect!
*
* Ok, what if we just never did optimistic updates and invalidated the tag in the queue event handlers instead?
* It's much simpler that way, but it causes a lot of network requests - 3 per queue item, as it moves from
@@ -94,7 +96,18 @@ export const queueApi = api.injectEndpoints({
*
* We can do a bit of extra work here, incrementing the pending and total counts in the queue status, and do
* similar optimistic updates in the socket handler. Because this optimistic update runs immediately after the
* enqueue network request, it should always occur _before_ the next queue event, so no race condition.
* enqueue network request, it should always occur _before_ the next queue event, so no race condition:
*
* - Enqueue batch via this query
* - On success, optimistically update - this happens immediately on the HTTP OK - before the next queue event
* - At this point, the queue status is correct
* - A queue item status changes and we receive a socket event w/ updated status
* - Update status optimistically in socket handler
* - Queue status is still correct
*
* This problem occurs most commonly with canvas filters like Canny edge detection, which are single-node
* graphs that execute very quickly. Image generation graphs take long enough to not trigger this race
* condition - even when all nodes are cached on the server.
*/
dispatch(
queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => {

View File

@@ -17,13 +17,9 @@ const isCanvasOutputNode = (data: S['InvocationCompleteEvent']) => {
return data.invocation_source_id.split(':')[0] === 'canvas_output';
};
export const buildOnInvocationComplete = (
getState: () => RootState,
dispatch: AppDispatch,
nodeTypeDenylist: string[],
setLastProgressEvent: (event: S['InvocationDenoiseProgressEvent'] | null) => void,
setLastCanvasProgressEvent: (event: S['InvocationDenoiseProgressEvent'] | null) => void
) => {
const nodeTypeDenylist = ['load_image', 'image'];
export const buildOnInvocationComplete = (getState: () => RootState, dispatch: AppDispatch) => {
const addImageToGallery = (imageDTO: ImageDTO) => {
if (imageDTO.is_intermediate) {
return;
@@ -113,7 +109,7 @@ export const buildOnInvocationComplete = (
}
};
const handleOriginGeneration = async (data: S['InvocationCompleteEvent']) => {
const handleOriginCanvas = async (data: S['InvocationCompleteEvent']) => {
const imageDTO = await getResultImageDTO(data);
if (!imageDTO) {
@@ -121,6 +117,7 @@ export const buildOnInvocationComplete = (
}
if (data.destination === 'canvas') {
// TODO(psyche): Can/should we let canvas handle this itself?
if (isCanvasOutputNode(data)) {
if (data.result.type === 'canvas_v2_mask_and_crop_output') {
const { offset_x, offset_y } = data.result;
@@ -131,8 +128,7 @@ export const buildOnInvocationComplete = (
addImageToGallery(imageDTO);
}
} else if (!imageDTO.is_intermediate) {
// session.mode === 'generate'
setLastCanvasProgressEvent(null);
// Desintaion is gallery
addImageToGallery(imageDTO);
}
};
@@ -151,15 +147,17 @@ export const buildOnInvocationComplete = (
`Invocation complete (${data.invocation.type}, ${data.invocation_source_id})`
);
// Update the node execution states - the image output is handled below
if (nodeTypeDenylist.includes(data.invocation.type)) {
log.trace('Skipping node type denylisted');
return;
}
if (data.origin === 'workflows') {
await handleOriginWorkflows(data);
} else if (data.origin === 'generation') {
await handleOriginGeneration(data);
} else if (data.origin === 'canvas') {
await handleOriginCanvas(data);
} else {
await handleOriginOther(data);
}
setLastProgressEvent(null);
};
};

View File

@@ -1,48 +1,40 @@
import { ExternalLink } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { getIsCancelled } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners';
import { socketConnected } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
import { $queueId } from 'app/store/nanostores/queueId';
import type { AppDispatch, RootState } from 'app/store/store';
import type { AppStore } from 'app/store/store';
import type { SerializableObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import { $lastCanvasProgressEvent } from 'features/controlLayers/store/canvasSlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
import { atom, computed } from 'nanostores';
import { api, LIST_TAG } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import type { S } from 'services/api/types';
import { buildOnInvocationComplete } from 'services/events/onInvocationComplete';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import type { Socket } from 'socket.io-client';
export const socketConnected = createAction('socket/connected');
import { $lastProgressEvent } from './stores';
const log = logger('events');
type SetEventListenersArg = {
socket: Socket<ServerToClientEvents, ClientToServerEvents>;
dispatch: AppDispatch;
getState: () => RootState;
store: AppStore;
setIsConnected: (isConnected: boolean) => void;
};
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
const nodeTypeDenylist = ['load_image', 'image'];
export const $lastProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val));
export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null);
export const $isProgressFromCanvas = computed($lastProgressEvent, (val) => val?.destination === 'canvas');
export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }: SetEventListenersArg) => {
export const setEventListeners = ({ socket, store, setIsConnected }: SetEventListenersArg) => {
const { dispatch, getState } = store;
socket.on('connect', () => {
log.debug('Connected');
setIsConnected(true);
@@ -54,14 +46,12 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
socket.emit('subscribe_bulk_download', { bulk_download_id });
}
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
});
socket.on('connect_error', (error) => {
log.debug('Connect error');
setIsConnected(false);
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
if (error && error.message) {
const data: string | undefined = (error as unknown as { data: string | undefined }).data;
if (data === 'ERR_UNAUTHENTICATED') {
@@ -78,7 +68,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
socket.on('disconnect', () => {
log.debug('Disconnected');
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
setIsConnected(false);
});
@@ -93,24 +82,7 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
});
socket.on('invocation_denoise_progress', (data) => {
const {
invocation_source_id,
invocation,
step,
total_steps,
progress_image,
origin,
destination,
percentage,
session_id,
batch_id,
} = data;
if (getIsCancelled({ session_id, batch_id, destination })) {
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
// progress update after the session has been cancelled.
return;
}
const { invocation_source_id, invocation, step, total_steps, progress_image, origin, percentage } = data;
log.trace(
{ data } as SerializableObject,
@@ -128,11 +100,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
upsertExecutionState(nes.nodeId, nes);
}
}
// This event is only relevant for the canvas
if (destination === 'canvas') {
$lastCanvasProgressEvent.set(data);
}
});
socket.on('invocation_error', (data) => {
@@ -152,13 +119,7 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
}
});
const onInvocationComplete = buildOnInvocationComplete(
getState,
dispatch,
nodeTypeDenylist,
$lastProgressEvent.set,
$lastCanvasProgressEvent.set
);
const onInvocationComplete = buildOnInvocationComplete(getState, dispatch);
socket.on('invocation_complete', onInvocationComplete);
socket.on('model_load_started', (data) => {
@@ -379,7 +340,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
error_type,
error_message,
error_traceback,
origin,
} = data;
log.debug({ data }, `Queue item ${item_id} status updated: ${status}`);
@@ -402,12 +362,17 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
})
);
// Update the queue status (we do not get the processor status here)
// Optimistic update of the queue status. We prefer to do an optimistic update over tag invalidation due to the
// frequency of `queue_item_status_changed` events.
dispatch(
queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => {
if (!draft) {
return;
}
/**
* Update the queue status - though the getQueueStatus query response contains the processor status (i.e. running
* or paused), that data is not provided in the event we are handling. So we can only update `draft.queue` here.
*/
Object.assign(draft.queue, queue_status);
})
);
@@ -442,11 +407,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
} else if (status === 'failed' && error_type) {
const isLocal = getState().config.isLocal ?? true;
const sessionId = session_id;
$lastProgressEvent.set(null);
if (origin === 'canvas') {
$lastCanvasProgressEvent.set(null);
}
toast({
id: `INVOCATION_ERROR_${error_type}`,
@@ -463,13 +423,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
/>
),
});
} else if (status === 'canceled') {
$lastProgressEvent.set(null);
if (origin === 'canvas') {
$lastCanvasProgressEvent.set(null);
}
} else if (status === 'completed') {
$lastProgressEvent.set(null);
}
});

View File

@@ -0,0 +1,12 @@
import { atom, computed, map } from 'nanostores';
import type { S } from 'services/api/types';
import type { AppSocket } from 'services/events/types';
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
export const $socket = atom<AppSocket | null>(null);
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
export const $isConnected = atom<boolean>(false);
export const $lastProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val));
export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null);
export const $isProgressFromCanvas = computed($lastProgressEvent, (val) => val?.destination === 'canvas');

View File

@@ -1,4 +1,5 @@
import type { S } from 'services/api/types';
import type { Socket } from 'socket.io-client';
type ClientEmitSubscribeQueue = { queue_id: string };
type ClientEmitUnsubscribeQueue = ClientEmitSubscribeQueue;
@@ -40,3 +41,5 @@ export type ClientToServerEvents = {
subscribe_bulk_download: (payload: ClientEmitSubscribeBulkDownload) => void;
unsubscribe_bulk_download: (payload: ClientEmitUnsubscribeBulkDownload) => void;
};
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;

View File

@@ -3,14 +3,17 @@ import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $isDebugging } from 'app/store/nanostores/isDebugging';
import { useAppStore } from 'app/store/nanostores/store';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import type { MapStore } from 'nanostores';
import { atom, map } from 'nanostores';
import { useEffect, useMemo } from 'react';
import { selectQueueStatus } from 'services/api/endpoints/queue';
import { setEventListeners } from 'services/events/setEventListeners';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client';
import type { AppSocket } from 'services/events/types';
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
import { io } from 'socket.io-client';
import { $isConnected, $lastProgressEvent, $socket, $socketOptions } from './stores';
// Inject socket options and url into window for debugging
declare global {
interface Window {
@@ -18,19 +21,12 @@ declare global {
}
}
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;
export const $socket = atom<AppSocket | null>(null);
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
const $isSocketInitialized = atom<boolean>(false);
export const $isConnected = atom<boolean>(false);
/**
* Initializes the socket.io connection and sets up event listeners.
*/
export const useSocketIO = () => {
const { dispatch, getState } = useAppStore();
useAssertSingleton('useSocketIO');
const store = useAppStore();
const baseUrl = useStore($baseUrl);
const authToken = useStore($authToken);
const addlSocketOptions = useStore($socketOptions);
@@ -61,14 +57,11 @@ export const useSocketIO = () => {
}, [authToken, addlSocketOptions, baseUrl]);
useEffect(() => {
if ($isSocketInitialized.get()) {
// Singleton!
return;
}
const socket: AppSocket = io(socketUrl, socketOptions);
$socket.set(socket);
setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set });
setEventListeners({ socket, store, setIsConnected: $isConnected.set });
socket.connect();
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
@@ -78,7 +71,12 @@ export const useSocketIO = () => {
console.log('Socket initialized', socket);
}
$isSocketInitialized.set(true);
const unsubscribeQueueStatusListener = store.subscribe(() => {
const queueStatusData = selectQueueStatus(store.getState()).data;
if (!queueStatusData || queueStatusData.queue.in_progress === 0) {
$lastProgressEvent.set(null);
}
});
return () => {
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
@@ -87,8 +85,8 @@ export const useSocketIO = () => {
/* eslint-disable-next-line no-console */
console.log('Socket teardown', socket);
}
unsubscribeQueueStatusListener();
socket.disconnect();
$isSocketInitialized.set(false);
};
}, [dispatch, getState, socketOptions, socketUrl]);
}, [socketOptions, socketUrl, store]);
};