mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): use the new get_queue_counts_by_destination to control staging area
This commit is contained in:
committed by
Kent Keirsey
parent
bf3891092d
commit
7b9d8df1a7
@@ -5,11 +5,6 @@ import type { SerializableObject } from 'common/types';
|
||||
import type { Result } from 'common/util/result';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { $canvasManager } from 'features/controlLayers/store/canvasSlice';
|
||||
import {
|
||||
selectIsStaging,
|
||||
stagingAreaReset,
|
||||
stagingAreaStartedStaging,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
@@ -34,19 +29,6 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
const manager = $canvasManager.get();
|
||||
assert(manager, 'No model found in state');
|
||||
|
||||
let didStartStaging = false;
|
||||
|
||||
if (!selectIsStaging(state) && state.canvasSettings.sendToCanvas) {
|
||||
dispatch(stagingAreaStartedStaging());
|
||||
didStartStaging = true;
|
||||
}
|
||||
|
||||
const abortStaging = () => {
|
||||
if (didStartStaging && selectIsStaging(getState())) {
|
||||
dispatch(stagingAreaReset());
|
||||
}
|
||||
};
|
||||
|
||||
let buildGraphResult: Result<
|
||||
{
|
||||
g: Graph;
|
||||
@@ -76,7 +58,6 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
|
||||
abortStaging();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -90,7 +71,6 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
abortStaging();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -105,7 +85,6 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
if (enqueueResult.isErr()) {
|
||||
log.error({ error: serializeError(enqueueResult.error) }, 'Failed to enqueue batch');
|
||||
abortStaging();
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,16 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useGetQueueCountsByDestinationQuery } from 'services/api/endpoints/queue';
|
||||
|
||||
// This hook just serves as a persistent subscriber for the queue count query.
|
||||
const queueCountArg = { destination: 'canvas' };
|
||||
const useCanvasQueueCountWatcher = () => {
|
||||
useGetQueueCountsByDestinationQuery(queueCountArg);
|
||||
};
|
||||
|
||||
export const StagingAreaIsStagingGate = memo((props: PropsWithChildren) => {
|
||||
useCanvasQueueCountWatcher();
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
|
||||
if (!isStaging) {
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import {
|
||||
selectCanvasStagingAreaSlice,
|
||||
stagingAreaStartedStaging,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectCanvasStagingAreaSlice, selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import type { StagingAreaImage } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import Konva from 'konva';
|
||||
@@ -43,17 +39,23 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
|
||||
this.image = null;
|
||||
this.selectedImage = null;
|
||||
|
||||
/**
|
||||
* When we change this flag, we need to re-render the staging area, which hides or shows the staged image.
|
||||
*/
|
||||
this.subscriptions.add(this.$shouldShowStagedImage.listen(this.render));
|
||||
/**
|
||||
* When the staging redux state changes (i.e. when the selected staged image is changed, or we add/discard a staged
|
||||
* image), we need to re-render the staging area.
|
||||
*/
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasStagingAreaSlice, this.render));
|
||||
/**
|
||||
* Sync the $isStaging flag with the redux state. $isStaging is used by the manager to determine the global busy
|
||||
* state of the canvas.
|
||||
*/
|
||||
this.subscriptions.add(
|
||||
this.manager.stateApi.store.dispatch(
|
||||
addAppListener({
|
||||
actionCreator: stagingAreaStartedStaging,
|
||||
effect: () => {
|
||||
this.$shouldShowStagedImage.set(true);
|
||||
},
|
||||
})
|
||||
)
|
||||
this.manager.stateApi.createStoreSubscription(selectIsStaging, (isStaging) => {
|
||||
this.$isStaging.set(isStaging);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
@@ -65,7 +67,6 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
|
||||
render = async () => {
|
||||
this.log.trace('Rendering staging area');
|
||||
const stagingArea = this.manager.stateApi.runSelector(selectCanvasStagingAreaSlice);
|
||||
this.$isStaging.set(stagingArea.isStaging);
|
||||
|
||||
const { x, y, width, height } = this.manager.stateApi.getBbox().rect;
|
||||
const shouldShowStagedImage = this.$shouldShowStagedImage.get();
|
||||
|
||||
@@ -3,15 +3,14 @@ import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { canvasReset } from 'features/controlLayers/store/canvasSlice';
|
||||
import type { StagingAreaImage } from 'features/controlLayers/store/types';
|
||||
import { selectCanvasQueueCounts } from 'services/api/endpoints/queue';
|
||||
|
||||
type CanvasStagingAreaState = {
|
||||
isStaging: boolean;
|
||||
stagedImages: StagingAreaImage[];
|
||||
selectedStagedImageIndex: number;
|
||||
};
|
||||
|
||||
const initialState: CanvasStagingAreaState = {
|
||||
isStaging: false,
|
||||
stagedImages: [],
|
||||
selectedStagedImageIndex: 0,
|
||||
};
|
||||
@@ -20,13 +19,8 @@ export const canvasStagingAreaSlice = createSlice({
|
||||
name: 'canvasStagingArea',
|
||||
initialState,
|
||||
reducers: {
|
||||
stagingAreaStartedStaging: (state) => {
|
||||
state.isStaging = true;
|
||||
state.selectedStagedImageIndex = 0;
|
||||
},
|
||||
stagingAreaImageStaged: (state, action: PayloadAction<{ stagingAreaImage: StagingAreaImage }>) => {
|
||||
const { stagingAreaImage } = action.payload;
|
||||
state.isStaging = true;
|
||||
state.stagedImages.push(stagingAreaImage);
|
||||
state.selectedStagedImageIndex = state.stagedImages.length - 1;
|
||||
},
|
||||
@@ -41,12 +35,8 @@ export const canvasStagingAreaSlice = createSlice({
|
||||
const { index } = action.payload;
|
||||
state.stagedImages.splice(index, 1);
|
||||
state.selectedStagedImageIndex = Math.min(state.selectedStagedImageIndex, state.stagedImages.length - 1);
|
||||
if (state.stagedImages.length === 0) {
|
||||
state.isStaging = false;
|
||||
}
|
||||
},
|
||||
stagingAreaReset: (state) => {
|
||||
state.isStaging = false;
|
||||
state.stagedImages = [];
|
||||
state.selectedStagedImageIndex = 0;
|
||||
},
|
||||
@@ -60,7 +50,6 @@ export const canvasStagingAreaSlice = createSlice({
|
||||
});
|
||||
|
||||
export const {
|
||||
stagingAreaStartedStaging,
|
||||
stagingAreaImageStaged,
|
||||
stagingAreaStagedImageDiscarded,
|
||||
stagingAreaReset,
|
||||
@@ -83,4 +72,21 @@ export const canvasStagingAreaPersistConfig: PersistConfig<CanvasStagingAreaStat
|
||||
|
||||
export const selectCanvasStagingAreaSlice = (s: RootState) => s.canvasStagingArea;
|
||||
|
||||
export const selectIsStaging = createSelector(selectCanvasStagingAreaSlice, (stagingaArea) => stagingaArea.isStaging);
|
||||
/**
|
||||
* Selects if we should be staging images. This is true if:
|
||||
* - There are staged images.
|
||||
* - There are any in-progress or pending canvas queue items.
|
||||
*/
|
||||
export const selectIsStaging = createSelector(
|
||||
selectCanvasQueueCounts,
|
||||
selectCanvasStagingAreaSlice,
|
||||
({ data }, staging) => {
|
||||
if (staging.stagedImages.length > 0) {
|
||||
return true;
|
||||
}
|
||||
if (!data) {
|
||||
return false;
|
||||
}
|
||||
return data.in_progress > 0 || data.pending > 0;
|
||||
}
|
||||
);
|
||||
|
||||
@@ -70,7 +70,7 @@ export const queueApi = api.injectEndpoints({
|
||||
body: arg,
|
||||
method: 'POST',
|
||||
}),
|
||||
invalidatesTags: ['CurrentSessionQueueItem', 'NextSessionQueueItem'],
|
||||
invalidatesTags: ['CurrentSessionQueueItem', 'NextSessionQueueItem', 'QueueCountsByDestination'],
|
||||
onQueryStarted: async (arg, api) => {
|
||||
const { dispatch, queryFulfilled } = api;
|
||||
try {
|
||||
@@ -163,6 +163,7 @@ export const queueApi = api.injectEndpoints({
|
||||
'BatchStatus',
|
||||
'CurrentSessionQueueItem',
|
||||
'NextSessionQueueItem',
|
||||
'QueueCountsByDestination',
|
||||
],
|
||||
onQueryStarted: async (arg, api) => {
|
||||
const { dispatch, queryFulfilled } = api;
|
||||
@@ -279,10 +280,14 @@ export const queueApi = api.injectEndpoints({
|
||||
if (!result) {
|
||||
return [];
|
||||
}
|
||||
return [
|
||||
const tags: ApiTagDescription[] = [
|
||||
{ type: 'SessionQueueItem', id: result.item_id },
|
||||
{ type: 'BatchStatus', id: result.batch_id },
|
||||
];
|
||||
if (result.destination) {
|
||||
tags.push({ type: 'QueueCountsByDestination', id: result.destination });
|
||||
}
|
||||
return tags;
|
||||
},
|
||||
}),
|
||||
cancelByBatchIds: build.mutation<
|
||||
@@ -303,7 +308,7 @@ export const queueApi = api.injectEndpoints({
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
invalidatesTags: ['SessionQueueStatus', 'BatchStatus'],
|
||||
invalidatesTags: ['SessionQueueStatus', 'BatchStatus', 'QueueCountsByDestination'],
|
||||
}),
|
||||
cancelByBatchDestination: build.mutation<
|
||||
paths['/api/v1/queue/{queue_id}/cancel_by_destination']['put']['responses']['200']['content']['application/json'],
|
||||
@@ -323,7 +328,12 @@ export const queueApi = api.injectEndpoints({
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
invalidatesTags: ['SessionQueueStatus', 'BatchStatus'],
|
||||
invalidatesTags: (result, error, { destination }) => {
|
||||
if (!result) {
|
||||
return [];
|
||||
}
|
||||
return ['SessionQueueStatus', 'BatchStatus', { type: 'QueueCountsByDestination', id: destination }];
|
||||
},
|
||||
}),
|
||||
listQueueItems: build.query<
|
||||
EntityState<components['schemas']['SessionQueueItemDTO'], string> & {
|
||||
@@ -353,6 +363,16 @@ export const queueApi = api.injectEndpoints({
|
||||
keepUnusedDataFor: 60 * 5, // 5 minutes
|
||||
providesTags: ['FetchOnReconnect'],
|
||||
}),
|
||||
getQueueCountsByDestination: build.query<
|
||||
paths['/api/v1/queue/{queue_id}/counts_by_destination']['get']['responses']['200']['content']['application/json'],
|
||||
paths['/api/v1/queue/{queue_id}/counts_by_destination']['get']['parameters']['query']
|
||||
>({
|
||||
query: (params) => ({ url: buildQueueUrl('counts_by_destination'), method: 'GET', params }),
|
||||
providesTags: (result, error, { destination }) => [
|
||||
'FetchOnReconnect',
|
||||
{ type: 'QueueCountsByDestination', id: destination },
|
||||
],
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -369,9 +389,11 @@ export const {
|
||||
useCancelQueueItemMutation,
|
||||
useGetBatchStatusQuery,
|
||||
useGetCurrentQueueItemQuery,
|
||||
useGetQueueCountsByDestinationQuery,
|
||||
} = queueApi;
|
||||
|
||||
export const selectQueueStatus = queueApi.endpoints.getQueueStatus.select();
|
||||
export const selectCanvasQueueCounts = queueApi.endpoints.getQueueCountsByDestination.select({ destination: 'canvas' });
|
||||
|
||||
const resetListQueryData = (
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
|
||||
@@ -42,6 +42,7 @@ const tagTypes = [
|
||||
'WorkflowsRecent',
|
||||
'StylePreset',
|
||||
'Schema',
|
||||
'QueueCountsByDestination',
|
||||
// This is invalidated on reconnect. It should be used for queries that have changing data,
|
||||
// especially related to the queue and generation.
|
||||
'FetchOnReconnect',
|
||||
|
||||
Reference in New Issue
Block a user