diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 2e12f19541..54282f9966 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -9,6 +9,7 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted'; import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected'; import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload'; +import { addCancellationsListeners } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners'; import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear'; import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes'; import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked'; @@ -119,3 +120,5 @@ addDynamicPromptsListener(startAppListening); addSetDefaultSettingsListener(startAppListening); // addControlAdapterPreprocessor(startAppListening); + +addCancellationsListeners(startAppListening); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts index 178d8df0c7..8ebbea1166 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts @@ -12,7 +12,6 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/types'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { queueApi } from 'services/api/endpoints/queue'; -import { $lastCanvasProgressEvent } from 'services/events/setEventListeners'; import { assert } from 'tsafe'; const log = logger('canvas'); @@ -33,8 +32,6 @@ export const addStagingListeners = (startAppListening: AppStartListening) => { const { canceled } = await req.unwrap(); req.reset(); - $lastCanvasProgressEvent.set(null); - if (canceled > 0) { log.debug(`Canceled ${canceled} canvas batches`); toast({ diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/cancellationsListeners.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/cancellationsListeners.ts new file mode 100644 index 0000000000..5542f4b579 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/cancellationsListeners.ts @@ -0,0 +1,137 @@ +import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { queueApi } from 'services/api/endpoints/queue'; +import { $lastCanvasProgressEvent } from 'services/events/setEventListeners'; + +/** + * To prevent a race condition where a progress event arrives after a successful cancellation, we need to keep track of + * cancellations: + * - In the route handlers above, we track and update the cancellations object + * - When the user queues a, we should reset the cancellations, also handled int he route handlers above + * - When we get a progress event, we should check if the event is cancelled before setting the event + * + * We have a few ways that cancellations are effected, so we need to track them all: + * - by queue item id (in this case, we will compare the session_id and not the item_id) + * - by batch id + * - by destination + * - by clearing the queue + */ +type Cancellations = { + sessionIds: Set; + batchIds: Set; + destinations: Set; + clearQueue: boolean; +}; + +const resetCancellations = (): void => { + cancellations.clearQueue = false; + cancellations.sessionIds.clear(); + cancellations.batchIds.clear(); + cancellations.destinations.clear(); +}; + +const cancellations: Cancellations = { + sessionIds: new Set(), + batchIds: new Set(), + destinations: new Set(), + clearQueue: false, +} as Readonly; + +/** + * Checks if an item is cancelled, used to prevent race conditions with event handling. + * + * To use this, provide the session_id, batch_id and destination from the event payload. + */ +export const getIsCancelled = (item: { + session_id: string; + batch_id: string; + destination?: string | null; +}): boolean => { + if (cancellations.clearQueue) { + return true; + } + if (cancellations.sessionIds.has(item.session_id)) { + return true; + } + if (cancellations.batchIds.has(item.batch_id)) { + return true; + } + if (item.destination && cancellations.destinations.has(item.destination)) { + return true; + } + return false; +}; + +export const addCancellationsListeners = (startAppListening: AppStartListening) => { + // When we get a cancellation, we may need to clear the last progress event - next few listeners handle those cases. + // Maybe we could use the `getIsCancelled` util here, but I think that could introduce _another_ race condition... + startAppListening({ + matcher: queueApi.endpoints.enqueueBatch.matchFulfilled, + effect: () => { + resetCancellations(); + }, + }); + + startAppListening({ + matcher: queueApi.endpoints.cancelByBatchDestination.matchFulfilled, + effect: (action) => { + cancellations.destinations.add(action.meta.arg.originalArgs.destination); + + const event = $lastCanvasProgressEvent.get(); + if (!event) { + return; + } + const { session_id, batch_id, destination } = event; + if (getIsCancelled({ session_id, batch_id, destination })) { + $lastCanvasProgressEvent.set(null); + } + }, + }); + + startAppListening({ + matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled, + effect: (action) => { + cancellations.sessionIds.add(action.payload.session_id); + + const event = $lastCanvasProgressEvent.get(); + if (!event) { + return; + } + const { session_id, batch_id, destination } = event; + if (getIsCancelled({ session_id, batch_id, destination })) { + $lastCanvasProgressEvent.set(null); + } + }, + }); + + startAppListening({ + matcher: queueApi.endpoints.cancelByBatchIds.matchFulfilled, + effect: (action) => { + for (const batch_id of action.meta.arg.originalArgs.batch_ids) { + cancellations.batchIds.add(batch_id); + } + const event = $lastCanvasProgressEvent.get(); + if (!event) { + return; + } + const { session_id, batch_id, destination } = event; + if (getIsCancelled({ session_id, batch_id, destination })) { + $lastCanvasProgressEvent.set(null); + } + }, + }); + + startAppListening({ + matcher: queueApi.endpoints.clearQueue.matchFulfilled, + effect: () => { + cancellations.clearQueue = true; + const event = $lastCanvasProgressEvent.get(); + if (!event) { + return; + } + const { session_id, batch_id, destination } = event; + if (getIsCancelled({ session_id, batch_id, destination })) { + $lastCanvasProgressEvent.set(null); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index f0aaeb9ad6..93822ba474 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -1,6 +1,7 @@ import { ExternalLink } from '@invoke-ai/ui-library'; import { createAction } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; +import { getIsCancelled } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners'; import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId'; import { $queueId } from 'app/store/nanostores/queueId'; @@ -39,7 +40,6 @@ export const $lastProgressEvent = atom(null); export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val)); export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null); -const cancellations = new Set(); export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }: SetEventListenersArg) => { socket.on('connect', () => { @@ -54,7 +54,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } } $lastProgressEvent.set(null); $lastCanvasProgressEvent.set(null); - cancellations.clear(); }); socket.on('connect_error', (error) => { @@ -73,7 +72,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } }); } } - cancellations.clear(); }); socket.on('disconnect', () => { @@ -81,7 +79,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } $lastProgressEvent.set(null); $lastCanvasProgressEvent.set(null); setIsConnected(false); - cancellations.clear(); }); socket.on('invocation_started', (data) => { @@ -92,7 +89,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } nes.status = zNodeStatus.enum.IN_PROGRESS; upsertExecutionState(nes.nodeId, nes); } - cancellations.clear(); }); socket.on('invocation_denoise_progress', (data) => { @@ -106,9 +102,10 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } destination, percentage, session_id, + batch_id, } = data; - if (cancellations.has(session_id)) { + if (getIsCancelled({ session_id, batch_id, destination })) { // Do not update the progress if this session has been cancelled. This prevents a race condition where we get a // progress update after the session has been cancelled. return; @@ -131,7 +128,8 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } } } - if (origin === 'generation' && destination === 'canvas') { + // This event is only relevant for the canvas + if (destination === 'canvas') { $lastCanvasProgressEvent.set(data); } }); @@ -464,16 +462,13 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected } /> ), }); - cancellations.add(session_id); } else if (status === 'canceled') { $lastProgressEvent.set(null); if (origin === 'canvas') { $lastCanvasProgressEvent.set(null); } - cancellations.add(session_id); } else if (status === 'completed') { $lastProgressEvent.set(null); - cancellations.add(session_id); } });