mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 06:14:58 -05:00
wip progress events
This commit is contained in:
@@ -3,6 +3,7 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import { getDebugLoggerMiddleware } from 'app/store/middleware/debugLoggerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
@@ -175,6 +176,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
.concat(api.middleware)
|
||||
.concat(dynamicMiddlewares)
|
||||
.concat(authToastMiddleware)
|
||||
.concat(getDebugLoggerMiddleware())
|
||||
.prepend(listenerMiddleware.middleware),
|
||||
enhancers: (getDefaultEnhancers) => {
|
||||
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
ButtonGroup,
|
||||
ContextMenu,
|
||||
@@ -40,12 +41,10 @@ import {
|
||||
selectCanvasSessionType,
|
||||
selectSelectedImage,
|
||||
selectStagedImageIndex,
|
||||
selectStagedImages,
|
||||
stagingAreaImageSelected,
|
||||
stagingAreaNextStagedImageSelected,
|
||||
stagingAreaPrevStagedImageSelected,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import type { EphemeralProgressImage } from 'features/controlLayers/store/types';
|
||||
import { newCanvasFromImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
@@ -55,7 +54,8 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiDotsThreeOutlineVerticalFill, PiUploadBold } from 'react-icons/pi';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { $lastCanvasProgressImage } from 'services/events/stores';
|
||||
import type { ProgressAndResult } from 'services/events/stores';
|
||||
import { $progressImages, useMapSelector } from 'services/events/stores';
|
||||
import type { Equals, Param0 } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
@@ -288,6 +288,7 @@ const SimpleActiveSession = memo(() => {
|
||||
|
||||
const startOver = useCallback(() => {
|
||||
dispatch(canvasSessionStarted({ sessionType: null }));
|
||||
$progressImages.set({});
|
||||
}, [dispatch]);
|
||||
|
||||
const goAdvanced = useCallback(() => {
|
||||
@@ -325,15 +326,10 @@ const SimpleActiveSession = memo(() => {
|
||||
SimpleActiveSession.displayName = 'SimpleActiveSession';
|
||||
|
||||
const SelectedImageOrProgressImage = memo(() => {
|
||||
const progressImage = useStore($lastCanvasProgressImage);
|
||||
const selectedImage = useAppSelector(selectSelectedImage);
|
||||
|
||||
if (progressImage) {
|
||||
return <ProgressImage progressImage={progressImage} />;
|
||||
}
|
||||
|
||||
if (selectedImage) {
|
||||
return <SelectedImage imageDTO={selectedImage.imageDTO} />;
|
||||
return <FullSizeImage sessionId={selectedImage.sessionId} />;
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -397,36 +393,107 @@ const SelectedImage = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
});
|
||||
SelectedImage.displayName = 'SelectedImage';
|
||||
|
||||
const ProgressImage = memo(({ progressImage }: { progressImage: EphemeralProgressImage }) => {
|
||||
const FullSizeImage = memo(({ sessionId }: { sessionId: string }) => {
|
||||
const _progressImage = useMapSelector(sessionId, $progressImages);
|
||||
|
||||
if (!_progressImage) {
|
||||
return (
|
||||
<Flex alignItems="center" justifyContent="center" minH={0} minW={0} h="full">
|
||||
<Text>Pending</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (_progressImage.resultImage) {
|
||||
return <SelectedImage imageDTO={_progressImage.resultImage} />;
|
||||
}
|
||||
|
||||
if (_progressImage.progressImage) {
|
||||
return (
|
||||
<Flex alignItems="center" justifyContent="center" minH={0} minW={0} h="full">
|
||||
<Image
|
||||
objectFit="contain"
|
||||
maxH="full"
|
||||
maxW="full"
|
||||
src={_progressImage.progressImage.dataURL}
|
||||
width={_progressImage.progressImage.width}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex alignItems="center" justifyContent="center" minH={0} minW={0} h="full">
|
||||
<Image
|
||||
objectFit="contain"
|
||||
maxH="full"
|
||||
maxW="full"
|
||||
src={progressImage.image.dataURL}
|
||||
width={progressImage.image.width}
|
||||
/>
|
||||
<Text>No progress yet</Text>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
ProgressImage.displayName = 'ProgressImage';
|
||||
FullSizeImage.displayName = 'FullSizeImage';
|
||||
|
||||
const SessionImages = memo(() => {
|
||||
const stagedImages = useAppSelector(selectStagedImages);
|
||||
const progressImages = useStore($progressImages);
|
||||
return (
|
||||
<Flex position="relative" gap={2} h={108} maxW="full" overflow="scroll">
|
||||
<Spacer />
|
||||
{stagedImages.map(({ imageDTO }, index) => (
|
||||
<SessionImage key={imageDTO.image_name} index={index} imageDTO={imageDTO} />
|
||||
))}
|
||||
{Object.values(progressImages).map((data, index) => {
|
||||
if (data.type === 'staged') {
|
||||
return <SessionImage key={data.sessionId} index={index} data={data} />;
|
||||
} else {
|
||||
return <ProgressImagePreview key={data.sessionId} index={index} data={data} />;
|
||||
}
|
||||
})}
|
||||
<Spacer />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
SessionImages.displayName = 'SessionImages';
|
||||
|
||||
const getStagingImageId = (imageDTO: ImageDTO) => `staging-image-${imageDTO.image_name}`;
|
||||
const ProgressImagePreview = ({ index, data }: { index: number; data: ProgressAndResult }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedImageIndex = useAppSelector(selectStagedImageIndex);
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(stagingAreaImageSelected({ index }));
|
||||
}, [dispatch, index]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedImageIndex === index) {
|
||||
// this doesn't work when the DndImage is in a popover... why
|
||||
document.getElementById(getStagingImageId(data.sessionId))?.scrollIntoView();
|
||||
}
|
||||
}, [data.sessionId, index, selectedImageIndex]);
|
||||
|
||||
if (data.resultImage) {
|
||||
return (
|
||||
<Image
|
||||
id={getStagingImageId(data.sessionId)}
|
||||
objectFit="contain"
|
||||
maxH="full"
|
||||
maxW="full"
|
||||
src={data.resultImage.thumbnail_url}
|
||||
width={data.resultImage.width}
|
||||
onClick={onClick}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (data.progressImage) {
|
||||
return (
|
||||
<Image
|
||||
id={getStagingImageId(data.sessionId)}
|
||||
objectFit="contain"
|
||||
maxH="full"
|
||||
maxW="full"
|
||||
src={data.progressImage.dataURL}
|
||||
width={data.progressImage.width}
|
||||
onClick={onClick}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return <Box id={getStagingImageId(data.sessionId)} bg="blue" h="full" w={108} borderWidth={1} onClick={onClick} />;
|
||||
};
|
||||
|
||||
const getStagingImageId = (session_id: string) => `staging-image-${session_id}`;
|
||||
|
||||
const sx = {
|
||||
objectFit: 'contain',
|
||||
@@ -442,7 +509,7 @@ const sx = {
|
||||
opacity: 0.5,
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
const SessionImage = memo(({ index, imageDTO }: { index: number; imageDTO: ImageDTO }) => {
|
||||
const SessionImage = memo(({ index, data }: { index: number; data: ProgressAndResult }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedImageIndex = useAppSelector(selectStagedImageIndex);
|
||||
const onClick = useCallback(() => {
|
||||
@@ -451,17 +518,19 @@ const SessionImage = memo(({ index, imageDTO }: { index: number; imageDTO: Image
|
||||
useEffect(() => {
|
||||
if (selectedImageIndex === index) {
|
||||
// this doesn't work when the DndImage is in a popover... why
|
||||
document.getElementById(getStagingImageId(imageDTO))?.scrollIntoView();
|
||||
document.getElementById(getStagingImageId(data.sessionId))?.scrollIntoView();
|
||||
}
|
||||
}, [imageDTO, index, selectedImageIndex]);
|
||||
}, [data.sessionId, index, selectedImageIndex]);
|
||||
return (
|
||||
<DndImage
|
||||
id={getStagingImageId(imageDTO)}
|
||||
imageDTO={imageDTO}
|
||||
id={getStagingImageId(data.sessionId)}
|
||||
imageDTO={data.imageDTO}
|
||||
asThumbnail
|
||||
onClick={onClick}
|
||||
data-is-selected={selectedImageIndex === index}
|
||||
w={data.imageDTO.width}
|
||||
sx={sx}
|
||||
borderWidth={1}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -2,12 +2,12 @@ import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolki
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { canvasReset } from 'features/controlLayers/store/actions';
|
||||
import type { StagingAreaImage } from 'features/controlLayers/store/types';
|
||||
import type { StagingAreaImage, StagingAreaProgressImage } from 'features/controlLayers/store/types';
|
||||
import { selectCanvasQueueCounts } from 'services/api/endpoints/queue';
|
||||
|
||||
type CanvasStagingAreaState = {
|
||||
sessionType: 'simple' | 'advanced' | null;
|
||||
images: StagingAreaImage[];
|
||||
images: (StagingAreaImage | StagingAreaProgressImage)[];
|
||||
selectedImageIndex: number;
|
||||
};
|
||||
|
||||
@@ -25,8 +25,28 @@ export const canvasSessionSlice = createSlice({
|
||||
reducers: {
|
||||
stagingAreaImageStaged: (state, action: PayloadAction<{ stagingAreaImage: StagingAreaImage }>) => {
|
||||
const { stagingAreaImage } = action.payload;
|
||||
state.images.push(stagingAreaImage);
|
||||
state.selectedImageIndex = state.images.length - 1;
|
||||
let didReplace = false;
|
||||
const newImages = [];
|
||||
for (const i of state.images) {
|
||||
if (i.sessionId === stagingAreaImage.sessionId) {
|
||||
newImages.push(stagingAreaImage);
|
||||
didReplace = true;
|
||||
} else {
|
||||
newImages.push(i);
|
||||
}
|
||||
}
|
||||
if (!didReplace) {
|
||||
newImages.push(stagingAreaImage);
|
||||
}
|
||||
state.images = newImages;
|
||||
},
|
||||
stagingAreaGenerationStarted: (state, action: PayloadAction<{ sessionId: string }>) => {
|
||||
const { sessionId } = action.payload;
|
||||
state.images.push({ type: 'progress', sessionId });
|
||||
},
|
||||
stagingAreaGenerationFinished: (state, action: PayloadAction<{ sessionId: string }>) => {
|
||||
const { sessionId } = action.payload;
|
||||
state.images = state.images.filter((data) => data.sessionId !== sessionId);
|
||||
},
|
||||
stagingAreaImageSelected: (state, action: PayloadAction<{ index: number }>) => {
|
||||
const { index } = action.payload;
|
||||
@@ -61,6 +81,8 @@ export const canvasSessionSlice = createSlice({
|
||||
|
||||
export const {
|
||||
stagingAreaImageStaged,
|
||||
stagingAreaGenerationStarted,
|
||||
stagingAreaGenerationFinished,
|
||||
stagingAreaStagedImageDiscarded,
|
||||
stagingAreaReset,
|
||||
stagingAreaImageSelected,
|
||||
|
||||
@@ -438,10 +438,16 @@ export type LoRA = {
|
||||
};
|
||||
|
||||
export type StagingAreaImage = {
|
||||
type: 'staged';
|
||||
sessionId: string;
|
||||
imageDTO: ImageDTO;
|
||||
offsetX: number;
|
||||
offsetY: number;
|
||||
};
|
||||
export type StagingAreaProgressImage = {
|
||||
type: 'progress';
|
||||
sessionId: string;
|
||||
};
|
||||
export type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
|
||||
|
||||
export const zAspectRatioID = z.enum(['Free', '21:9', '9:21', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
|
||||
|
||||
@@ -13,7 +13,11 @@ import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { getCategories, getListImagesUrl } from 'services/api/util';
|
||||
import { $lastCanvasProgressImage, $lastProgressEvent } from 'services/events/stores';
|
||||
import {
|
||||
$lastCanvasProgressImage,
|
||||
$lastProgressEvent,
|
||||
$progressImages,
|
||||
} from 'services/events/stores';
|
||||
import type { Param0 } from 'tsafe';
|
||||
import { objectEntries } from 'tsafe';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
@@ -184,9 +188,20 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
|
||||
}
|
||||
|
||||
flushSync(() => {
|
||||
dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
|
||||
dispatch(
|
||||
stagingAreaImageStaged({
|
||||
stagingAreaImage: { type: 'staged', sessionId: data.session_id, imageDTO, offsetX: 0, offsetY: 0 },
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
const progressData = $progressImages.get()[data.session_id];
|
||||
if (progressData) {
|
||||
$progressImages.setKey(data.session_id, { ...progressData, isFinished: true, resultImage: imageDTO });
|
||||
} else {
|
||||
$progressImages.setKey(data.session_id, { sessionId: data.session_id, isFinished: true, resultImage: imageDTO });
|
||||
}
|
||||
|
||||
$lastCanvasProgressImage.set(null);
|
||||
};
|
||||
|
||||
|
||||
@@ -8,6 +8,9 @@ import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
|
||||
import { $queueId } from 'app/store/nanostores/queueId';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import {
|
||||
stagingAreaGenerationStarted,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
$isInPublishFlow,
|
||||
$outputNodeId,
|
||||
@@ -38,6 +41,7 @@ import {
|
||||
$lastUpscalingProgressImage,
|
||||
$lastWorkflowsProgressEvent,
|
||||
$lastWorkflowsProgressImage,
|
||||
$progressImages,
|
||||
} from './stores';
|
||||
|
||||
const log = logger('events');
|
||||
@@ -115,6 +119,15 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
|
||||
|
||||
$lastProgressEvent.set(data);
|
||||
|
||||
if (data.image) {
|
||||
const progressData = $progressImages.get()[session_id];
|
||||
if (progressData) {
|
||||
$progressImages.setKey(session_id, { ...progressData, progressImage: data.image });
|
||||
} else {
|
||||
$progressImages.setKey(session_id, { sessionId: session_id, isFinished: false, progressImage: data.image });
|
||||
}
|
||||
}
|
||||
|
||||
if (origin === 'canvas') {
|
||||
$lastCanvasProgressEvent.set(data);
|
||||
if (image) {
|
||||
@@ -432,6 +445,10 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
|
||||
clone.outputs = [];
|
||||
$nodeExecutionStates.setKey(clone.nodeId, clone);
|
||||
});
|
||||
if (data.origin === 'canvas') {
|
||||
store.dispatch(stagingAreaGenerationStarted({ sessionId: session_id }));
|
||||
$progressImages.setKey(session_id, { sessionId: session_id, isFinished: false });
|
||||
}
|
||||
} else if (status === 'completed' || status === 'failed' || status === 'canceled') {
|
||||
if (status === 'failed' && error_type) {
|
||||
const isLocal = getState().config.isLocal ?? true;
|
||||
@@ -455,13 +472,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
|
||||
}
|
||||
// If the queue item is completed, failed, or cancelled, we want to clear the last progress event
|
||||
$lastProgressEvent.set(null);
|
||||
|
||||
if (data.origin === 'canvas') {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
if (status === 'canceled' || status === 'failed') {
|
||||
$lastCanvasProgressImage.set(null);
|
||||
}
|
||||
}
|
||||
$progressImages.setKey(session_id, undefined);
|
||||
|
||||
// When a validation run is completed, we want to clear the validation run batch ID & set the workflow as published
|
||||
const validationRunData = $validationRunData.get();
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import type { EphemeralProgressImage } from 'features/controlLayers/store/types';
|
||||
import type { ProgressImage } from 'features/nodes/types/common';
|
||||
import { round } from 'lodash-es';
|
||||
import type { MapStore } from 'nanostores';
|
||||
import { atom, computed, map } from 'nanostores';
|
||||
import type { S } from 'services/api/types';
|
||||
import { useEffect, useState } from 'react';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import type { AppSocket } from 'services/events/types';
|
||||
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
|
||||
|
||||
@@ -10,6 +13,28 @@ export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
|
||||
export const $isConnected = atom<boolean>(false);
|
||||
export const $lastProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
|
||||
|
||||
export type ProgressAndResult = {
|
||||
sessionId: string;
|
||||
isFinished: boolean;
|
||||
progressImage?: ProgressImage;
|
||||
resultImage?: ImageDTO;
|
||||
};
|
||||
export const $progressImages = map({} as Record<string, ProgressAndResult>);
|
||||
|
||||
export const useMapSelector = <T extends object>(id: string, map: MapStore<Record<string, T>>): T | undefined => {
|
||||
const [value, setValue] = useState<T | undefined>();
|
||||
useEffect(() => {
|
||||
const unsub = map.subscribe((data) => {
|
||||
setValue(data[id]);
|
||||
});
|
||||
return () => {
|
||||
unsub();
|
||||
};
|
||||
}, [id, map]);
|
||||
|
||||
return value;
|
||||
};
|
||||
|
||||
export const $lastCanvasProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
|
||||
export const $lastCanvasProgressImage = atom<EphemeralProgressImage | null>(null);
|
||||
export const $lastWorkflowsProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
|
||||
|
||||
Reference in New Issue
Block a user