mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): add progress slice and move progress to it
This slice tracks denoising progress for linear, canvas and workflow tabs. `ViewerProgress` now uses it for showing progress images.
This commit is contained in:
@@ -1,25 +1,25 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { canvasBatchIdsReset, commitStagingAreaImage, discardStagedImages } from 'features/canvas/store/canvasSlice';
|
||||
import { matchAnyStagingAreaDismissed } from 'features/canvas/store/canvasSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages);
|
||||
|
||||
export const addCommitStagingAreaImageListener = () => {
|
||||
startAppListening({
|
||||
matcher,
|
||||
matcher: matchAnyStagingAreaDismissed,
|
||||
effect: async (_, { dispatch, getState }) => {
|
||||
const log = logger('canvas');
|
||||
const state = getState();
|
||||
const { batchIds } = state.canvas;
|
||||
const { canvasBatchIds } = state.progress;
|
||||
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: batchIds }, { fixedCacheKey: 'cancelByBatchIds' })
|
||||
queueApi.endpoints.cancelByBatchIds.initiate(
|
||||
{ batch_ids: canvasBatchIds },
|
||||
{ fixedCacheKey: 'cancelByBatchIds' }
|
||||
)
|
||||
);
|
||||
const { canceled } = await req.unwrap();
|
||||
req.reset();
|
||||
@@ -32,7 +32,6 @@ export const addCommitStagingAreaImageListener = () => {
|
||||
})
|
||||
);
|
||||
}
|
||||
dispatch(canvasBatchIdsReset());
|
||||
} catch {
|
||||
log.error('Failed to cancel canvas batches');
|
||||
dispatch(
|
||||
|
||||
@@ -2,13 +2,14 @@ import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { canvasBatchIdAdded, stagingAreaInitialized } from 'features/canvas/store/canvasSlice';
|
||||
import { stagingAreaInitialized } from 'features/canvas/store/canvasSlice';
|
||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { buildCanvasGraph } from 'features/nodes/util/graph/buildCanvasGraph';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { canvasBatchEnqueued } from 'features/progress/store/progressSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -121,8 +122,6 @@ export const addEnqueueRequestedCanvasListener = () => {
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
|
||||
|
||||
// Prep the canvas staging area if it is not yet initialized
|
||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||
dispatch(
|
||||
@@ -135,8 +134,9 @@ export const addEnqueueRequestedCanvasListener = () => {
|
||||
);
|
||||
}
|
||||
|
||||
// Associate the session with the canvas session ID
|
||||
dispatch(canvasBatchIdAdded(batchId));
|
||||
if (enqueueResult.batch.batch_id) {
|
||||
dispatch(canvasBatchEnqueued(enqueueResult.batch.batch_id));
|
||||
}
|
||||
} catch {
|
||||
// no-op
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import { buildLinearImageToImageGraph } from 'features/nodes/util/graph/buildLin
|
||||
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graph/buildLinearSDXLImageToImageGraph';
|
||||
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graph/buildLinearSDXLTextToImageGraph';
|
||||
import { buildLinearTextToImageGraph } from 'features/nodes/util/graph/buildLinearTextToImageGraph';
|
||||
import { linearBatchEnqueued } from 'features/progress/store/progressSlice';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
import { startAppListening } from '..';
|
||||
@@ -35,12 +36,20 @@ export const addEnqueueRequestedLinear = () => {
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
req.reset();
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
if (enqueueResult.batch.batch_id) {
|
||||
dispatch(linearBatchEnqueued(enqueueResult.batch.batch_id));
|
||||
}
|
||||
} catch {
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { workflowBatchEnqueued } from 'features/progress/store/progressSlice';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig } from 'services/api/types';
|
||||
|
||||
@@ -35,12 +36,20 @@ export const addEnqueueRequestedNodes = () => {
|
||||
prepend: action.payload.prepend,
|
||||
};
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
req.reset();
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
if (enqueueResult.batch.batch_id) {
|
||||
dispatch(workflowBatchEnqueued(enqueueResult.batch.batch_id));
|
||||
}
|
||||
} catch {
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -5,6 +5,7 @@ import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gal
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { LINEAR_UI_OUTPUT, nodeIDDenyList } from 'features/nodes/util/graph/constants';
|
||||
import { imageInvocationComplete } from 'features/progress/store/progressSlice';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
@@ -29,7 +30,7 @@ export const addInvocationCompleteEventListener = () => {
|
||||
// This complete event has an associated image output
|
||||
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type) && !nodeIDDenyList.includes(source_node_id)) {
|
||||
const { image_name } = result.image;
|
||||
const { canvas, gallery } = getState();
|
||||
const { gallery, progress } = getState();
|
||||
|
||||
// This populates the `getImageDTO` cache
|
||||
const imageDTORequest = dispatch(
|
||||
@@ -41,8 +42,10 @@ export const addInvocationCompleteEventListener = () => {
|
||||
const imageDTO = await imageDTORequest.unwrap();
|
||||
imageDTORequest.unsubscribe();
|
||||
|
||||
dispatch(imageInvocationComplete({ data, imageDTO }));
|
||||
|
||||
// Add canvas images to the staging area
|
||||
if (canvas.batchIds.includes(queue_batch_id) && [LINEAR_UI_OUTPUT].includes(data.source_node_id)) {
|
||||
if (progress.canvasBatchIds.includes(queue_batch_id) && [LINEAR_UI_OUTPUT].includes(data.source_node_id)) {
|
||||
dispatch(addImageToStagingArea(imageDTO));
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import { nodesTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
||||
import { progressPersistConfig, progressSlice } from 'features/progress/store/progressSlice';
|
||||
import { queueSlice } from 'features/queue/store/queueSlice';
|
||||
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
||||
import { configSlice } from 'features/system/store/configSlice';
|
||||
@@ -63,6 +64,7 @@ const allReducers = {
|
||||
[workflowSlice.name]: workflowSlice.reducer,
|
||||
[hrfSlice.name]: hrfSlice.reducer,
|
||||
[viewerSlice.name]: viewerSlice.reducer,
|
||||
[progressSlice.name]: progressSlice.reducer,
|
||||
[api.reducerPath]: api.reducer,
|
||||
};
|
||||
|
||||
@@ -108,6 +110,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[modelManagerPersistConfig.name]: modelManagerPersistConfig,
|
||||
[hrfPersistConfig.name]: hrfPersistConfig,
|
||||
[viewerPersistConfig.name]: viewerPersistConfig,
|
||||
[progressPersistConfig.name]: progressPersistConfig,
|
||||
};
|
||||
|
||||
const unserialize: UnserializeFunction = (data, key) => {
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectCanvasSlice } from 'features/canvas/store/canvasSlice';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import { selectProgressSlice } from 'features/progress/store/progressSlice';
|
||||
import { memo, useEffect, useState } from 'react';
|
||||
import { Image as KonvaImage } from 'react-konva';
|
||||
|
||||
const progressImageSelector = createMemoizedSelector([selectSystemSlice, selectCanvasSlice], (system, canvas) => {
|
||||
const { denoiseProgress } = system;
|
||||
const { batchIds } = canvas;
|
||||
|
||||
const progressImageSelector = createMemoizedSelector([selectProgressSlice, selectCanvasSlice], (progress, canvas) => {
|
||||
return {
|
||||
progressImage:
|
||||
denoiseProgress && batchIds.includes(denoiseProgress.batch_id) ? denoiseProgress.progress_image : undefined,
|
||||
progressImage: progress.canvasDenoiseProgress?.progress_image,
|
||||
boundingBox: canvas.layerState.stagingArea.boundingBox,
|
||||
};
|
||||
});
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { selectProgressSlice } from 'features/progress/store/progressSlice';
|
||||
|
||||
import { selectCanvasSlice } from './canvasSlice';
|
||||
import { isCanvasBaseImage } from './canvasTypes';
|
||||
|
||||
export const isStagingSelector = createSelector(
|
||||
selectProgressSlice,
|
||||
selectCanvasSlice,
|
||||
(canvas) => canvas.batchIds.length > 0 || canvas.layerState.stagingArea.images.length > 0
|
||||
(progress, canvas) => progress.canvasBatchIds.length > 0 || canvas.layerState.stagingArea.images.length > 0
|
||||
);
|
||||
|
||||
export const initialCanvasImageSelector = createMemoizedSelector(selectCanvasSlice, (canvas) =>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import calculateCoordinates from 'features/canvas/util/calculateCoordinates';
|
||||
@@ -15,9 +15,7 @@ import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/
|
||||
import type { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { clamp, cloneDeep } from 'lodash-es';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { socketQueueItemStatusChanged } from 'services/events/actions';
|
||||
|
||||
import type {
|
||||
BoundingBoxScaleMethod,
|
||||
@@ -79,7 +77,6 @@ export const initialCanvasState: CanvasState = {
|
||||
stageCoordinates: { x: 0, y: 0 },
|
||||
stageDimensions: { width: 0, height: 0 },
|
||||
stageScale: 1,
|
||||
batchIds: [],
|
||||
aspectRatio: {
|
||||
id: '1:1',
|
||||
value: 1,
|
||||
@@ -180,7 +177,6 @@ export const canvasSlice = createSlice({
|
||||
],
|
||||
};
|
||||
state.futureLayerStates = [];
|
||||
state.batchIds = [];
|
||||
|
||||
const newScale = calculateScale(
|
||||
stageDimensions.width,
|
||||
@@ -237,12 +233,6 @@ export const canvasSlice = createSlice({
|
||||
setShouldShowBoundingBox: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowBoundingBox = action.payload;
|
||||
},
|
||||
canvasBatchIdAdded: (state, action: PayloadAction<string>) => {
|
||||
state.batchIds.push(action.payload);
|
||||
},
|
||||
canvasBatchIdsReset: (state) => {
|
||||
state.batchIds = [];
|
||||
},
|
||||
stagingAreaInitialized: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@@ -293,7 +283,6 @@ export const canvasSlice = createSlice({
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
state.shouldShowStagingImage = true;
|
||||
state.batchIds = [];
|
||||
},
|
||||
addFillRect: (state) => {
|
||||
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = state;
|
||||
@@ -426,7 +415,6 @@ export const canvasSlice = createSlice({
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
state.layerState = cloneDeep(initialLayerState);
|
||||
state.futureLayerStates = [];
|
||||
state.batchIds = [];
|
||||
state.boundingBoxCoordinates = {
|
||||
...initialCanvasState.boundingBoxCoordinates,
|
||||
};
|
||||
@@ -536,7 +524,6 @@ export const canvasSlice = createSlice({
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
state.shouldShowStagingImage = true;
|
||||
state.batchIds = [];
|
||||
},
|
||||
setBoundingBoxScaleMethod: {
|
||||
reducer: (state, action: PayloadActionWithOptimalDimension<BoundingBoxScaleMethod>) => {
|
||||
@@ -644,23 +631,6 @@ export const canvasSlice = createSlice({
|
||||
optimalDimension
|
||||
);
|
||||
});
|
||||
|
||||
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
|
||||
const batch_status = action.payload.data.batch_status;
|
||||
if (!state.batchIds.includes(batch_status.batch_id)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (batch_status.in_progress === 0 && batch_status.pending === 0) {
|
||||
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
|
||||
}
|
||||
});
|
||||
builder.addMatcher(queueApi.endpoints.clearQueue.matchFulfilled, (state) => {
|
||||
state.batchIds = [];
|
||||
});
|
||||
builder.addMatcher(queueApi.endpoints.cancelByBatchIds.matchFulfilled, (state, action) => {
|
||||
state.batchIds = state.batchIds.filter((id) => !action.meta.arg.originalArgs.batch_ids.includes(id));
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
@@ -713,8 +683,6 @@ export const {
|
||||
stagingAreaInitialized,
|
||||
setShouldAntialias,
|
||||
canvasResized,
|
||||
canvasBatchIdAdded,
|
||||
canvasBatchIdsReset,
|
||||
aspectRatioChanged,
|
||||
scaledBoundingBoxDimensionsReset,
|
||||
} = canvasSlice.actions;
|
||||
@@ -736,3 +704,5 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
|
||||
migrate: migrateCanvasState,
|
||||
persistDenylist: [],
|
||||
};
|
||||
|
||||
export const matchAnyStagingAreaDismissed = isAnyOf(commitStagingAreaImage, discardStagedImages);
|
||||
|
||||
@@ -142,7 +142,6 @@ export interface CanvasState {
|
||||
stageDimensions: Dimensions;
|
||||
stageScale: number;
|
||||
generationMode?: GenerationMode;
|
||||
batchIds: string[];
|
||||
aspectRatio: AspectRatioState;
|
||||
}
|
||||
|
||||
|
||||
13
invokeai/frontend/web/src/features/progress/README.md
Normal file
13
invokeai/frontend/web/src/features/progress/README.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# Progress
|
||||
|
||||
We have 3 different places to display progress images:
|
||||
|
||||
- TextToImage & ImageToImage
|
||||
- Canvas
|
||||
- Workflow
|
||||
|
||||
The progress slice tracks the latest denoising progress events, latest image output, and active batch ids for each of the workspaces.
|
||||
|
||||
Each of these have different requirements for displaying progress images, but much of the logic around tracking progress is the same, so it is consolidated here.
|
||||
|
||||
It also holds the latest progress event separately, which is used for the progress bar.
|
||||
@@ -1,28 +1,24 @@
|
||||
import { Progress } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import { memo } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
|
||||
const selectProgressValue = createSelector(
|
||||
selectSystemSlice,
|
||||
(system) => (system.denoiseProgress?.percentage ?? 0) * 100
|
||||
);
|
||||
|
||||
const ProgressBar = () => {
|
||||
const { t } = useTranslation();
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const hasSteps = useAppSelector((s) => Boolean(s.system.denoiseProgress));
|
||||
const value = useAppSelector(selectProgressValue);
|
||||
const hasSteps = useAppSelector((s) => Boolean(s.progress.latestDenoiseProgress));
|
||||
const value = useAppSelector((s) => (s.progress.latestDenoiseProgress?.percentage ?? 0) * 100);
|
||||
const isIndeterminate = useMemo(() => {
|
||||
return isConnected && Boolean(queueStatus?.queue.in_progress) && !hasSteps;
|
||||
}, [hasSteps, isConnected, queueStatus?.queue.in_progress]);
|
||||
|
||||
return (
|
||||
<Progress
|
||||
value={value}
|
||||
aria-label={t('accessibility.invokeProgressBar')}
|
||||
isIndeterminate={isConnected && Boolean(queueStatus?.queue.in_progress) && !hasSteps}
|
||||
isIndeterminate={isIndeterminate}
|
||||
h={2}
|
||||
w="full"
|
||||
colorScheme="invokeBlue"
|
||||
@@ -0,0 +1,281 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import {
|
||||
addImageToStagingArea,
|
||||
commitStagingAreaImage,
|
||||
discardStagedImages,
|
||||
resetCanvas,
|
||||
setInitialCanvasImage,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import type { DenoiseProgress } from 'features/progress/store/types';
|
||||
import { calculateStepPercentage } from 'features/system/util/calculateStepPercentage';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
socketConnected,
|
||||
socketDisconnected,
|
||||
socketGeneratorProgress,
|
||||
socketInvocationComplete,
|
||||
socketQueueItemStatusChanged,
|
||||
} from 'services/events/actions';
|
||||
import type { GeneratorProgressEvent, InvocationCompleteEvent } from 'services/events/types';
|
||||
|
||||
export type ProgressTab = 'linear' | 'canvas' | 'workflow';
|
||||
|
||||
export type LatestImageData = InvocationCompleteEvent & {
|
||||
image_name: string;
|
||||
};
|
||||
|
||||
export type ProgressState = {
|
||||
_version: 1;
|
||||
/**
|
||||
* Whether or not the system is currently generating.
|
||||
*/
|
||||
isProcessing: boolean;
|
||||
/**
|
||||
* The batches that are currently being processed in the canvas, if any.
|
||||
*/
|
||||
canvasBatchIds: string[];
|
||||
/**
|
||||
* The current denoise progress of the canvas, if any.
|
||||
*/
|
||||
canvasDenoiseProgress: DenoiseProgress | null;
|
||||
/**
|
||||
* The latest image data for the canvas, if any.
|
||||
*/
|
||||
canvasLatestImageData: LatestImageData | null;
|
||||
/**
|
||||
* The batches that are currently being processed in the linear tabs, if any.
|
||||
*/
|
||||
linearBatchIds: string[];
|
||||
/**
|
||||
* The current denoise progress of the linear tabs, if any.
|
||||
*/
|
||||
linearDenoiseProgress: DenoiseProgress | null;
|
||||
/**
|
||||
* The latest image data for the linear tabs, if any.
|
||||
*/
|
||||
linearLatestImageData: LatestImageData | null;
|
||||
/**
|
||||
* The batches that are currently being processed in the workflow, if any.
|
||||
*/
|
||||
workflowBatchIds: string[];
|
||||
/**
|
||||
* The current denoise progress of the workflow, if any.
|
||||
*/
|
||||
workflowDenoiseProgress: DenoiseProgress | null;
|
||||
/**
|
||||
* The latest image data for the workflow, if any.
|
||||
*/
|
||||
workflowLatestImageData: LatestImageData | null;
|
||||
/**
|
||||
* The latest denoise progress, regardless of tab.
|
||||
*/
|
||||
latestDenoiseProgress: DenoiseProgress | null;
|
||||
};
|
||||
|
||||
export const initialProgressState: ProgressState = {
|
||||
_version: 1,
|
||||
isProcessing: false,
|
||||
canvasBatchIds: [],
|
||||
canvasDenoiseProgress: null,
|
||||
canvasLatestImageData: null,
|
||||
linearBatchIds: [],
|
||||
linearDenoiseProgress: null,
|
||||
linearLatestImageData: null,
|
||||
workflowBatchIds: [],
|
||||
workflowDenoiseProgress: null,
|
||||
workflowLatestImageData: null,
|
||||
latestDenoiseProgress: null,
|
||||
};
|
||||
|
||||
export const progressPersistDenylist: (keyof ProgressState)[] = [
|
||||
'canvasDenoiseProgress',
|
||||
'canvasLatestImageData',
|
||||
'linearDenoiseProgress',
|
||||
'linearLatestImageData',
|
||||
'workflowDenoiseProgress',
|
||||
'workflowLatestImageData',
|
||||
];
|
||||
|
||||
export const progressSlice = createSlice({
|
||||
name: 'progress',
|
||||
initialState: initialProgressState,
|
||||
reducers: {
|
||||
canvasBatchEnqueued: (state, action: PayloadAction<string>) => {
|
||||
state.canvasBatchIds.push(action.payload);
|
||||
},
|
||||
linearBatchEnqueued: (state, action: PayloadAction<string>) => {
|
||||
state.linearBatchIds.push(action.payload);
|
||||
},
|
||||
workflowBatchEnqueued: (state, action: PayloadAction<string>) => {
|
||||
state.workflowBatchIds.push(action.payload);
|
||||
},
|
||||
imageInvocationComplete: (state, action: PayloadAction<{ data: InvocationCompleteEvent; imageDTO: ImageDTO }>) => {
|
||||
const { data, imageDTO } = action.payload;
|
||||
if (state.canvasBatchIds.includes(data.queue_batch_id)) {
|
||||
state.canvasLatestImageData = { ...data, image_name: imageDTO.image_name };
|
||||
}
|
||||
if (state.linearBatchIds.includes(data.queue_batch_id)) {
|
||||
state.linearLatestImageData = { ...data, image_name: imageDTO.image_name };
|
||||
}
|
||||
if (state.workflowBatchIds.includes(data.queue_batch_id)) {
|
||||
state.workflowLatestImageData = { ...data, image_name: imageDTO.image_name };
|
||||
}
|
||||
},
|
||||
latestLinearImageLoaded: (state) => {
|
||||
state.linearDenoiseProgress = null;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(socketConnected, (state) => {
|
||||
state.canvasDenoiseProgress = null;
|
||||
state.linearDenoiseProgress = null;
|
||||
state.workflowDenoiseProgress = null;
|
||||
});
|
||||
|
||||
builder.addCase(socketDisconnected, (state) => {
|
||||
state.canvasDenoiseProgress = null;
|
||||
state.linearDenoiseProgress = null;
|
||||
state.workflowDenoiseProgress = null;
|
||||
});
|
||||
|
||||
builder.addCase(socketInvocationComplete, (state) => {
|
||||
state.latestDenoiseProgress = null;
|
||||
});
|
||||
|
||||
builder.addCase(socketGeneratorProgress, (state, action) => {
|
||||
const denoiseProgress = buildDenoiseProgress(action.payload.data);
|
||||
state.latestDenoiseProgress = denoiseProgress;
|
||||
if (state.linearBatchIds.includes(action.payload.data.queue_batch_id)) {
|
||||
state.linearDenoiseProgress = denoiseProgress;
|
||||
}
|
||||
if (state.canvasBatchIds.includes(action.payload.data.queue_batch_id)) {
|
||||
state.canvasDenoiseProgress = denoiseProgress;
|
||||
}
|
||||
if (state.workflowBatchIds.includes(action.payload.data.queue_batch_id)) {
|
||||
state.workflowDenoiseProgress = denoiseProgress;
|
||||
}
|
||||
});
|
||||
|
||||
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
|
||||
// This logic only applies to the linear and workflow views. Canvas progress images are linked to the staging area
|
||||
// and handled separately.
|
||||
|
||||
// When the queue is empty, clear progress and batch ids.
|
||||
if (!action.payload.data.queue_status.in_progress && !action.payload.data.queue_status.pending) {
|
||||
// state.linearBatchIds = [];
|
||||
state.workflowBatchIds = [];
|
||||
// state.canvasDenoiseProgress = null;
|
||||
// state.linearDenoiseProgress = null;
|
||||
state.workflowDenoiseProgress = null;
|
||||
return;
|
||||
}
|
||||
|
||||
// // If the current queue item / session has just finished *and* we are storing its progress, clear the progress.
|
||||
// const { status, session_id } = action.payload.data.queue_item;
|
||||
// if (['completed', 'canceled', 'failed'].includes(status)) {
|
||||
// if (state.canvasDenoiseProgress?.graph_execution_state_id === session_id) {
|
||||
// state.canvasDenoiseProgress = null;
|
||||
// }
|
||||
// if (state.linearDenoiseProgress?.graph_execution_state_id === session_id) {
|
||||
// state.linearDenoiseProgress = null;
|
||||
// }
|
||||
// if (state.workflowDenoiseProgress?.graph_execution_state_id === session_id) {
|
||||
// state.workflowDenoiseProgress = null;
|
||||
// }
|
||||
// }
|
||||
});
|
||||
|
||||
builder.addCase(addImageToStagingArea, (state) => {
|
||||
state.canvasDenoiseProgress = null;
|
||||
});
|
||||
|
||||
// builder.addCase(imageSelected, (state, action) => {
|
||||
// if (
|
||||
// action.payload?.session_id &&
|
||||
// state.linearDenoiseProgress?.graph_execution_state_id === action.payload.session_id
|
||||
// ) {
|
||||
// state.linearDenoiseProgress = null;
|
||||
// }
|
||||
// });
|
||||
|
||||
builder.addMatcher(
|
||||
isAnyOf(commitStagingAreaImage, discardStagedImages, resetCanvas, setInitialCanvasImage),
|
||||
(state) => {
|
||||
// These actions all should result in the canvas progress being cleared.
|
||||
state.canvasDenoiseProgress = null;
|
||||
state.canvasBatchIds = [];
|
||||
}
|
||||
);
|
||||
|
||||
builder.addMatcher(queueApi.endpoints.clearQueue.matchFulfilled, (state) => {
|
||||
// When the queue is cleared, all progress is cleared
|
||||
state.canvasBatchIds = [];
|
||||
state.linearBatchIds = [];
|
||||
state.workflowBatchIds = [];
|
||||
state.canvasDenoiseProgress = null;
|
||||
state.linearDenoiseProgress = null;
|
||||
state.workflowDenoiseProgress = null;
|
||||
});
|
||||
|
||||
builder.addMatcher(queueApi.endpoints.cancelByBatchIds.matchFulfilled, (state, action) => {
|
||||
// When a batch is canceled, remove it from the list of batch ids and clear its progress if it is stored.
|
||||
|
||||
const canceled_batch_ids = action.meta.arg.originalArgs.batch_ids;
|
||||
state.canvasBatchIds = state.canvasBatchIds.filter((id) => !canceled_batch_ids.includes(id));
|
||||
state.linearBatchIds = state.linearBatchIds.filter((id) => !canceled_batch_ids.includes(id));
|
||||
state.workflowBatchIds = state.workflowBatchIds.filter((id) => !canceled_batch_ids.includes(id));
|
||||
|
||||
if (
|
||||
state.canvasDenoiseProgress?.graph_execution_state_id &&
|
||||
canceled_batch_ids.includes(state.canvasDenoiseProgress.graph_execution_state_id)
|
||||
) {
|
||||
state.canvasDenoiseProgress = null;
|
||||
}
|
||||
if (
|
||||
state.linearDenoiseProgress?.graph_execution_state_id &&
|
||||
canceled_batch_ids.includes(state.linearDenoiseProgress.graph_execution_state_id)
|
||||
) {
|
||||
state.linearDenoiseProgress = null;
|
||||
}
|
||||
if (
|
||||
state.workflowDenoiseProgress?.graph_execution_state_id &&
|
||||
canceled_batch_ids.includes(state.workflowDenoiseProgress.graph_execution_state_id)
|
||||
) {
|
||||
state.workflowDenoiseProgress = null;
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
canvasBatchEnqueued,
|
||||
linearBatchEnqueued,
|
||||
workflowBatchEnqueued,
|
||||
imageInvocationComplete,
|
||||
latestLinearImageLoaded,
|
||||
} = progressSlice.actions;
|
||||
|
||||
export const selectProgressSlice = (state: RootState) => state.progress;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
export const migrateProgressState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
const buildDenoiseProgress = (data: GeneratorProgressEvent): DenoiseProgress => ({
|
||||
...data,
|
||||
percentage: calculateStepPercentage(data.step, data.total_steps, data.order),
|
||||
});
|
||||
|
||||
export const progressPersistConfig: PersistConfig<ProgressState> = {
|
||||
name: progressSlice.name,
|
||||
initialState: initialProgressState,
|
||||
migrate: migrateProgressState,
|
||||
persistDenylist: progressPersistDenylist,
|
||||
};
|
||||
@@ -0,0 +1,7 @@
|
||||
import type { GeneratorProgressEvent } from 'services/events/types';
|
||||
|
||||
export type SystemStatus = 'CONNECTED' | 'DISCONNECTED' | 'PROCESSING' | 'ERROR' | 'LOADING_MODEL';
|
||||
|
||||
export type DenoiseProgress = GeneratorProgressEvent & {
|
||||
percentage: number;
|
||||
};
|
||||
@@ -1,7 +1,7 @@
|
||||
import { ButtonGroup, Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import ProgressBar from 'features/progress/components/ProgressBar';
|
||||
import ClearQueueIconButton from 'features/queue/components/ClearQueueIconButton';
|
||||
import QueueFrontButton from 'features/queue/components/QueueFrontButton';
|
||||
import ProgressBar from 'features/system/components/ProgressBar';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ export const Viewer = memo(() => {
|
||||
return (
|
||||
<Flex position="relative" flexDirection="column" height="full" width="full" gap={4}>
|
||||
<ViewerToolbar />
|
||||
<Flex height="full" width="full">
|
||||
<Flex height="full" width="full" alignItems="center" justifyContent="center">
|
||||
{viewerMode === 'image' && <ViewerImage />}
|
||||
{viewerMode === 'info' && <ViewerInfo />}
|
||||
{viewerMode === 'progress' && <ViewerProgress />}
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex , Image } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { Image } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { latestLinearImageLoaded } from 'features/progress/store/progressSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiHourglassBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
|
||||
export const ViewerProgress = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const progress_image = useAppSelector((s) => s.system.denoiseProgress?.progress_image);
|
||||
const dispatch = useAppDispatch();
|
||||
const linearDenoiseProgress = useAppSelector((s) => s.progress.linearDenoiseProgress);
|
||||
const linearLatestImageData = useAppSelector((s) => s.progress.linearLatestImageData);
|
||||
const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage);
|
||||
|
||||
const sx = useMemo<SystemStyleObject>(
|
||||
@@ -18,16 +23,53 @@ export const ViewerProgress = memo(() => {
|
||||
[shouldAntialiasProgressImage]
|
||||
);
|
||||
|
||||
if (!progress_image) {
|
||||
return <IAINoContentFallback icon={PiHourglassBold} label={t('viewer.noProgress')} />;
|
||||
const shouldShowOutputImage = useMemo(() => {
|
||||
if (
|
||||
linearDenoiseProgress &&
|
||||
linearLatestImageData &&
|
||||
linearDenoiseProgress.graph_execution_state_id === linearLatestImageData.graph_execution_state_id
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!linearDenoiseProgress?.progress_image && linearLatestImageData) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}, [linearDenoiseProgress, linearLatestImageData]);
|
||||
|
||||
const { data: imageDTO } = useGetImageDTOQuery(linearLatestImageData?.image_name ?? skipToken);
|
||||
|
||||
const onLoad = useCallback(() => {
|
||||
dispatch(latestLinearImageLoaded());
|
||||
}, [dispatch]);
|
||||
|
||||
if (shouldShowOutputImage && imageDTO) {
|
||||
return (
|
||||
<Image
|
||||
src={imageDTO.image_url}
|
||||
width={imageDTO.width}
|
||||
height={imageDTO.height}
|
||||
fallbackSrc={linearDenoiseProgress?.progress_image?.dataURL}
|
||||
draggable={false}
|
||||
data-testid="output-image"
|
||||
objectFit="contain"
|
||||
maxWidth="full"
|
||||
maxHeight="full"
|
||||
position="absolute"
|
||||
borderRadius="base"
|
||||
onLoad={onLoad}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
if (linearDenoiseProgress?.progress_image) {
|
||||
return (
|
||||
<Image
|
||||
src={progress_image.dataURL}
|
||||
width={progress_image.width}
|
||||
height={progress_image.height}
|
||||
src={linearDenoiseProgress.progress_image.dataURL}
|
||||
width={linearDenoiseProgress.progress_image.width}
|
||||
height={linearDenoiseProgress.progress_image.height}
|
||||
draggable={false}
|
||||
data-testid="progress-image"
|
||||
objectFit="contain"
|
||||
@@ -37,8 +79,10 @@ export const ViewerProgress = memo(() => {
|
||||
borderRadius="base"
|
||||
sx={sx}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
);
|
||||
}
|
||||
|
||||
return <IAINoContentFallback icon={PiHourglassBold} label={t('viewer.noProgress')} />;
|
||||
});
|
||||
|
||||
ViewerProgress.displayName = 'ViewerProgress';
|
||||
|
||||
@@ -9,13 +9,12 @@ import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||
import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMenu/SingleSelectionMenuItems';
|
||||
import { sentImageToImg2Img } from 'features/gallery/store/actions';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import ParamUpscalePopover from 'features/parameters/components/Upscale/ParamUpscaleSettings';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { selectProgressSlice } from 'features/progress/store/progressSlice';
|
||||
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import { viewerModeChanged } from 'features/viewer/store/viewerSlice';
|
||||
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
|
||||
import { memo, useCallback } from 'react';
|
||||
@@ -37,11 +36,10 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||
|
||||
const selectShouldDisableToolbarButtons = createSelector(
|
||||
selectSystemSlice,
|
||||
selectGallerySlice,
|
||||
selectProgressSlice,
|
||||
selectLastSelectedImage,
|
||||
(system, gallery, lastSelectedImage) => {
|
||||
const hasProgressImage = Boolean(system.denoiseProgress?.progress_image);
|
||||
(progress, lastSelectedImage) => {
|
||||
const hasProgressImage = Boolean(progress.linearDenoiseProgress?.progress_image);
|
||||
return hasProgressImage || !lastSelectedImage;
|
||||
}
|
||||
);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { imageSelected, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
|
||||
export type ViewerMode = 'image' | 'info' | 'progress';
|
||||
|
||||
@@ -25,6 +26,14 @@ export const viewerSlice = createSlice({
|
||||
state.viewerMode = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addMatcher(isAnyOf(imageSelected, selectionChanged), (state) => {
|
||||
// When a gallery image is selected and we are in progress mode, switch to image mode
|
||||
if (state.viewerMode === 'progress') {
|
||||
// state.viewerMode = 'image';
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const { viewerModeChanged } = viewerSlice.actions;
|
||||
|
||||
Reference in New Issue
Block a user