feat(ui): use the new get_queue_counts_by_destination to control staging area

This commit is contained in:
psychedelicious
2024-09-17 18:40:09 +10:00
committed by Kent Keirsey
parent bf3891092d
commit 7b9d8df1a7
6 changed files with 69 additions and 52 deletions

View File

@@ -5,11 +5,6 @@ import type { SerializableObject } from 'common/types';
import type { Result } from 'common/util/result';
import { withResult, withResultAsync } from 'common/util/result';
import { $canvasManager } from 'features/controlLayers/store/canvasSlice';
import {
selectIsStaging,
stagingAreaReset,
stagingAreaStartedStaging,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
@@ -34,19 +29,6 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
const manager = $canvasManager.get();
assert(manager, 'No model found in state');
let didStartStaging = false;
if (!selectIsStaging(state) && state.canvasSettings.sendToCanvas) {
dispatch(stagingAreaStartedStaging());
didStartStaging = true;
}
const abortStaging = () => {
if (didStartStaging && selectIsStaging(getState())) {
dispatch(stagingAreaReset());
}
};
let buildGraphResult: Result<
{
g: Graph;
@@ -76,7 +58,6 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
if (buildGraphResult.isErr()) {
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
abortStaging();
return;
}
@@ -90,7 +71,6 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
abortStaging();
return;
}
@@ -105,7 +85,6 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
if (enqueueResult.isErr()) {
log.error({ error: serializeError(enqueueResult.error) }, 'Failed to enqueue batch');
abortStaging();
return;
}

View File

@@ -2,8 +2,16 @@ import { useAppSelector } from 'app/store/storeHooks';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
import { useGetQueueCountsByDestinationQuery } from 'services/api/endpoints/queue';
// This hook just serves as a persistent subscriber for the queue count query.
const queueCountArg = { destination: 'canvas' };
const useCanvasQueueCountWatcher = () => {
useGetQueueCountsByDestinationQuery(queueCountArg);
};
export const StagingAreaIsStagingGate = memo((props: PropsWithChildren) => {
useCanvasQueueCountWatcher();
const isStaging = useAppSelector(selectIsStaging);
if (!isStaging) {

View File

@@ -1,12 +1,8 @@
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
selectCanvasStagingAreaSlice,
stagingAreaStartedStaging,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectCanvasStagingAreaSlice, selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import type { StagingAreaImage } from 'features/controlLayers/store/types';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import Konva from 'konva';
@@ -43,17 +39,23 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
this.image = null;
this.selectedImage = null;
/**
* When we change this flag, we need to re-render the staging area, which hides or shows the staged image.
*/
this.subscriptions.add(this.$shouldShowStagedImage.listen(this.render));
/**
* When the staging redux state changes (i.e. when the selected staged image is changed, or we add/discard a staged
* image), we need to re-render the staging area.
*/
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasStagingAreaSlice, this.render));
/**
* Sync the $isStaging flag with the redux state. $isStaging is used by the manager to determine the global busy
* state of the canvas.
*/
this.subscriptions.add(
this.manager.stateApi.store.dispatch(
addAppListener({
actionCreator: stagingAreaStartedStaging,
effect: () => {
this.$shouldShowStagedImage.set(true);
},
})
)
this.manager.stateApi.createStoreSubscription(selectIsStaging, (isStaging) => {
this.$isStaging.set(isStaging);
})
);
}
@@ -65,7 +67,6 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
render = async () => {
this.log.trace('Rendering staging area');
const stagingArea = this.manager.stateApi.runSelector(selectCanvasStagingAreaSlice);
this.$isStaging.set(stagingArea.isStaging);
const { x, y, width, height } = this.manager.stateApi.getBbox().rect;
const shouldShowStagedImage = this.$shouldShowStagedImage.get();

View File

@@ -3,15 +3,14 @@ import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { canvasReset } from 'features/controlLayers/store/canvasSlice';
import type { StagingAreaImage } from 'features/controlLayers/store/types';
import { selectCanvasQueueCounts } from 'services/api/endpoints/queue';
type CanvasStagingAreaState = {
isStaging: boolean;
stagedImages: StagingAreaImage[];
selectedStagedImageIndex: number;
};
const initialState: CanvasStagingAreaState = {
isStaging: false,
stagedImages: [],
selectedStagedImageIndex: 0,
};
@@ -20,13 +19,8 @@ export const canvasStagingAreaSlice = createSlice({
name: 'canvasStagingArea',
initialState,
reducers: {
stagingAreaStartedStaging: (state) => {
state.isStaging = true;
state.selectedStagedImageIndex = 0;
},
stagingAreaImageStaged: (state, action: PayloadAction<{ stagingAreaImage: StagingAreaImage }>) => {
const { stagingAreaImage } = action.payload;
state.isStaging = true;
state.stagedImages.push(stagingAreaImage);
state.selectedStagedImageIndex = state.stagedImages.length - 1;
},
@@ -41,12 +35,8 @@ export const canvasStagingAreaSlice = createSlice({
const { index } = action.payload;
state.stagedImages.splice(index, 1);
state.selectedStagedImageIndex = Math.min(state.selectedStagedImageIndex, state.stagedImages.length - 1);
if (state.stagedImages.length === 0) {
state.isStaging = false;
}
},
stagingAreaReset: (state) => {
state.isStaging = false;
state.stagedImages = [];
state.selectedStagedImageIndex = 0;
},
@@ -60,7 +50,6 @@ export const canvasStagingAreaSlice = createSlice({
});
export const {
stagingAreaStartedStaging,
stagingAreaImageStaged,
stagingAreaStagedImageDiscarded,
stagingAreaReset,
@@ -83,4 +72,21 @@ export const canvasStagingAreaPersistConfig: PersistConfig<CanvasStagingAreaStat
export const selectCanvasStagingAreaSlice = (s: RootState) => s.canvasStagingArea;
export const selectIsStaging = createSelector(selectCanvasStagingAreaSlice, (stagingaArea) => stagingaArea.isStaging);
/**
* Selects if we should be staging images. This is true if:
* - There are staged images.
* - There are any in-progress or pending canvas queue items.
*/
export const selectIsStaging = createSelector(
selectCanvasQueueCounts,
selectCanvasStagingAreaSlice,
({ data }, staging) => {
if (staging.stagedImages.length > 0) {
return true;
}
if (!data) {
return false;
}
return data.in_progress > 0 || data.pending > 0;
}
);

View File

@@ -70,7 +70,7 @@ export const queueApi = api.injectEndpoints({
body: arg,
method: 'POST',
}),
invalidatesTags: ['CurrentSessionQueueItem', 'NextSessionQueueItem'],
invalidatesTags: ['CurrentSessionQueueItem', 'NextSessionQueueItem', 'QueueCountsByDestination'],
onQueryStarted: async (arg, api) => {
const { dispatch, queryFulfilled } = api;
try {
@@ -163,6 +163,7 @@ export const queueApi = api.injectEndpoints({
'BatchStatus',
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'QueueCountsByDestination',
],
onQueryStarted: async (arg, api) => {
const { dispatch, queryFulfilled } = api;
@@ -279,10 +280,14 @@ export const queueApi = api.injectEndpoints({
if (!result) {
return [];
}
return [
const tags: ApiTagDescription[] = [
{ type: 'SessionQueueItem', id: result.item_id },
{ type: 'BatchStatus', id: result.batch_id },
];
if (result.destination) {
tags.push({ type: 'QueueCountsByDestination', id: result.destination });
}
return tags;
},
}),
cancelByBatchIds: build.mutation<
@@ -303,7 +308,7 @@ export const queueApi = api.injectEndpoints({
// no-op
}
},
invalidatesTags: ['SessionQueueStatus', 'BatchStatus'],
invalidatesTags: ['SessionQueueStatus', 'BatchStatus', 'QueueCountsByDestination'],
}),
cancelByBatchDestination: build.mutation<
paths['/api/v1/queue/{queue_id}/cancel_by_destination']['put']['responses']['200']['content']['application/json'],
@@ -323,7 +328,12 @@ export const queueApi = api.injectEndpoints({
// no-op
}
},
invalidatesTags: ['SessionQueueStatus', 'BatchStatus'],
invalidatesTags: (result, error, { destination }) => {
if (!result) {
return [];
}
return ['SessionQueueStatus', 'BatchStatus', { type: 'QueueCountsByDestination', id: destination }];
},
}),
listQueueItems: build.query<
EntityState<components['schemas']['SessionQueueItemDTO'], string> & {
@@ -353,6 +363,16 @@ export const queueApi = api.injectEndpoints({
keepUnusedDataFor: 60 * 5, // 5 minutes
providesTags: ['FetchOnReconnect'],
}),
getQueueCountsByDestination: build.query<
paths['/api/v1/queue/{queue_id}/counts_by_destination']['get']['responses']['200']['content']['application/json'],
paths['/api/v1/queue/{queue_id}/counts_by_destination']['get']['parameters']['query']
>({
query: (params) => ({ url: buildQueueUrl('counts_by_destination'), method: 'GET', params }),
providesTags: (result, error, { destination }) => [
'FetchOnReconnect',
{ type: 'QueueCountsByDestination', id: destination },
],
}),
}),
});
@@ -369,9 +389,11 @@ export const {
useCancelQueueItemMutation,
useGetBatchStatusQuery,
useGetCurrentQueueItemQuery,
useGetQueueCountsByDestinationQuery,
} = queueApi;
export const selectQueueStatus = queueApi.endpoints.getQueueStatus.select();
export const selectCanvasQueueCounts = queueApi.endpoints.getQueueCountsByDestination.select({ destination: 'canvas' });
const resetListQueryData = (
// eslint-disable-next-line @typescript-eslint/no-explicit-any

View File

@@ -42,6 +42,7 @@ const tagTypes = [
'WorkflowsRecent',
'StylePreset',
'Schema',
'QueueCountsByDestination',
// This is invalidated on reconnect. It should be used for queries that have changing data,
// especially related to the queue and generation.
'FetchOnReconnect',