feat(ui): add progress slice and move progress to it

This slice tracks denoising progress for linear, canvas and workflow tabs.

`ViewerProgress` now uses it for showing progress images.
This commit is contained in:
psychedelicious
2024-02-02 22:05:00 +11:00
parent 24d67c77e1
commit b02e11d2b5
19 changed files with 440 additions and 102 deletions

View File

@@ -1,25 +1,25 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { canvasBatchIdsReset, commitStagingAreaImage, discardStagedImages } from 'features/canvas/store/canvasSlice';
import { matchAnyStagingAreaDismissed } from 'features/canvas/store/canvasSlice';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import { startAppListening } from '..';
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages);
export const addCommitStagingAreaImageListener = () => {
startAppListening({
matcher,
matcher: matchAnyStagingAreaDismissed,
effect: async (_, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { batchIds } = state.canvas;
const { canvasBatchIds } = state.progress;
try {
const req = dispatch(
queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: batchIds }, { fixedCacheKey: 'cancelByBatchIds' })
queueApi.endpoints.cancelByBatchIds.initiate(
{ batch_ids: canvasBatchIds },
{ fixedCacheKey: 'cancelByBatchIds' }
)
);
const { canceled } = await req.unwrap();
req.reset();
@@ -32,7 +32,6 @@ export const addCommitStagingAreaImageListener = () => {
})
);
}
dispatch(canvasBatchIdsReset());
} catch {
log.error('Failed to cancel canvas batches');
dispatch(

View File

@@ -2,13 +2,14 @@ import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { parseify } from 'common/util/serialize';
import { canvasBatchIdAdded, stagingAreaInitialized } from 'features/canvas/store/canvasSlice';
import { stagingAreaInitialized } from 'features/canvas/store/canvasSlice';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { buildCanvasGraph } from 'features/nodes/util/graph/buildCanvasGraph';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { canvasBatchEnqueued } from 'features/progress/store/progressSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO } from 'services/api/types';
@@ -121,8 +122,6 @@ export const addEnqueueRequestedCanvasListener = () => {
const enqueueResult = await req.unwrap();
req.reset();
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
// Prep the canvas staging area if it is not yet initialized
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
@@ -135,8 +134,9 @@ export const addEnqueueRequestedCanvasListener = () => {
);
}
// Associate the session with the canvas session ID
dispatch(canvasBatchIdAdded(batchId));
if (enqueueResult.batch.batch_id) {
dispatch(canvasBatchEnqueued(enqueueResult.batch.batch_id));
}
} catch {
// no-op
}

View File

@@ -4,6 +4,7 @@ import { buildLinearImageToImageGraph } from 'features/nodes/util/graph/buildLin
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graph/buildLinearSDXLImageToImageGraph';
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graph/buildLinearSDXLTextToImageGraph';
import { buildLinearTextToImageGraph } from 'features/nodes/util/graph/buildLinearTextToImageGraph';
import { linearBatchEnqueued } from 'features/progress/store/progressSlice';
import { queueApi } from 'services/api/endpoints/queue';
import { startAppListening } from '..';
@@ -35,12 +36,20 @@ export const addEnqueueRequestedLinear = () => {
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
req.reset();
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
if (enqueueResult.batch.batch_id) {
dispatch(linearBatchEnqueued(enqueueResult.batch.batch_id));
}
} catch {
// no-op
}
},
});
};

View File

@@ -1,6 +1,7 @@
import { enqueueRequested } from 'app/store/actions';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { workflowBatchEnqueued } from 'features/progress/store/progressSlice';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
@@ -35,12 +36,20 @@ export const addEnqueueRequestedNodes = () => {
prepend: action.payload.prepend,
};
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
req.reset();
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
if (enqueueResult.batch.batch_id) {
dispatch(workflowBatchEnqueued(enqueueResult.batch.batch_id));
}
} catch {
// no-op
}
},
});
};

View File

@@ -5,6 +5,7 @@ import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gal
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { isImageOutput } from 'features/nodes/types/common';
import { LINEAR_UI_OUTPUT, nodeIDDenyList } from 'features/nodes/util/graph/constants';
import { imageInvocationComplete } from 'features/progress/store/progressSlice';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
@@ -29,7 +30,7 @@ export const addInvocationCompleteEventListener = () => {
// This complete event has an associated image output
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type) && !nodeIDDenyList.includes(source_node_id)) {
const { image_name } = result.image;
const { canvas, gallery } = getState();
const { gallery, progress } = getState();
// This populates the `getImageDTO` cache
const imageDTORequest = dispatch(
@@ -41,8 +42,10 @@ export const addInvocationCompleteEventListener = () => {
const imageDTO = await imageDTORequest.unwrap();
imageDTORequest.unsubscribe();
dispatch(imageInvocationComplete({ data, imageDTO }));
// Add canvas images to the staging area
if (canvas.batchIds.includes(queue_batch_id) && [LINEAR_UI_OUTPUT].includes(data.source_node_id)) {
if (progress.canvasBatchIds.includes(queue_batch_id) && [LINEAR_UI_OUTPUT].includes(data.source_node_id)) {
dispatch(addImageToStagingArea(imageDTO));
}

View File

@@ -20,6 +20,7 @@ import { nodesTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
import { progressPersistConfig, progressSlice } from 'features/progress/store/progressSlice';
import { queueSlice } from 'features/queue/store/queueSlice';
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
import { configSlice } from 'features/system/store/configSlice';
@@ -63,6 +64,7 @@ const allReducers = {
[workflowSlice.name]: workflowSlice.reducer,
[hrfSlice.name]: hrfSlice.reducer,
[viewerSlice.name]: viewerSlice.reducer,
[progressSlice.name]: progressSlice.reducer,
[api.reducerPath]: api.reducer,
};
@@ -108,6 +110,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[modelManagerPersistConfig.name]: modelManagerPersistConfig,
[hrfPersistConfig.name]: hrfPersistConfig,
[viewerPersistConfig.name]: viewerPersistConfig,
[progressPersistConfig.name]: progressPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {

View File

@@ -1,17 +1,13 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasSlice } from 'features/canvas/store/canvasSlice';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { selectProgressSlice } from 'features/progress/store/progressSlice';
import { memo, useEffect, useState } from 'react';
import { Image as KonvaImage } from 'react-konva';
const progressImageSelector = createMemoizedSelector([selectSystemSlice, selectCanvasSlice], (system, canvas) => {
const { denoiseProgress } = system;
const { batchIds } = canvas;
const progressImageSelector = createMemoizedSelector([selectProgressSlice, selectCanvasSlice], (progress, canvas) => {
return {
progressImage:
denoiseProgress && batchIds.includes(denoiseProgress.batch_id) ? denoiseProgress.progress_image : undefined,
progressImage: progress.canvasDenoiseProgress?.progress_image,
boundingBox: canvas.layerState.stagingArea.boundingBox,
};
});

View File

@@ -1,12 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { selectProgressSlice } from 'features/progress/store/progressSlice';
import { selectCanvasSlice } from './canvasSlice';
import { isCanvasBaseImage } from './canvasTypes';
export const isStagingSelector = createSelector(
selectProgressSlice,
selectCanvasSlice,
(canvas) => canvas.batchIds.length > 0 || canvas.layerState.stagingArea.images.length > 0
(progress, canvas) => progress.canvasBatchIds.length > 0 || canvas.layerState.stagingArea.images.length > 0
);
export const initialCanvasImageSelector = createMemoizedSelector(selectCanvasSlice, (canvas) =>

View File

@@ -1,5 +1,5 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
import calculateCoordinates from 'features/canvas/util/calculateCoordinates';
@@ -15,9 +15,7 @@ import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/
import type { IRect, Vector2d } from 'konva/lib/types';
import { clamp, cloneDeep } from 'lodash-es';
import type { RgbaColor } from 'react-colorful';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO } from 'services/api/types';
import { socketQueueItemStatusChanged } from 'services/events/actions';
import type {
BoundingBoxScaleMethod,
@@ -79,7 +77,6 @@ export const initialCanvasState: CanvasState = {
stageCoordinates: { x: 0, y: 0 },
stageDimensions: { width: 0, height: 0 },
stageScale: 1,
batchIds: [],
aspectRatio: {
id: '1:1',
value: 1,
@@ -180,7 +177,6 @@ export const canvasSlice = createSlice({
],
};
state.futureLayerStates = [];
state.batchIds = [];
const newScale = calculateScale(
stageDimensions.width,
@@ -237,12 +233,6 @@ export const canvasSlice = createSlice({
setShouldShowBoundingBox: (state, action: PayloadAction<boolean>) => {
state.shouldShowBoundingBox = action.payload;
},
canvasBatchIdAdded: (state, action: PayloadAction<string>) => {
state.batchIds.push(action.payload);
},
canvasBatchIdsReset: (state) => {
state.batchIds = [];
},
stagingAreaInitialized: (
state,
action: PayloadAction<{
@@ -293,7 +283,6 @@ export const canvasSlice = createSlice({
state.futureLayerStates = [];
state.shouldShowStagingOutline = true;
state.shouldShowStagingImage = true;
state.batchIds = [];
},
addFillRect: (state) => {
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = state;
@@ -426,7 +415,6 @@ export const canvasSlice = createSlice({
state.pastLayerStates.push(cloneDeep(state.layerState));
state.layerState = cloneDeep(initialLayerState);
state.futureLayerStates = [];
state.batchIds = [];
state.boundingBoxCoordinates = {
...initialCanvasState.boundingBoxCoordinates,
};
@@ -536,7 +524,6 @@ export const canvasSlice = createSlice({
state.futureLayerStates = [];
state.shouldShowStagingOutline = true;
state.shouldShowStagingImage = true;
state.batchIds = [];
},
setBoundingBoxScaleMethod: {
reducer: (state, action: PayloadActionWithOptimalDimension<BoundingBoxScaleMethod>) => {
@@ -644,23 +631,6 @@ export const canvasSlice = createSlice({
optimalDimension
);
});
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
const batch_status = action.payload.data.batch_status;
if (!state.batchIds.includes(batch_status.batch_id)) {
return;
}
if (batch_status.in_progress === 0 && batch_status.pending === 0) {
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
}
});
builder.addMatcher(queueApi.endpoints.clearQueue.matchFulfilled, (state) => {
state.batchIds = [];
});
builder.addMatcher(queueApi.endpoints.cancelByBatchIds.matchFulfilled, (state, action) => {
state.batchIds = state.batchIds.filter((id) => !action.meta.arg.originalArgs.batch_ids.includes(id));
});
},
});
@@ -713,8 +683,6 @@ export const {
stagingAreaInitialized,
setShouldAntialias,
canvasResized,
canvasBatchIdAdded,
canvasBatchIdsReset,
aspectRatioChanged,
scaledBoundingBoxDimensionsReset,
} = canvasSlice.actions;
@@ -736,3 +704,5 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
migrate: migrateCanvasState,
persistDenylist: [],
};
export const matchAnyStagingAreaDismissed = isAnyOf(commitStagingAreaImage, discardStagedImages);

View File

@@ -142,7 +142,6 @@ export interface CanvasState {
stageDimensions: Dimensions;
stageScale: number;
generationMode?: GenerationMode;
batchIds: string[];
aspectRatio: AspectRatioState;
}

View File

@@ -0,0 +1,13 @@
# Progress
We have 3 different places to display progress images:
- TextToImage & ImageToImage
- Canvas
- Workflow
The progress slice tracks the latest denoising progress events, latest image output, and active batch ids for each of the workspaces.
Each of these have different requirements for displaying progress images, but much of the logic around tracking progress is the same, so it is consolidated here.
It also holds the latest progress event separately, which is used for the progress bar.

View File

@@ -1,28 +1,24 @@
import { Progress } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
const selectProgressValue = createSelector(
selectSystemSlice,
(system) => (system.denoiseProgress?.percentage ?? 0) * 100
);
const ProgressBar = () => {
const { t } = useTranslation();
const { data: queueStatus } = useGetQueueStatusQuery();
const isConnected = useAppSelector((s) => s.system.isConnected);
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress));
const value = useAppSelector(selectProgressValue);
const hasSteps = useAppSelector((s) => Boolean(s.progress.latestDenoiseProgress));
const value = useAppSelector((s) => (s.progress.latestDenoiseProgress?.percentage ?? 0) * 100);
const isIndeterminate = useMemo(() => {
return isConnected && Boolean(queueStatus?.queue.in_progress) && !hasSteps;
}, [hasSteps, isConnected, queueStatus?.queue.in_progress]);
return (
<Progress
value={value}
aria-label={t('accessibility.invokeProgressBar')}
isIndeterminate={isConnected && Boolean(queueStatus?.queue.in_progress) && !hasSteps}
isIndeterminate={isIndeterminate}
h={2}
w="full"
colorScheme="invokeBlue"

View File

@@ -0,0 +1,281 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import {
addImageToStagingArea,
commitStagingAreaImage,
discardStagedImages,
resetCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import type { DenoiseProgress } from 'features/progress/store/types';
import { calculateStepPercentage } from 'features/system/util/calculateStepPercentage';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO } from 'services/api/types';
import {
socketConnected,
socketDisconnected,
socketGeneratorProgress,
socketInvocationComplete,
socketQueueItemStatusChanged,
} from 'services/events/actions';
import type { GeneratorProgressEvent, InvocationCompleteEvent } from 'services/events/types';
export type ProgressTab = 'linear' | 'canvas' | 'workflow';
export type LatestImageData = InvocationCompleteEvent & {
image_name: string;
};
export type ProgressState = {
_version: 1;
/**
* Whether or not the system is currently generating.
*/
isProcessing: boolean;
/**
* The batches that are currently being processed in the canvas, if any.
*/
canvasBatchIds: string[];
/**
* The current denoise progress of the canvas, if any.
*/
canvasDenoiseProgress: DenoiseProgress | null;
/**
* The latest image data for the canvas, if any.
*/
canvasLatestImageData: LatestImageData | null;
/**
* The batches that are currently being processed in the linear tabs, if any.
*/
linearBatchIds: string[];
/**
* The current denoise progress of the linear tabs, if any.
*/
linearDenoiseProgress: DenoiseProgress | null;
/**
* The latest image data for the linear tabs, if any.
*/
linearLatestImageData: LatestImageData | null;
/**
* The batches that are currently being processed in the workflow, if any.
*/
workflowBatchIds: string[];
/**
* The current denoise progress of the workflow, if any.
*/
workflowDenoiseProgress: DenoiseProgress | null;
/**
* The latest image data for the workflow, if any.
*/
workflowLatestImageData: LatestImageData | null;
/**
* The latest denoise progress, regardless of tab.
*/
latestDenoiseProgress: DenoiseProgress | null;
};
export const initialProgressState: ProgressState = {
_version: 1,
isProcessing: false,
canvasBatchIds: [],
canvasDenoiseProgress: null,
canvasLatestImageData: null,
linearBatchIds: [],
linearDenoiseProgress: null,
linearLatestImageData: null,
workflowBatchIds: [],
workflowDenoiseProgress: null,
workflowLatestImageData: null,
latestDenoiseProgress: null,
};
export const progressPersistDenylist: (keyof ProgressState)[] = [
'canvasDenoiseProgress',
'canvasLatestImageData',
'linearDenoiseProgress',
'linearLatestImageData',
'workflowDenoiseProgress',
'workflowLatestImageData',
];
export const progressSlice = createSlice({
name: 'progress',
initialState: initialProgressState,
reducers: {
canvasBatchEnqueued: (state, action: PayloadAction<string>) => {
state.canvasBatchIds.push(action.payload);
},
linearBatchEnqueued: (state, action: PayloadAction<string>) => {
state.linearBatchIds.push(action.payload);
},
workflowBatchEnqueued: (state, action: PayloadAction<string>) => {
state.workflowBatchIds.push(action.payload);
},
imageInvocationComplete: (state, action: PayloadAction<{ data: InvocationCompleteEvent; imageDTO: ImageDTO }>) => {
const { data, imageDTO } = action.payload;
if (state.canvasBatchIds.includes(data.queue_batch_id)) {
state.canvasLatestImageData = { ...data, image_name: imageDTO.image_name };
}
if (state.linearBatchIds.includes(data.queue_batch_id)) {
state.linearLatestImageData = { ...data, image_name: imageDTO.image_name };
}
if (state.workflowBatchIds.includes(data.queue_batch_id)) {
state.workflowLatestImageData = { ...data, image_name: imageDTO.image_name };
}
},
latestLinearImageLoaded: (state) => {
state.linearDenoiseProgress = null;
},
},
extraReducers(builder) {
builder.addCase(socketConnected, (state) => {
state.canvasDenoiseProgress = null;
state.linearDenoiseProgress = null;
state.workflowDenoiseProgress = null;
});
builder.addCase(socketDisconnected, (state) => {
state.canvasDenoiseProgress = null;
state.linearDenoiseProgress = null;
state.workflowDenoiseProgress = null;
});
builder.addCase(socketInvocationComplete, (state) => {
state.latestDenoiseProgress = null;
});
builder.addCase(socketGeneratorProgress, (state, action) => {
const denoiseProgress = buildDenoiseProgress(action.payload.data);
state.latestDenoiseProgress = denoiseProgress;
if (state.linearBatchIds.includes(action.payload.data.queue_batch_id)) {
state.linearDenoiseProgress = denoiseProgress;
}
if (state.canvasBatchIds.includes(action.payload.data.queue_batch_id)) {
state.canvasDenoiseProgress = denoiseProgress;
}
if (state.workflowBatchIds.includes(action.payload.data.queue_batch_id)) {
state.workflowDenoiseProgress = denoiseProgress;
}
});
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
// This logic only applies to the linear and workflow views. Canvas progress images are linked to the staging area
// and handled separately.
// When the queue is empty, clear progress and batch ids.
if (!action.payload.data.queue_status.in_progress && !action.payload.data.queue_status.pending) {
// state.linearBatchIds = [];
state.workflowBatchIds = [];
// state.canvasDenoiseProgress = null;
// state.linearDenoiseProgress = null;
state.workflowDenoiseProgress = null;
return;
}
// // If the current queue item / session has just finished *and* we are storing its progress, clear the progress.
// const { status, session_id } = action.payload.data.queue_item;
// if (['completed', 'canceled', 'failed'].includes(status)) {
// if (state.canvasDenoiseProgress?.graph_execution_state_id === session_id) {
// state.canvasDenoiseProgress = null;
// }
// if (state.linearDenoiseProgress?.graph_execution_state_id === session_id) {
// state.linearDenoiseProgress = null;
// }
// if (state.workflowDenoiseProgress?.graph_execution_state_id === session_id) {
// state.workflowDenoiseProgress = null;
// }
// }
});
builder.addCase(addImageToStagingArea, (state) => {
state.canvasDenoiseProgress = null;
});
// builder.addCase(imageSelected, (state, action) => {
// if (
// action.payload?.session_id &&
// state.linearDenoiseProgress?.graph_execution_state_id === action.payload.session_id
// ) {
// state.linearDenoiseProgress = null;
// }
// });
builder.addMatcher(
isAnyOf(commitStagingAreaImage, discardStagedImages, resetCanvas, setInitialCanvasImage),
(state) => {
// These actions all should result in the canvas progress being cleared.
state.canvasDenoiseProgress = null;
state.canvasBatchIds = [];
}
);
builder.addMatcher(queueApi.endpoints.clearQueue.matchFulfilled, (state) => {
// When the queue is cleared, all progress is cleared
state.canvasBatchIds = [];
state.linearBatchIds = [];
state.workflowBatchIds = [];
state.canvasDenoiseProgress = null;
state.linearDenoiseProgress = null;
state.workflowDenoiseProgress = null;
});
builder.addMatcher(queueApi.endpoints.cancelByBatchIds.matchFulfilled, (state, action) => {
// When a batch is canceled, remove it from the list of batch ids and clear its progress if it is stored.
const canceled_batch_ids = action.meta.arg.originalArgs.batch_ids;
state.canvasBatchIds = state.canvasBatchIds.filter((id) => !canceled_batch_ids.includes(id));
state.linearBatchIds = state.linearBatchIds.filter((id) => !canceled_batch_ids.includes(id));
state.workflowBatchIds = state.workflowBatchIds.filter((id) => !canceled_batch_ids.includes(id));
if (
state.canvasDenoiseProgress?.graph_execution_state_id &&
canceled_batch_ids.includes(state.canvasDenoiseProgress.graph_execution_state_id)
) {
state.canvasDenoiseProgress = null;
}
if (
state.linearDenoiseProgress?.graph_execution_state_id &&
canceled_batch_ids.includes(state.linearDenoiseProgress.graph_execution_state_id)
) {
state.linearDenoiseProgress = null;
}
if (
state.workflowDenoiseProgress?.graph_execution_state_id &&
canceled_batch_ids.includes(state.workflowDenoiseProgress.graph_execution_state_id)
) {
state.workflowDenoiseProgress = null;
}
});
},
});
export const {
canvasBatchEnqueued,
linearBatchEnqueued,
workflowBatchEnqueued,
imageInvocationComplete,
latestLinearImageLoaded,
} = progressSlice.actions;
export const selectProgressSlice = (state: RootState) => state.progress;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateProgressState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
const buildDenoiseProgress = (data: GeneratorProgressEvent): DenoiseProgress => ({
...data,
percentage: calculateStepPercentage(data.step, data.total_steps, data.order),
});
export const progressPersistConfig: PersistConfig<ProgressState> = {
name: progressSlice.name,
initialState: initialProgressState,
migrate: migrateProgressState,
persistDenylist: progressPersistDenylist,
};

View File

@@ -0,0 +1,7 @@
import type { GeneratorProgressEvent } from 'services/events/types';
export type SystemStatus = 'CONNECTED' | 'DISCONNECTED' | 'PROCESSING' | 'ERROR' | 'LOADING_MODEL';
export type DenoiseProgress = GeneratorProgressEvent & {
percentage: number;
};

View File

@@ -1,7 +1,7 @@
import { ButtonGroup, Flex, Spacer } from '@invoke-ai/ui-library';
import ProgressBar from 'features/progress/components/ProgressBar';
import ClearQueueIconButton from 'features/queue/components/ClearQueueIconButton';
import QueueFrontButton from 'features/queue/components/QueueFrontButton';
import ProgressBar from 'features/system/components/ProgressBar';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';

View File

@@ -13,7 +13,7 @@ export const Viewer = memo(() => {
return (
<Flex position="relative" flexDirection="column" height="full" width="full" gap={4}>
<ViewerToolbar />
<Flex height="full" width="full">
<Flex height="full" width="full" alignItems="center" justifyContent="center">
{viewerMode === 'image' && <ViewerImage />}
{viewerMode === 'info' && <ViewerInfo />}
{viewerMode === 'progress' && <ViewerProgress />}

View File

@@ -1,14 +1,19 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex , Image } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { Image } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo, useMemo } from 'react';
import { latestLinearImageLoaded } from 'features/progress/store/progressSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiHourglassBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
export const ViewerProgress = memo(() => {
const { t } = useTranslation();
const progress_image = useAppSelector((s) => s.system.denoiseProgress?.progress_image);
const dispatch = useAppDispatch();
const linearDenoiseProgress = useAppSelector((s) => s.progress.linearDenoiseProgress);
const linearLatestImageData = useAppSelector((s) => s.progress.linearLatestImageData);
const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage);
const sx = useMemo<SystemStyleObject>(
@@ -18,16 +23,53 @@ export const ViewerProgress = memo(() => {
[shouldAntialiasProgressImage]
);
if (!progress_image) {
return <IAINoContentFallback icon={PiHourglassBold} label={t('viewer.noProgress')} />;
const shouldShowOutputImage = useMemo(() => {
if (
linearDenoiseProgress &&
linearLatestImageData &&
linearDenoiseProgress.graph_execution_state_id === linearLatestImageData.graph_execution_state_id
) {
return true;
}
if (!linearDenoiseProgress?.progress_image && linearLatestImageData) {
return true;
}
return false;
}, [linearDenoiseProgress, linearLatestImageData]);
const { data: imageDTO } = useGetImageDTOQuery(linearLatestImageData?.image_name ?? skipToken);
const onLoad = useCallback(() => {
dispatch(latestLinearImageLoaded());
}, [dispatch]);
if (shouldShowOutputImage && imageDTO) {
return (
<Image
src={imageDTO.image_url}
width={imageDTO.width}
height={imageDTO.height}
fallbackSrc={linearDenoiseProgress?.progress_image?.dataURL}
draggable={false}
data-testid="output-image"
objectFit="contain"
maxWidth="full"
maxHeight="full"
position="absolute"
borderRadius="base"
onLoad={onLoad}
/>
);
}
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
if (linearDenoiseProgress?.progress_image) {
return (
<Image
src={progress_image.dataURL}
width={progress_image.width}
height={progress_image.height}
src={linearDenoiseProgress.progress_image.dataURL}
width={linearDenoiseProgress.progress_image.width}
height={linearDenoiseProgress.progress_image.height}
draggable={false}
data-testid="progress-image"
objectFit="contain"
@@ -37,8 +79,10 @@ export const ViewerProgress = memo(() => {
borderRadius="base"
sx={sx}
/>
</Flex>
);
);
}
return <IAINoContentFallback icon={PiHourglassBold} label={t('viewer.noProgress')} />;
});
ViewerProgress.displayName = 'ViewerProgress';

View File

@@ -9,13 +9,12 @@ import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMenu/SingleSelectionMenuItems';
import { sentImageToImg2Img } from 'features/gallery/store/actions';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import ParamUpscalePopover from 'features/parameters/components/Upscale/ParamUpscaleSettings';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { selectProgressSlice } from 'features/progress/store/progressSlice';
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { viewerModeChanged } from 'features/viewer/store/viewerSlice';
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
import { memo, useCallback } from 'react';
@@ -37,11 +36,10 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
const selectShouldDisableToolbarButtons = createSelector(
selectSystemSlice,
selectGallerySlice,
selectProgressSlice,
selectLastSelectedImage,
(system, gallery, lastSelectedImage) => {
const hasProgressImage = Boolean(system.denoiseProgress?.progress_image);
(progress, lastSelectedImage) => {
const hasProgressImage = Boolean(progress.linearDenoiseProgress?.progress_image);
return hasProgressImage || !lastSelectedImage;
}
);

View File

@@ -1,6 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { imageSelected, selectionChanged } from 'features/gallery/store/gallerySlice';
export type ViewerMode = 'image' | 'info' | 'progress';
@@ -25,6 +26,14 @@ export const viewerSlice = createSlice({
state.viewerMode = action.payload;
},
},
extraReducers: (builder) => {
builder.addMatcher(isAnyOf(imageSelected, selectionChanged), (state) => {
// When a gallery image is selected and we are in progress mode, switch to image mode
if (state.viewerMode === 'progress') {
// state.viewerMode = 'image';
}
});
},
});
export const { viewerModeChanged } = viewerSlice.actions;