mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-01 21:05:05 -05:00
fix(ui): race conditions with progress events
There's a race condition where we sometimes get progress events from canceled queue items, depending on the timing of the cancellation request and last event or two from the queue item. I can't imagine how to resolve this except by tracking all cancellations and ignoring events for cancelled items, which is implemented in this change.
This commit is contained in:
@@ -9,6 +9,7 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
|
||||
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addCancellationsListeners } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners';
|
||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||
@@ -119,3 +120,5 @@ addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
// addControlAdapterPreprocessor(startAppListening);
|
||||
|
||||
addCancellationsListeners(startAppListening);
|
||||
|
||||
@@ -12,7 +12,6 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('canvas');
|
||||
@@ -33,8 +32,6 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
const { canceled } = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
|
||||
if (canceled > 0) {
|
||||
log.debug(`Canceled ${canceled} canvas batches`);
|
||||
toast({
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
|
||||
|
||||
/**
|
||||
* To prevent a race condition where a progress event arrives after a successful cancellation, we need to keep track of
|
||||
* cancellations:
|
||||
* - In the route handlers above, we track and update the cancellations object
|
||||
* - When the user queues a, we should reset the cancellations, also handled int he route handlers above
|
||||
* - When we get a progress event, we should check if the event is cancelled before setting the event
|
||||
*
|
||||
* We have a few ways that cancellations are effected, so we need to track them all:
|
||||
* - by queue item id (in this case, we will compare the session_id and not the item_id)
|
||||
* - by batch id
|
||||
* - by destination
|
||||
* - by clearing the queue
|
||||
*/
|
||||
type Cancellations = {
|
||||
sessionIds: Set<string>;
|
||||
batchIds: Set<string>;
|
||||
destinations: Set<string>;
|
||||
clearQueue: boolean;
|
||||
};
|
||||
|
||||
const resetCancellations = (): void => {
|
||||
cancellations.clearQueue = false;
|
||||
cancellations.sessionIds.clear();
|
||||
cancellations.batchIds.clear();
|
||||
cancellations.destinations.clear();
|
||||
};
|
||||
|
||||
const cancellations: Cancellations = {
|
||||
sessionIds: new Set(),
|
||||
batchIds: new Set(),
|
||||
destinations: new Set(),
|
||||
clearQueue: false,
|
||||
} as Readonly<Cancellations>;
|
||||
|
||||
/**
|
||||
* Checks if an item is cancelled, used to prevent race conditions with event handling.
|
||||
*
|
||||
* To use this, provide the session_id, batch_id and destination from the event payload.
|
||||
*/
|
||||
export const getIsCancelled = (item: {
|
||||
session_id: string;
|
||||
batch_id: string;
|
||||
destination?: string | null;
|
||||
}): boolean => {
|
||||
if (cancellations.clearQueue) {
|
||||
return true;
|
||||
}
|
||||
if (cancellations.sessionIds.has(item.session_id)) {
|
||||
return true;
|
||||
}
|
||||
if (cancellations.batchIds.has(item.batch_id)) {
|
||||
return true;
|
||||
}
|
||||
if (item.destination && cancellations.destinations.has(item.destination)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
export const addCancellationsListeners = (startAppListening: AppStartListening) => {
|
||||
// When we get a cancellation, we may need to clear the last progress event - next few listeners handle those cases.
|
||||
// Maybe we could use the `getIsCancelled` util here, but I think that could introduce _another_ race condition...
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||
effect: () => {
|
||||
resetCancellations();
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.cancelByBatchDestination.matchFulfilled,
|
||||
effect: (action) => {
|
||||
cancellations.destinations.add(action.meta.arg.originalArgs.destination);
|
||||
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
|
||||
effect: (action) => {
|
||||
cancellations.sessionIds.add(action.payload.session_id);
|
||||
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.cancelByBatchIds.matchFulfilled,
|
||||
effect: (action) => {
|
||||
for (const batch_id of action.meta.arg.originalArgs.batch_ids) {
|
||||
cancellations.batchIds.add(batch_id);
|
||||
}
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.clearQueue.matchFulfilled,
|
||||
effect: () => {
|
||||
cancellations.clearQueue = true;
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,6 +1,7 @@
|
||||
import { ExternalLink } from '@invoke-ai/ui-library';
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { getIsCancelled } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
|
||||
import { $queueId } from 'app/store/nanostores/queueId';
|
||||
@@ -39,7 +40,6 @@ export const $lastProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | nul
|
||||
export const $lastCanvasProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
|
||||
export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val));
|
||||
export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null);
|
||||
const cancellations = new Set<string>();
|
||||
|
||||
export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }: SetEventListenersArg) => {
|
||||
socket.on('connect', () => {
|
||||
@@ -54,7 +54,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
}
|
||||
$lastProgressEvent.set(null);
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
cancellations.clear();
|
||||
});
|
||||
|
||||
socket.on('connect_error', (error) => {
|
||||
@@ -73,7 +72,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
});
|
||||
}
|
||||
}
|
||||
cancellations.clear();
|
||||
});
|
||||
|
||||
socket.on('disconnect', () => {
|
||||
@@ -81,7 +79,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
$lastProgressEvent.set(null);
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
setIsConnected(false);
|
||||
cancellations.clear();
|
||||
});
|
||||
|
||||
socket.on('invocation_started', (data) => {
|
||||
@@ -92,7 +89,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
cancellations.clear();
|
||||
});
|
||||
|
||||
socket.on('invocation_denoise_progress', (data) => {
|
||||
@@ -106,9 +102,10 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
destination,
|
||||
percentage,
|
||||
session_id,
|
||||
batch_id,
|
||||
} = data;
|
||||
|
||||
if (cancellations.has(session_id)) {
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
|
||||
// progress update after the session has been cancelled.
|
||||
return;
|
||||
@@ -131,7 +128,8 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
}
|
||||
}
|
||||
|
||||
if (origin === 'generation' && destination === 'canvas') {
|
||||
// This event is only relevant for the canvas
|
||||
if (destination === 'canvas') {
|
||||
$lastCanvasProgressEvent.set(data);
|
||||
}
|
||||
});
|
||||
@@ -464,16 +462,13 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
|
||||
/>
|
||||
),
|
||||
});
|
||||
cancellations.add(session_id);
|
||||
} else if (status === 'canceled') {
|
||||
$lastProgressEvent.set(null);
|
||||
if (origin === 'canvas') {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
cancellations.add(session_id);
|
||||
} else if (status === 'completed') {
|
||||
$lastProgressEvent.set(null);
|
||||
cancellations.add(session_id);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user