feat: canvas flow rework (wip)

This commit is contained in:
psychedelicious
2025-06-03 20:54:19 +10:00
parent 0e9b71801a
commit ad736bc190
8 changed files with 266 additions and 96 deletions

View File

@@ -5,7 +5,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import { canvasSessionStarted, selectCanvasSessionType } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { canvasSessionStarted, selectCanvasSession } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
@@ -32,6 +32,11 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
actionCreator: enqueueRequestedCanvas,
effect: async (action, { getState, dispatch }) => {
log.debug('Enqueue requested');
if (!selectCanvasSession(getState())) {
dispatch(canvasSessionStarted({ sessionType: 'simple' }));
}
const state = getState();
const { prepend } = action.payload;
@@ -91,7 +96,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
// const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
const destination = state.canvasSession.session?.id ?? 'canvas';
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
@@ -101,7 +106,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'canvas',
destination: 'canvas',
destination,
})
);
@@ -116,9 +121,6 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
try {
await req.unwrap();
if (!selectCanvasSessionType(state)) {
dispatch(canvasSessionStarted({ sessionType: 'simple' }));
}
log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');

View File

@@ -45,22 +45,25 @@ import { Transform } from 'features/controlLayers/components/Transform/Transform
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { loadImage } from 'features/controlLayers/konva/util';
import { selectDynamicGrid, selectShowHUD } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasSessionStarted, selectCanvasSessionType } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { canvasSessionStarted, selectCanvasSession } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { newCanvasFromImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { newCanvasFromImage } from 'features/imageActions/actions';
import { isImageField } from 'features/nodes/types/common';
import { isCanvasOutputNodeId } from 'features/nodes/util/graph/graphBuilderUtils';
import { round } from 'lodash-es';
import { atom, type WritableAtom } from 'nanostores';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { createContext, memo, useCallback, useContext, useEffect, useMemo, useRef, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { Trans, useTranslation } from 'react-i18next';
import { PiDotsThreeOutlineVerticalFill, PiUploadBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { useListAllQueueItemsQuery } from 'services/api/endpoints/queue';
import type { ImageDTO, S } from 'services/api/types';
import { $socket, setProgress, useProgressData } from 'services/events/stores';
import type { ProgressData } from 'services/events/stores';
import { $socket, clearProgressEvent, setProgress, useHasProgressImage, useProgressData } from 'services/events/stores';
import type { Equals, Param0 } from 'tsafe';
import { assert, objectEntries } from 'tsafe';
@@ -84,25 +87,45 @@ const MenuContent = memo(() => {
MenuContent.displayName = 'MenuContent';
export const CanvasMainPanelContent = memo(() => {
const sessionType = useAppSelector(selectCanvasSessionType);
const session = useAppSelector(selectCanvasSession);
if (sessionType === null) {
if (session === null) {
return <NoActiveSession />;
}
if (sessionType === 'simple') {
return <StagingArea />;
if (session.type === 'simple') {
return <StagingAreaWrapper id={session.id} />;
}
if (sessionType === 'advanced') {
if (session.type === 'advanced') {
return <CanvasActiveSession />;
}
assert<Equals<never, typeof sessionType>>(false, 'Unexpected sessionType');
assert<Equals<never, typeof session>>(false, 'Unexpected session');
});
CanvasMainPanelContent.displayName = 'CanvasMainPanelContent';
const StagingAreaWrapper = memo(({ id }: { id: string }) => {
const ctx = useMemo(
() =>
({
session: {
type: 'simple',
id,
},
$progressData: atom<Record<string, ProgressData>>({}),
}) as const,
[id]
);
return (
<StagingContext.Provider value={ctx}>
<StagingArea />
</StagingContext.Provider>
);
});
StagingAreaWrapper.displayName = 'StagingAreaWrapper';
const generateWithStartingImageDndTargetData = newCanvasFromImageDndTarget.getData({
type: 'raster_layer',
withResize: true,
@@ -116,6 +139,8 @@ const generateWithControlImageDndTargetData = newCanvasFromImageDndTarget.getDat
withResize: true,
});
const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
const NoActiveSession = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
@@ -294,14 +319,36 @@ const scrollIndicatorSx = {
},
} satisfies SystemStyleObject;
type StagingContextValue = {
session:
| {
type: 'simple';
id: string;
}
| {
type: 'advanced';
id: string;
};
$progressData: WritableAtom<Record<string, ProgressData>>;
};
const StagingContext = createContext<StagingContextValue | null>(null);
const useStagingContext = () => {
const ctx = useContext(StagingContext);
assert(ctx !== null, 'use in stg prov');
return ctx;
};
const StagingArea = memo(() => {
const ctx = useStagingContext();
const dispatch = useAppDispatch();
const [selectedItemId, setSelectedItemId] = useState<number | null>(null);
const [autoSwitch, setAutoSwitch] = useState(true);
const [canScrollLeft, setCanScrollLeft] = useState(false);
const [canScrollRight, setCanScrollRight] = useState(false);
const scrollableRef = useRef<HTMLDivElement>(null);
const { data } = useListAllQueueItemsQuery({ destination: 'canvas' });
const { data } = useListAllQueueItemsQuery({ destination: ctx.session.id });
const items = useMemo(() => data?.filter(({ status }) => status !== 'canceled') ?? EMPTY_ARRAY, [data]);
const selectedItem = useMemo(
() =>
@@ -315,7 +362,7 @@ const StagingArea = memo(() => {
);
const startOver = useCallback(() => {
dispatch(canvasSessionStarted({ sessionType: null }));
dispatch(canvasSessionStarted({ sessionType: 'simple' }));
}, [dispatch]);
useEffect(() => {
@@ -353,13 +400,10 @@ const StagingArea = memo(() => {
onSelectItemId(null);
return;
}
if (selectedItem === null && items.length > 0) {
if (selectedItemId === null && items.length > 0) {
onSelectItemId(items[0]?.item_id ?? null);
return;
}
if (selectedItemId === null || items.find((item) => item.item_id === selectedItemId) === undefined) {
return;
}
}, [items, onSelectItemId, selectedItem, selectedItemId]);
const onNext = useCallback(() => {
@@ -387,26 +431,42 @@ const StagingArea = memo(() => {
onSelectItemId(prevItem.item_id);
}, [items, onSelectItemId, selectedItemId]);
useHotkeys('left', onPrev);
useHotkeys('right', onNext);
const onFirst = useCallback(() => {
const first = items.at(0);
if (!first) {
return;
}
onSelectItemId(first.item_id);
}, [items, onSelectItemId]);
const onLast = useCallback(() => {
const last = items.at(-1);
if (!last) {
return;
}
onSelectItemId(last.item_id);
}, [items, onSelectItemId]);
useHotkeys('left', onPrev, { preventDefault: true });
useHotkeys('right', onNext, { preventDefault: true });
useHotkeys('meta+left', onFirst, { preventDefault: true });
useHotkeys('meta+right', onLast, { preventDefault: true });
const socket = useStore($socket);
useEffect(() => {
if (!autoSwitch) {
return;
}
if (!socket) {
return;
}
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
if (data.destination !== 'canvas') {
if (data.destination !== ctx.session.id) {
return;
}
if (data.status === 'in_progress') {
if (data.status === 'in_progress' && autoSwitch) {
onSelectItemId(data.item_id);
}
if (data.status === 'completed' || data.status === 'canceled' || data.status === 'failed') {
clearProgressEvent(ctx.$progressData, data.session_id);
}
};
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
@@ -414,7 +474,7 @@ const StagingArea = memo(() => {
return () => {
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [autoSwitch, onSelectItemId, socket]);
}, [autoSwitch, ctx.$progressData, ctx.session.id, onSelectItemId, socket]);
const _onChangeAutoSwitch = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setAutoSwitch(e.target.checked);
@@ -425,17 +485,17 @@ const StagingArea = memo(() => {
return;
}
const onProgress = (data: S['InvocationProgressEvent']) => {
if (data.destination !== 'canvas') {
if (data.destination !== ctx.session.id) {
return;
}
setProgress(data);
setProgress(ctx.$progressData, data);
};
socket.on('invocation_progress', onProgress);
return () => {
socket.off('invocation_progress', onProgress);
};
}, [socket]);
}, [ctx.$progressData, ctx.session.id, socket]);
return (
<Flex flexDir="column" gap={2} w="full" h="full" minW={0} minH={0}>
@@ -467,8 +527,8 @@ const StagingArea = memo(() => {
<Switch size="sm" isChecked={autoSwitch} onChange={_onChangeAutoSwitch} />
</FormControl>
</Flex>
<Flex position="relative" maxW="full" h={108} justifyContent="center">
<Flex ref={scrollableRef} gap={2} maxW="full" overflowX="scroll" flexShrink={0}>
<Flex position="relative" maxW="full" w="full" h={108}>
<Flex ref={scrollableRef} gap={2} maxW="full" overflowX="scroll">
{items.map((item, i) => (
<QueueItemCard
key={item.item_id}
@@ -511,7 +571,7 @@ const StagingArea = memo(() => {
});
StagingArea.displayName = 'StagingArea';
const queueItemStatusCardMiniSx = {
const queueItemCardSx = {
cursor: 'pointer',
pos: 'relative',
borderWidth: 1,
@@ -528,6 +588,9 @@ const queueItemStatusCardMiniSx = {
'&[data-size="mini"]': {
flexShrink: 0,
},
'&[data-size="full"]&[data-has-progress-image="false"]': {
w: 1024,
},
};
const getCardId = (item_id: number) => `queue-item-status-card-${item_id}`;
@@ -543,7 +606,9 @@ type QueueItemStatusCardMiniProps = {
const QueueItemCard = memo(
({ item, isSelected, number, onSelectItemId, onChangeAutoSwitch, size }: QueueItemStatusCardMiniProps) => {
const ctx = useStagingContext();
const [isImageLoaded, setIsImageLoaded] = useState(false);
const hasProgressImage = useHasProgressImage(ctx.$progressData, item.session_id);
const outputImageName = useMemo(() => {
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>
@@ -566,13 +631,23 @@ const QueueItemCard = memo(
const { currentData: imageDTO } = useGetImageDTOQuery(outputImageName ?? skipToken);
useEffect(() => {
if (imageDTO) {
loadImage(imageDTO.thumbnail_url, true).then(() => {
setIsImageLoaded(true);
});
const syncIsReady = useCallback(async () => {
if (!imageDTO) {
setIsImageLoaded(false);
return;
}
}, [imageDTO, item.session_id]);
try {
const _ = await loadImage(size === 'mini' ? imageDTO.thumbnail_url : imageDTO.image_url, true);
setIsImageLoaded(true);
return;
} catch {
setIsImageLoaded(false);
}
}, [imageDTO, size]);
useEffect(() => {
syncIsReady();
}, [syncIsReady]);
const onClick = useCallback(() => {
onSelectItemId(item.item_id);
@@ -584,9 +659,16 @@ const QueueItemCard = memo(
if (imageDTO && isImageLoaded) {
return (
<Flex id={getCardId(item.item_id)} sx={queueItemStatusCardMiniSx} data-selected={isSelected} data-size={size}>
<DndImage imageDTO={imageDTO} onClick={onClick} onDoubleClick={onDoubleClick} />
<Text position="absolute" top={0} left={1} pointerEvents="none" userSelect="none">{`#${number}`}</Text>
<Flex id={getCardId(item.item_id)} sx={queueItemCardSx} data-selected={isSelected} data-size={size}>
<DndImage imageDTO={imageDTO} onClick={onClick} onDoubleClick={onDoubleClick} asThumbnail={size === 'mini'} />
<Text
position="absolute"
top={0}
left={1}
pointerEvents="none"
userSelect="none"
filter={DROP_SHADOW}
>{`#${number}`}</Text>
{size === 'full' && (
<Flex position="absolute" top={2} right={2}>
<ImageActions imageDTO={imageDTO} />
@@ -599,40 +681,72 @@ const QueueItemCard = memo(
return (
<Flex
id={getCardId(item.item_id)}
sx={queueItemStatusCardMiniSx}
sx={queueItemCardSx}
data-selected={isSelected}
data-size={size}
data-has-progress-image={hasProgressImage}
onClick={onClick}
onDoubleClick={onDoubleClick}
>
<InProgressContent item={item} />
<Text position="absolute" top={0} left={1} pointerEvents="none" userSelect="none">{`#${number}`}</Text>
<Text
position="absolute"
top={0}
left={1}
pointerEvents="none"
userSelect="none"
filter={DROP_SHADOW}
>{`#${number}`}</Text>
{size === 'full' && <ProgressMessage key={item.session_id} session_id={item.session_id} />}
</Flex>
);
}
);
QueueItemCard.displayName = 'QueueItemStatusCard';
const getMessage = (data: S['InvocationProgressEvent']) => {
let message = data.message;
if (data.percentage) {
message += ` (${round(data.percentage * 100)}%)`;
}
return message;
};
const ProgressMessage = memo(({ session_id }: { session_id: string }) => {
const { $progressData } = useStagingContext();
const { progressEvent } = useProgressData($progressData, session_id);
if (!progressEvent) {
return null;
}
return (
<Text position="absolute" bottom={2} left={2} pointerEvents="none" userSelect="none" filter={DROP_SHADOW}>
{getMessage(progressEvent)}
</Text>
);
});
ProgressMessage.displayName = 'ProgressMessage';
const InProgressContent = memo(({ item }: { item: S['SessionQueueItem'] }) => {
const { progressEvent, progressImage } = useProgressData(item.session_id);
const { $progressData } = useStagingContext();
const { progressEvent, progressImage } = useProgressData($progressData, item.session_id);
if (item.status === 'pending') {
return (
<Text fontWeight="semibold" color="base.300">
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="base.300">
Pending
</Text>
);
}
if (item.status === 'canceled') {
return (
<Text fontWeight="semibold" color="warning.300">
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="warning.300">
Canceled
</Text>
);
}
if (item.status === 'failed') {
return (
<Text fontWeight="semibold" color="error.300">
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300">
Failed
</Text>
);
@@ -650,7 +764,7 @@ const InProgressContent = memo(({ item }: { item: S['SessionQueueItem'] }) => {
if (item.status === 'in_progress') {
return (
<>
<Text fontWeight="semibold" color="invokeBlue.300">
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeBlue.300">
In Progress
</Text>
<ProgressCircle data={progressEvent} />
@@ -660,7 +774,7 @@ const InProgressContent = memo(({ item }: { item: S['SessionQueueItem'] }) => {
if (item.status === 'completed') {
return (
<Text fontWeight="semibold" color="error.300">
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300">
Unable to get image
</Text>
);

View File

@@ -20,7 +20,7 @@ export const useNewGallerySession = () => {
const newSessionDialog = useNewGallerySessionDialog();
const newGallerySessionImmediate = useCallback(() => {
dispatch(canvasSessionStarted({ sessionType: null }));
dispatch(canvasSessionStarted({ sessionType: 'simple' }));
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);

View File

@@ -1,17 +1,20 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasReset } from 'features/controlLayers/store/actions';
import type { StagingAreaImage, StagingAreaProgressImage } from 'features/controlLayers/store/types';
import { selectCanvasQueueCounts } from 'services/api/endpoints/queue';
type CanvasStagingAreaState = {
session: { type: 'simple'; id: string } | { type: 'advanced'; id: string } | null;
sessionType: 'simple' | 'advanced' | null;
images: (StagingAreaImage | StagingAreaProgressImage)[];
selectedImageIndex: number;
};
const INITIAL_STATE: CanvasStagingAreaState = {
session: null,
sessionType: null,
images: [],
selectedImageIndex: 0,
@@ -23,6 +26,10 @@ export const canvasSessionSlice = createSlice({
name: 'canvasSession',
initialState: getInitialState(),
reducers: {
sessionChanged: (state, action: PayloadAction<{ session: CanvasStagingAreaState['session'] }>) => {
const { session } = action.payload;
state.session = session;
},
stagingAreaImageStaged: (state, action: PayloadAction<{ stagingAreaImage: StagingAreaImage }>) => {
const { stagingAreaImage } = action.payload;
let didReplace = false;
@@ -67,11 +74,19 @@ export const canvasSessionSlice = createSlice({
state.images = [];
state.selectedImageIndex = 0;
},
canvasSessionStarted: (_, action: PayloadAction<{ sessionType: CanvasStagingAreaState['sessionType'] }>) => {
const { sessionType } = action.payload;
const state = getInitialState();
state.sessionType = sessionType;
return state;
canvasSessionStarted: {
reducer: (state, action: PayloadAction<{ session: CanvasStagingAreaState['session'] }>) => {
const { session } = action.payload;
state.session = session;
},
prepare: (payload: { sessionType: 'simple' | 'advanced' }) => ({
payload: {
session: {
type: payload.sessionType,
id: getPrefixedId(`canvas:${payload.sessionType}`),
},
},
}),
},
},
extraReducers(builder) {
@@ -80,6 +95,7 @@ export const canvasSessionSlice = createSlice({
});
export const {
sessionChanged,
stagingAreaImageStaged,
stagingAreaGenerationStarted,
stagingAreaGenerationFinished,
@@ -140,3 +156,7 @@ export const selectCanvasSessionType = createSelector(
selectCanvasStagingAreaSlice,
(canvasSession) => canvasSession.sessionType
);
export const selectCanvasSession = createSelector(
selectCanvasStagingAreaSlice,
(canvasSession) => canvasSession.session
);

View File

@@ -34,7 +34,7 @@ export const prepareLinearUIBatch = (arg: {
seedFieldIdentifier?: FieldIdentifier;
positivePromptFieldIdentifier: FieldIdentifier;
origin: 'canvas' | 'workflows' | 'upscaling';
destination: 'canvas' | 'gallery';
destination: string;
}): EnqueueBatchArg => {
const { state, g, prepend, seedFieldIdentifier, positivePromptFieldIdentifier, origin, destination } = arg;
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.params;

View File

@@ -1,22 +1,18 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { stagingAreaImageStaged } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { isCanvasOutputEvent } from 'features/nodes/util/graph/graphBuilderUtils';
import { flushSync } from 'react-dom';
import type { ApiTagDescription } from 'services/api';
import { boardsApi } from 'services/api/endpoints/boards';
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO, S } from 'services/api/types';
import { getCategories, getListImagesUrl } from 'services/api/util';
import {
$lastCanvasProgressImage,
$lastProgressEvent,
$progressImages,
} from 'services/events/stores';
import type { Param0 } from 'tsafe';
import { objectEntries } from 'tsafe';
@@ -180,29 +176,29 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
await addImagesToGallery(data);
// We expect only a single image in the canvas output
const imageDTO = (await getResultImageDTOs(data))[0];
// // We expect only a single image in the canvas output
// const imageDTO = (await getResultImageDTOs(data))[0];
if (!imageDTO) {
return;
}
// if (!imageDTO) {
// return;
// }
flushSync(() => {
dispatch(
stagingAreaImageStaged({
stagingAreaImage: { type: 'staged', sessionId: data.session_id, imageDTO, offsetX: 0, offsetY: 0 },
})
);
});
// flushSync(() => {
// dispatch(
// stagingAreaImageStaged({
// stagingAreaImage: { type: 'staged', sessionId: data.session_id, imageDTO, offsetX: 0, offsetY: 0 },
// })
// );
// });
const progressData = $progressImages.get()[data.session_id];
if (progressData) {
$progressImages.setKey(data.session_id, { ...progressData, isFinished: true, resultImage: imageDTO });
} else {
$progressImages.setKey(data.session_id, { sessionId: data.session_id, isFinished: true, resultImage: imageDTO });
}
// const progressData = $progressImages.get()[data.session_id];
// if (progressData) {
// $progressImages.setKey(data.session_id, { ...progressData, isFinished: true, resultImage: imageDTO });
// } else {
// $progressImages.setKey(data.session_id, { sessionId: data.session_id, isFinished: true, resultImage: imageDTO });
// }
$lastCanvasProgressImage.set(null);
// $lastCanvasProgressImage.set(null);
};
const handleOriginOther = async (data: S['InvocationCompleteEvent']) => {

View File

@@ -8,7 +8,6 @@ import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
import { $queueId } from 'app/store/nanostores/queueId';
import type { AppStore } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { stagingAreaGenerationStarted } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
$isInPublishFlow,
$outputNodeId,
@@ -397,8 +396,8 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
$nodeExecutionStates.setKey(clone.nodeId, clone);
});
if (data.origin === 'canvas') {
store.dispatch(stagingAreaGenerationStarted({ sessionId: session_id }));
$progressImages.setKey(session_id, { sessionId: session_id, isFinished: false });
// store.dispatch(stagingAreaGenerationStarted({ sessionId: session_id }));
// $progressImages.setKey(session_id, { sessionId: session_id, isFinished: false });
}
} else if (status === 'completed' || status === 'failed' || status === 'canceled') {
if (status === 'failed' && error_type) {
@@ -423,7 +422,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
}
// If the queue item is completed, failed, or cancelled, we want to clear the last progress event
$lastProgressEvent.set(null);
$progressImages.setKey(session_id, undefined);
// $progressImages.setKey(session_id, undefined);
// When a validation run is completed, we want to clear the validation run batch ID & set the workflow as published
const validationRunData = $validationRunData.get();

View File

@@ -1,6 +1,7 @@
import type { EphemeralProgressImage } from 'features/controlLayers/store/types';
import type { ProgressImage } from 'features/nodes/types/common';
import { round } from 'lodash-es';
import type { WritableAtom } from 'nanostores';
import { atom, computed, map } from 'nanostores';
import { useEffect, useState } from 'react';
import type { ImageDTO, S } from 'services/api/types';
@@ -20,16 +21,19 @@ export type ProgressAndResult = {
};
export const $progressImages = map({} as Record<string, ProgressAndResult>);
type ProgressData = {
export type ProgressData = {
sessionId: string;
progressEvent: S['InvocationProgressEvent'] | null;
progressImage: ProgressImage | null;
};
export const $progressData = atom<Record<string, ProgressData>>({});
export const useProgressData = (sessionId: string): ProgressData => {
const [value, setValue] = useState<ProgressData>({ sessionId, progressEvent: null, progressImage: null });
export const useProgressData = (
$progressData: WritableAtom<Record<string, ProgressData>>,
sessionId: string
): ProgressData => {
const [value, setValue] = useState<ProgressData>(() => {
return $progressData.get()[sessionId] ?? { sessionId, progressEvent: null, progressImage: null };
});
useEffect(() => {
const unsub = $progressData.subscribe((data) => {
const progressData = data[sessionId];
@@ -41,12 +45,33 @@ export const useProgressData = (sessionId: string): ProgressData => {
return () => {
unsub();
};
}, [sessionId]);
}, [$progressData, sessionId]);
return value;
};
export const setProgress = (data: S['InvocationProgressEvent']) => {
export const useHasProgressImage = (
$progressData: WritableAtom<Record<string, ProgressData>>,
sessionId: string
): boolean => {
const [value, setValue] = useState(false);
useEffect(() => {
const unsub = $progressData.subscribe((data) => {
const progressData = data[sessionId];
setValue(Boolean(progressData?.progressImage));
});
return () => {
unsub();
};
}, [$progressData, sessionId]);
return value;
};
export const setProgress = (
$progressData: WritableAtom<Record<string, ProgressData>>,
data: S['InvocationProgressEvent']
) => {
const progressData = $progressData.get();
const current = progressData[data.session_id];
if (current) {
@@ -71,7 +96,21 @@ export const setProgress = (data: S['InvocationProgressEvent']) => {
}
};
export const clearProgressImage = (sessionId: string) => {
export const clearProgressEvent = ($progressData: WritableAtom<Record<string, ProgressData>>, sessionId: string) => {
const progressData = $progressData.get();
const current = progressData[sessionId];
if (!current) {
return;
}
const next = { ...current };
next.progressEvent = null;
$progressData.set({
...progressData,
[sessionId]: next,
});
};
export const clearProgressImage = ($progressData: WritableAtom<Record<string, ProgressData>>, sessionId: string) => {
const progressData = $progressData.get();
const current = progressData[sessionId];
if (!current) {