mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
fix(ui): ensure staging area always has the right state and session association
This commit is contained in:
committed by
Kent Keirsey
parent
bed01941a5
commit
307259f096
@@ -15,7 +15,7 @@ import { selectBboxRect, selectSelectedEntityIdentifier } from 'features/control
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageNameToImageObject } from 'features/controlLayers/store/util';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { createContext, memo, useContext, useEffect, useMemo } from 'react';
|
||||
import { createContext, memo, useContext, useEffect, useMemo, useState } from 'react';
|
||||
import { getImageDTOSafe } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { S } from 'services/api/types';
|
||||
@@ -94,18 +94,24 @@ export const StagingAreaContextProvider = memo(({ children, sessionId }: PropsWi
|
||||
|
||||
return _stagingAreaAppApi;
|
||||
}, [sessionId, socket, store]);
|
||||
const value = useMemo(() => {
|
||||
return new StagingAreaApi(sessionId, stagingAreaAppApi);
|
||||
}, [sessionId, stagingAreaAppApi]);
|
||||
|
||||
const [stagingAreaApi] = useState(() => new StagingAreaApi());
|
||||
|
||||
useEffect(() => {
|
||||
const api = value;
|
||||
return () => {
|
||||
api.cleanup();
|
||||
};
|
||||
}, [value]);
|
||||
stagingAreaApi.connectToApp(sessionId, stagingAreaAppApi);
|
||||
|
||||
return <StagingAreaContext.Provider value={value}>{children}</StagingAreaContext.Provider>;
|
||||
// We need to subscribe to the queue items query manually to ensure the staging area actually gets the items
|
||||
const { unsubscribe: unsubQueueItemsQuery } = store.dispatch(
|
||||
queueApi.endpoints.listAllQueueItems.initiate({ destination: sessionId })
|
||||
);
|
||||
|
||||
return () => {
|
||||
stagingAreaApi.cleanup();
|
||||
unsubQueueItemsQuery();
|
||||
};
|
||||
}, [sessionId, stagingAreaApi, stagingAreaAppApi, store]);
|
||||
|
||||
return <StagingAreaContext.Provider value={stagingAreaApi}>{children}</StagingAreaContext.Provider>;
|
||||
});
|
||||
StagingAreaContextProvider.displayName = 'StagingAreaContextProvider';
|
||||
|
||||
|
||||
@@ -16,7 +16,8 @@ describe('StagingAreaApi', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
mockApp = createMockStagingAreaApp();
|
||||
api = new StagingAreaApi(sessionId, mockApp);
|
||||
api = new StagingAreaApi();
|
||||
api.connectToApp(sessionId, mockApp);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -25,7 +26,7 @@ describe('StagingAreaApi', () => {
|
||||
|
||||
describe('Constructor and Setup', () => {
|
||||
it('should initialize with correct session ID', () => {
|
||||
expect(api.sessionId).toBe(sessionId);
|
||||
expect(api._sessionId).toBe(sessionId);
|
||||
});
|
||||
|
||||
it('should set up event subscriptions', () => {
|
||||
@@ -747,8 +748,10 @@ describe('StagingAreaApi', () => {
|
||||
|
||||
describe('Event Subscription Management', () => {
|
||||
it('should handle multiple subscriptions and unsubscriptions', () => {
|
||||
const api2 = new StagingAreaApi(sessionId, mockApp);
|
||||
const api3 = new StagingAreaApi(sessionId, mockApp);
|
||||
const api2 = new StagingAreaApi();
|
||||
api2.connectToApp(sessionId, mockApp);
|
||||
const api3 = new StagingAreaApi();
|
||||
api3.connectToApp(sessionId, mockApp);
|
||||
|
||||
// All should be subscribed
|
||||
expect(mockApp.onItemsChanged).toHaveBeenCalledTimes(3);
|
||||
|
||||
@@ -61,19 +61,15 @@ type ProgressDataMap = Record<number, ProgressData | undefined>;
|
||||
* and configure auto-switching behavior.
|
||||
*/
|
||||
export class StagingAreaApi {
|
||||
sessionId: string;
|
||||
_app: StagingAreaAppApi;
|
||||
/** The current session ID. */
|
||||
_sessionId: string | null = null;
|
||||
|
||||
/** The app API */
|
||||
_app: StagingAreaAppApi | null = null;
|
||||
|
||||
/** A set of subscriptions to be cleaned up when we are finished with a session */
|
||||
_subscriptions = new Set<() => void>();
|
||||
|
||||
constructor(sessionId: string, app: StagingAreaAppApi) {
|
||||
this.sessionId = sessionId;
|
||||
this._app = app;
|
||||
|
||||
this._subscriptions.add(this._app.onItemsChanged(this.onItemsChangedEvent));
|
||||
this._subscriptions.add(this._app.onQueueItemStatusChanged(this.onQueueItemStatusChangedEvent));
|
||||
this._subscriptions.add(this._app.onInvocationProgress(this.onInvocationProgressEvent));
|
||||
}
|
||||
|
||||
/** Item ID of the last started item. Used for auto-switch on start. */
|
||||
$lastStartedItemId = atom<number | null>(null);
|
||||
|
||||
@@ -136,7 +132,7 @@ export class StagingAreaApi {
|
||||
/** Selects a queue item by ID. */
|
||||
select = (itemId: number) => {
|
||||
this.$selectedItemId.set(itemId);
|
||||
this._app.onSelect?.(itemId);
|
||||
this._app?.onSelect?.(itemId);
|
||||
};
|
||||
|
||||
/** Selects the next item in the queue, wrapping to the first item if at the end. */
|
||||
@@ -152,7 +148,7 @@ export class StagingAreaApi {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(nextItem.item_id);
|
||||
this._app.onSelectNext?.();
|
||||
this._app?.onSelectNext?.();
|
||||
};
|
||||
|
||||
/** Selects the previous item in the queue, wrapping to the last item if at the beginning. */
|
||||
@@ -168,7 +164,7 @@ export class StagingAreaApi {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(prevItem.item_id);
|
||||
this._app.onSelectPrev?.();
|
||||
this._app?.onSelectPrev?.();
|
||||
};
|
||||
|
||||
/** Selects the first item in the queue. */
|
||||
@@ -179,7 +175,7 @@ export class StagingAreaApi {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(first.item_id);
|
||||
this._app.onSelectFirst?.();
|
||||
this._app?.onSelectFirst?.();
|
||||
};
|
||||
|
||||
/** Selects the last item in the queue. */
|
||||
@@ -190,7 +186,7 @@ export class StagingAreaApi {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(last.item_id);
|
||||
this._app.onSelectLast?.();
|
||||
this._app?.onSelectLast?.();
|
||||
};
|
||||
|
||||
/** Discards the currently selected item and selects the next available item. */
|
||||
@@ -207,7 +203,7 @@ export class StagingAreaApi {
|
||||
} else {
|
||||
this.$selectedItemId.set(null);
|
||||
}
|
||||
this._app.onDiscard?.(selectedItem.item);
|
||||
this._app?.onDiscard?.(selectedItem.item);
|
||||
};
|
||||
|
||||
/** Whether the discard selected action is enabled. */
|
||||
@@ -218,10 +214,23 @@ export class StagingAreaApi {
|
||||
return true;
|
||||
});
|
||||
|
||||
/** Connects to the app, registering listeners and such */
|
||||
connectToApp = (sessionId: string, app: StagingAreaAppApi) => {
|
||||
if (this._sessionId !== sessionId) {
|
||||
this.cleanup();
|
||||
this._sessionId = sessionId;
|
||||
}
|
||||
this._app = app;
|
||||
|
||||
this._subscriptions.add(this._app.onItemsChanged(this.onItemsChangedEvent));
|
||||
this._subscriptions.add(this._app.onQueueItemStatusChanged(this.onQueueItemStatusChangedEvent));
|
||||
this._subscriptions.add(this._app.onInvocationProgress(this.onInvocationProgressEvent));
|
||||
};
|
||||
|
||||
/** Discards all items in the queue. */
|
||||
discardAll = () => {
|
||||
this.$selectedItemId.set(null);
|
||||
this._app.onDiscardAll?.();
|
||||
this._app?.onDiscardAll?.();
|
||||
};
|
||||
|
||||
/** Accepts the currently selected item if an image is available. */
|
||||
@@ -235,7 +244,7 @@ export class StagingAreaApi {
|
||||
if (!datum || !datum.imageDTO) {
|
||||
return;
|
||||
}
|
||||
this._app.onAccept?.(selectedItem.item, datum.imageDTO);
|
||||
this._app?.onAccept?.(selectedItem.item, datum.imageDTO);
|
||||
};
|
||||
|
||||
/** Whether the accept selected action is enabled. */
|
||||
@@ -249,12 +258,12 @@ export class StagingAreaApi {
|
||||
|
||||
/** Sets the auto-switch mode. */
|
||||
setAutoSwitch = (mode: AutoSwitchMode) => {
|
||||
this._app.onAutoSwitchChange?.(mode);
|
||||
this._app?.onAutoSwitchChange?.(mode);
|
||||
};
|
||||
|
||||
/** Handles invocation progress events from the WebSocket. */
|
||||
onInvocationProgressEvent = (data: S['InvocationProgressEvent']) => {
|
||||
if (data.destination !== this.sessionId) {
|
||||
if (data.destination !== this._sessionId) {
|
||||
return;
|
||||
}
|
||||
setProgress(this.$progressData, data);
|
||||
@@ -262,7 +271,7 @@ export class StagingAreaApi {
|
||||
|
||||
/** Handles queue item status change events from the WebSocket. */
|
||||
onQueueItemStatusChangedEvent = (data: S['QueueItemStatusChangedEvent']) => {
|
||||
if (data.destination !== this.sessionId) {
|
||||
if (data.destination !== this._sessionId) {
|
||||
return;
|
||||
}
|
||||
if (data.status === 'completed') {
|
||||
@@ -277,7 +286,7 @@ export class StagingAreaApi {
|
||||
*/
|
||||
this.$lastCompletedItemId.set(data.item_id);
|
||||
}
|
||||
if (data.status === 'in_progress' && this._app.getAutoSwitch() === 'switch_on_start') {
|
||||
if (data.status === 'in_progress' && this._app?.getAutoSwitch() === 'switch_on_start') {
|
||||
this.$lastStartedItemId.set(data.item_id);
|
||||
}
|
||||
};
|
||||
@@ -327,7 +336,7 @@ export class StagingAreaApi {
|
||||
for (const item of items) {
|
||||
const datum = progressData[item.item_id];
|
||||
|
||||
if (this.$lastStartedItemId.get() === item.item_id && this._app.getAutoSwitch() === 'switch_on_start') {
|
||||
if (this.$lastStartedItemId.get() === item.item_id && this._app?.getAutoSwitch() === 'switch_on_start') {
|
||||
this.$selectedItemId.set(item.item_id);
|
||||
this.$lastStartedItemId.set(null);
|
||||
}
|
||||
@@ -339,13 +348,13 @@ export class StagingAreaApi {
|
||||
if (!outputImageName) {
|
||||
continue;
|
||||
}
|
||||
const imageDTO = await this._app.getImageDTO(outputImageName);
|
||||
const imageDTO = await this._app?.getImageDTO(outputImageName);
|
||||
if (!imageDTO) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// This is the load logic mentioned in the comment in the QueueItemStatusChangedEvent handler above.
|
||||
if (this.$lastCompletedItemId.get() === item.item_id && this._app.getAutoSwitch() === 'switch_on_finish') {
|
||||
if (this.$lastCompletedItemId.get() === item.item_id && this._app?.getAutoSwitch() === 'switch_on_finish') {
|
||||
this._app.loadImage(imageDTO.image_url).then(() => {
|
||||
this.$selectedItemId.set(item.item_id);
|
||||
this.$lastCompletedItemId.set(null);
|
||||
|
||||
Reference in New Issue
Block a user