fix(ui): race conditions with progress events

There's a race condition where we sometimes get progress events from canceled queue items, depending on the timing of the cancellation request and last event or two from the queue item.

I can't imagine how to resolve this except by tracking all cancellations and ignoring events for cancelled items, which is implemented in this change.
This commit is contained in:
psychedelicious
2024-09-07 21:33:04 +10:00
parent 2a1bc3e044
commit 2bdfc340aa
4 changed files with 145 additions and 13 deletions

View File

@@ -9,6 +9,7 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addCancellationsListeners } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
@@ -119,3 +120,5 @@ addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);
// addControlAdapterPreprocessor(startAppListening);
addCancellationsListeners(startAppListening);

View File

@@ -12,7 +12,6 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
import { assert } from 'tsafe';
const log = logger('canvas');
@@ -33,8 +32,6 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
const { canceled } = await req.unwrap();
req.reset();
$lastCanvasProgressEvent.set(null);
if (canceled > 0) {
log.debug(`Canceled ${canceled} canvas batches`);
toast({

View File

@@ -0,0 +1,137 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { queueApi } from 'services/api/endpoints/queue';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
/**
* To prevent a race condition where a progress event arrives after a successful cancellation, we need to keep track of
* cancellations:
* - In the route handlers above, we track and update the cancellations object
* - When the user queues a, we should reset the cancellations, also handled int he route handlers above
* - When we get a progress event, we should check if the event is cancelled before setting the event
*
* We have a few ways that cancellations are effected, so we need to track them all:
* - by queue item id (in this case, we will compare the session_id and not the item_id)
* - by batch id
* - by destination
* - by clearing the queue
*/
type Cancellations = {
sessionIds: Set<string>;
batchIds: Set<string>;
destinations: Set<string>;
clearQueue: boolean;
};
const resetCancellations = (): void => {
cancellations.clearQueue = false;
cancellations.sessionIds.clear();
cancellations.batchIds.clear();
cancellations.destinations.clear();
};
const cancellations: Cancellations = {
sessionIds: new Set(),
batchIds: new Set(),
destinations: new Set(),
clearQueue: false,
} as Readonly<Cancellations>;
/**
* Checks if an item is cancelled, used to prevent race conditions with event handling.
*
* To use this, provide the session_id, batch_id and destination from the event payload.
*/
export const getIsCancelled = (item: {
session_id: string;
batch_id: string;
destination?: string | null;
}): boolean => {
if (cancellations.clearQueue) {
return true;
}
if (cancellations.sessionIds.has(item.session_id)) {
return true;
}
if (cancellations.batchIds.has(item.batch_id)) {
return true;
}
if (item.destination && cancellations.destinations.has(item.destination)) {
return true;
}
return false;
};
export const addCancellationsListeners = (startAppListening: AppStartListening) => {
// When we get a cancellation, we may need to clear the last progress event - next few listeners handle those cases.
// Maybe we could use the `getIsCancelled` util here, but I think that could introduce _another_ race condition...
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
effect: () => {
resetCancellations();
},
});
startAppListening({
matcher: queueApi.endpoints.cancelByBatchDestination.matchFulfilled,
effect: (action) => {
cancellations.destinations.add(action.meta.arg.originalArgs.destination);
const event = $lastCanvasProgressEvent.get();
if (!event) {
return;
}
const { session_id, batch_id, destination } = event;
if (getIsCancelled({ session_id, batch_id, destination })) {
$lastCanvasProgressEvent.set(null);
}
},
});
startAppListening({
matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
effect: (action) => {
cancellations.sessionIds.add(action.payload.session_id);
const event = $lastCanvasProgressEvent.get();
if (!event) {
return;
}
const { session_id, batch_id, destination } = event;
if (getIsCancelled({ session_id, batch_id, destination })) {
$lastCanvasProgressEvent.set(null);
}
},
});
startAppListening({
matcher: queueApi.endpoints.cancelByBatchIds.matchFulfilled,
effect: (action) => {
for (const batch_id of action.meta.arg.originalArgs.batch_ids) {
cancellations.batchIds.add(batch_id);
}
const event = $lastCanvasProgressEvent.get();
if (!event) {
return;
}
const { session_id, batch_id, destination } = event;
if (getIsCancelled({ session_id, batch_id, destination })) {
$lastCanvasProgressEvent.set(null);
}
},
});
startAppListening({
matcher: queueApi.endpoints.clearQueue.matchFulfilled,
effect: () => {
cancellations.clearQueue = true;
const event = $lastCanvasProgressEvent.get();
if (!event) {
return;
}
const { session_id, batch_id, destination } = event;
if (getIsCancelled({ session_id, batch_id, destination })) {
$lastCanvasProgressEvent.set(null);
}
},
});
};

View File

@@ -1,6 +1,7 @@
import { ExternalLink } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { getIsCancelled } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
import { $queueId } from 'app/store/nanostores/queueId';
@@ -39,7 +40,6 @@ export const $lastProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | nul
export const $lastCanvasProgressEvent = atom<S['InvocationDenoiseProgressEvent'] | null>(null);
export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val));
export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null);
const cancellations = new Set<string>();
export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }: SetEventListenersArg) => {
socket.on('connect', () => {
@@ -54,7 +54,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
}
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
cancellations.clear();
});
socket.on('connect_error', (error) => {
@@ -73,7 +72,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
});
}
}
cancellations.clear();
});
socket.on('disconnect', () => {
@@ -81,7 +79,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
$lastProgressEvent.set(null);
$lastCanvasProgressEvent.set(null);
setIsConnected(false);
cancellations.clear();
});
socket.on('invocation_started', (data) => {
@@ -92,7 +89,6 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes);
}
cancellations.clear();
});
socket.on('invocation_denoise_progress', (data) => {
@@ -106,9 +102,10 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
destination,
percentage,
session_id,
batch_id,
} = data;
if (cancellations.has(session_id)) {
if (getIsCancelled({ session_id, batch_id, destination })) {
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
// progress update after the session has been cancelled.
return;
@@ -131,7 +128,8 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
}
}
if (origin === 'generation' && destination === 'canvas') {
// This event is only relevant for the canvas
if (destination === 'canvas') {
$lastCanvasProgressEvent.set(data);
}
});
@@ -464,16 +462,13 @@ export const setEventListeners = ({ socket, dispatch, getState, setIsConnected }
/>
),
});
cancellations.add(session_id);
} else if (status === 'canceled') {
$lastProgressEvent.set(null);
if (origin === 'canvas') {
$lastCanvasProgressEvent.set(null);
}
cancellations.add(session_id);
} else if (status === 'completed') {
$lastProgressEvent.set(null);
cancellations.add(session_id);
}
});