wip progress events

This commit is contained in:
psychedelicious
2025-05-31 00:01:57 +10:00
parent a3851e0b08
commit 5e93f58530
7 changed files with 192 additions and 42 deletions

View File

@@ -3,6 +3,7 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too
import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import { getDebugLoggerMiddleware } from 'app/store/middleware/debugLoggerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
@@ -175,6 +176,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
.concat(api.middleware)
.concat(dynamicMiddlewares)
.concat(authToastMiddleware)
.concat(getDebugLoggerMiddleware())
.prepend(listenerMiddleware.middleware),
enhancers: (getDefaultEnhancers) => {
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());

View File

@@ -1,5 +1,6 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
Box,
Button,
ButtonGroup,
ContextMenu,
@@ -40,12 +41,10 @@ import {
selectCanvasSessionType,
selectSelectedImage,
selectStagedImageIndex,
selectStagedImages,
stagingAreaImageSelected,
stagingAreaNextStagedImageSelected,
stagingAreaPrevStagedImageSelected,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import type { EphemeralProgressImage } from 'features/controlLayers/store/types';
import { newCanvasFromImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
@@ -55,7 +54,8 @@ import { useHotkeys } from 'react-hotkeys-hook';
import { Trans, useTranslation } from 'react-i18next';
import { PiDotsThreeOutlineVerticalFill, PiUploadBold } from 'react-icons/pi';
import type { ImageDTO } from 'services/api/types';
import { $lastCanvasProgressImage } from 'services/events/stores';
import type { ProgressAndResult } from 'services/events/stores';
import { $progressImages, useMapSelector } from 'services/events/stores';
import type { Equals, Param0 } from 'tsafe';
import { assert } from 'tsafe';
@@ -288,6 +288,7 @@ const SimpleActiveSession = memo(() => {
const startOver = useCallback(() => {
dispatch(canvasSessionStarted({ sessionType: null }));
$progressImages.set({});
}, [dispatch]);
const goAdvanced = useCallback(() => {
@@ -325,15 +326,10 @@ const SimpleActiveSession = memo(() => {
SimpleActiveSession.displayName = 'SimpleActiveSession';
const SelectedImageOrProgressImage = memo(() => {
const progressImage = useStore($lastCanvasProgressImage);
const selectedImage = useAppSelector(selectSelectedImage);
if (progressImage) {
return <ProgressImage progressImage={progressImage} />;
}
if (selectedImage) {
return <SelectedImage imageDTO={selectedImage.imageDTO} />;
return <FullSizeImage sessionId={selectedImage.sessionId} />;
}
return (
@@ -397,36 +393,107 @@ const SelectedImage = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
});
SelectedImage.displayName = 'SelectedImage';
const ProgressImage = memo(({ progressImage }: { progressImage: EphemeralProgressImage }) => {
const FullSizeImage = memo(({ sessionId }: { sessionId: string }) => {
const _progressImage = useMapSelector(sessionId, $progressImages);
if (!_progressImage) {
return (
<Flex alignItems="center" justifyContent="center" minH={0} minW={0} h="full">
<Text>Pending</Text>
</Flex>
);
}
if (_progressImage.resultImage) {
return <SelectedImage imageDTO={_progressImage.resultImage} />;
}
if (_progressImage.progressImage) {
return (
<Flex alignItems="center" justifyContent="center" minH={0} minW={0} h="full">
<Image
objectFit="contain"
maxH="full"
maxW="full"
src={_progressImage.progressImage.dataURL}
width={_progressImage.progressImage.width}
/>
</Flex>
);
}
return (
<Flex alignItems="center" justifyContent="center" minH={0} minW={0} h="full">
<Image
objectFit="contain"
maxH="full"
maxW="full"
src={progressImage.image.dataURL}
width={progressImage.image.width}
/>
<Text>No progress yet</Text>
</Flex>
);
});
ProgressImage.displayName = 'ProgressImage';
FullSizeImage.displayName = 'FullSizeImage';
const SessionImages = memo(() => {
const stagedImages = useAppSelector(selectStagedImages);
const progressImages = useStore($progressImages);
return (
<Flex position="relative" gap={2} h={108} maxW="full" overflow="scroll">
<Spacer />
{stagedImages.map(({ imageDTO }, index) => (
<SessionImage key={imageDTO.image_name} index={index} imageDTO={imageDTO} />
))}
{Object.values(progressImages).map((data, index) => {
if (data.type === 'staged') {
return <SessionImage key={data.sessionId} index={index} data={data} />;
} else {
return <ProgressImagePreview key={data.sessionId} index={index} data={data} />;
}
})}
<Spacer />
</Flex>
);
});
SessionImages.displayName = 'SessionImages';
const getStagingImageId = (imageDTO: ImageDTO) => `staging-image-${imageDTO.image_name}`;
const ProgressImagePreview = ({ index, data }: { index: number; data: ProgressAndResult }) => {
const dispatch = useAppDispatch();
const selectedImageIndex = useAppSelector(selectStagedImageIndex);
const onClick = useCallback(() => {
dispatch(stagingAreaImageSelected({ index }));
}, [dispatch, index]);
useEffect(() => {
if (selectedImageIndex === index) {
// this doesn't work when the DndImage is in a popover... why
document.getElementById(getStagingImageId(data.sessionId))?.scrollIntoView();
}
}, [data.sessionId, index, selectedImageIndex]);
if (data.resultImage) {
return (
<Image
id={getStagingImageId(data.sessionId)}
objectFit="contain"
maxH="full"
maxW="full"
src={data.resultImage.thumbnail_url}
width={data.resultImage.width}
onClick={onClick}
/>
);
}
if (data.progressImage) {
return (
<Image
id={getStagingImageId(data.sessionId)}
objectFit="contain"
maxH="full"
maxW="full"
src={data.progressImage.dataURL}
width={data.progressImage.width}
onClick={onClick}
/>
);
}
return <Box id={getStagingImageId(data.sessionId)} bg="blue" h="full" w={108} borderWidth={1} onClick={onClick} />;
};
const getStagingImageId = (session_id: string) => `staging-image-${session_id}`;
const sx = {
objectFit: 'contain',
@@ -442,7 +509,7 @@ const sx = {
opacity: 0.5,
},
} satisfies SystemStyleObject;
const SessionImage = memo(({ index, imageDTO }: { index: number; imageDTO: ImageDTO }) => {
const SessionImage = memo(({ index, data }: { index: number; data: ProgressAndResult }) => {
const dispatch = useAppDispatch();
const selectedImageIndex = useAppSelector(selectStagedImageIndex);
const onClick = useCallback(() => {
@@ -451,17 +518,19 @@ const SessionImage = memo(({ index, imageDTO }: { index: number; imageDTO: Image
useEffect(() => {
if (selectedImageIndex === index) {
// this doesn't work when the DndImage is in a popover... why
document.getElementById(getStagingImageId(imageDTO))?.scrollIntoView();
document.getElementById(getStagingImageId(data.sessionId))?.scrollIntoView();
}
}, [imageDTO, index, selectedImageIndex]);
}, [data.sessionId, index, selectedImageIndex]);
return (
<DndImage
id={getStagingImageId(imageDTO)}
imageDTO={imageDTO}
id={getStagingImageId(data.sessionId)}
imageDTO={data.imageDTO}
asThumbnail
onClick={onClick}
data-is-selected={selectedImageIndex === index}
w={data.imageDTO.width}
sx={sx}
borderWidth={1}
/>
);
});

View File

@@ -2,12 +2,12 @@ import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolki
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { canvasReset } from 'features/controlLayers/store/actions';
import type { StagingAreaImage } from 'features/controlLayers/store/types';
import type { StagingAreaImage, StagingAreaProgressImage } from 'features/controlLayers/store/types';
import { selectCanvasQueueCounts } from 'services/api/endpoints/queue';
type CanvasStagingAreaState = {
sessionType: 'simple' | 'advanced' | null;
images: StagingAreaImage[];
images: (StagingAreaImage | StagingAreaProgressImage)[];
selectedImageIndex: number;
};
@@ -25,8 +25,28 @@ export const canvasSessionSlice = createSlice({
reducers: {
stagingAreaImageStaged: (state, action: PayloadAction<{ stagingAreaImage: StagingAreaImage }>) => {
const { stagingAreaImage } = action.payload;
state.images.push(stagingAreaImage);
state.selectedImageIndex = state.images.length - 1;
let didReplace = false;
const newImages = [];
for (const i of state.images) {
if (i.sessionId === stagingAreaImage.sessionId) {
newImages.push(stagingAreaImage);
didReplace = true;
} else {
newImages.push(i);
}
}
if (!didReplace) {
newImages.push(stagingAreaImage);
}
state.images = newImages;
},
stagingAreaGenerationStarted: (state, action: PayloadAction<{ sessionId: string }>) => {
const { sessionId } = action.payload;
state.images.push({ type: 'progress', sessionId });
},
stagingAreaGenerationFinished: (state, action: PayloadAction<{ sessionId: string }>) => {
const { sessionId } = action.payload;
state.images = state.images.filter((data) => data.sessionId !== sessionId);
},
stagingAreaImageSelected: (state, action: PayloadAction<{ index: number }>) => {
const { index } = action.payload;
@@ -61,6 +81,8 @@ export const canvasSessionSlice = createSlice({
export const {
stagingAreaImageStaged,
stagingAreaGenerationStarted,
stagingAreaGenerationFinished,
stagingAreaStagedImageDiscarded,
stagingAreaReset,
stagingAreaImageSelected,

View File

@@ -438,10 +438,16 @@ export type LoRA = {
};
export type StagingAreaImage = {
type: 'staged';
sessionId: string;
imageDTO: ImageDTO;
offsetX: number;
offsetY: number;
};
export type StagingAreaProgressImage = {
type: 'progress';
sessionId: string;
};
export type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
export const zAspectRatioID = z.enum(['Free', '21:9', '9:21', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);

View File

@@ -13,7 +13,11 @@ import { boardsApi } from 'services/api/endpoints/boards';
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO, S } from 'services/api/types';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { $lastCanvasProgressImage, $lastProgressEvent } from 'services/events/stores';
import {
$lastCanvasProgressImage,
$lastProgressEvent,
$progressImages,
} from 'services/events/stores';
import type { Param0 } from 'tsafe';
import { objectEntries } from 'tsafe';
import type { JsonObject } from 'type-fest';
@@ -184,9 +188,20 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
}
flushSync(() => {
dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
dispatch(
stagingAreaImageStaged({
stagingAreaImage: { type: 'staged', sessionId: data.session_id, imageDTO, offsetX: 0, offsetY: 0 },
})
);
});
const progressData = $progressImages.get()[data.session_id];
if (progressData) {
$progressImages.setKey(data.session_id, { ...progressData, isFinished: true, resultImage: imageDTO });
} else {
$progressImages.setKey(data.session_id, { sessionId: data.session_id, isFinished: true, resultImage: imageDTO });
}
$lastCanvasProgressImage.set(null);
};

View File

@@ -8,6 +8,9 @@ import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
import { $queueId } from 'app/store/nanostores/queueId';
import type { AppStore } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import {
stagingAreaGenerationStarted,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
$isInPublishFlow,
$outputNodeId,
@@ -38,6 +41,7 @@ import {
$lastUpscalingProgressImage,
$lastWorkflowsProgressEvent,
$lastWorkflowsProgressImage,
$progressImages,
} from './stores';
const log = logger('events');
@@ -115,6 +119,15 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
$lastProgressEvent.set(data);
if (data.image) {
const progressData = $progressImages.get()[session_id];
if (progressData) {
$progressImages.setKey(session_id, { ...progressData, progressImage: data.image });
} else {
$progressImages.setKey(session_id, { sessionId: session_id, isFinished: false, progressImage: data.image });
}
}
if (origin === 'canvas') {
$lastCanvasProgressEvent.set(data);
if (image) {
@@ -432,6 +445,10 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
clone.outputs = [];
$nodeExecutionStates.setKey(clone.nodeId, clone);
});
if (data.origin === 'canvas') {
store.dispatch(stagingAreaGenerationStarted({ sessionId: session_id }));
$progressImages.setKey(session_id, { sessionId: session_id, isFinished: false });
}
} else if (status === 'completed' || status === 'failed' || status === 'canceled') {
if (status === 'failed' && error_type) {
const isLocal = getState().config.isLocal ?? true;
@@ -455,13 +472,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
}
// If the queue item is completed, failed, or cancelled, we want to clear the last progress event
$lastProgressEvent.set(null);
if (data.origin === 'canvas') {
$lastCanvasProgressEvent.set(null);
if (status === 'canceled' || status === 'failed') {
$lastCanvasProgressImage.set(null);
}
}
$progressImages.setKey(session_id, undefined);
// When a validation run is completed, we want to clear the validation run batch ID & set the workflow as published
const validationRunData = $validationRunData.get();

View File

@@ -1,7 +1,10 @@
import type { EphemeralProgressImage } from 'features/controlLayers/store/types';
import type { ProgressImage } from 'features/nodes/types/common';
import { round } from 'lodash-es';
import type { MapStore } from 'nanostores';
import { atom, computed, map } from 'nanostores';
import type { S } from 'services/api/types';
import { useEffect, useState } from 'react';
import type { ImageDTO, S } from 'services/api/types';
import type { AppSocket } from 'services/events/types';
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
@@ -10,6 +13,28 @@ export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
export const $isConnected = atom<boolean>(false);
export const $lastProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
export type ProgressAndResult = {
sessionId: string;
isFinished: boolean;
progressImage?: ProgressImage;
resultImage?: ImageDTO;
};
export const $progressImages = map({} as Record<string, ProgressAndResult>);
export const useMapSelector = <T extends object>(id: string, map: MapStore<Record<string, T>>): T | undefined => {
const [value, setValue] = useState<T | undefined>();
useEffect(() => {
const unsub = map.subscribe((data) => {
setValue(data[id]);
});
return () => {
unsub();
};
}, [id, map]);
return value;
};
export const $lastCanvasProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
export const $lastCanvasProgressImage = atom<EphemeralProgressImage | null>(null);
export const $lastWorkflowsProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);