mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat: canvas flow rework (wip)
This commit is contained in:
@@ -5,7 +5,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { canvasSessionStarted, selectCanvasSessionType } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { canvasSessionStarted, selectCanvasSession } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
|
||||
@@ -32,6 +32,11 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
actionCreator: enqueueRequestedCanvas,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
log.debug('Enqueue requested');
|
||||
|
||||
if (!selectCanvasSession(getState())) {
|
||||
dispatch(canvasSessionStarted({ sessionType: 'simple' }));
|
||||
}
|
||||
|
||||
const state = getState();
|
||||
const { prepend } = action.payload;
|
||||
|
||||
@@ -91,7 +96,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
|
||||
|
||||
// const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
|
||||
const destination = state.canvasSession.session?.id ?? 'canvas';
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
@@ -101,7 +106,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
seedFieldIdentifier,
|
||||
positivePromptFieldIdentifier,
|
||||
origin: 'canvas',
|
||||
destination: 'canvas',
|
||||
destination,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -116,9 +121,6 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
try {
|
||||
await req.unwrap();
|
||||
if (!selectCanvasSessionType(state)) {
|
||||
dispatch(canvasSessionStarted({ sessionType: 'simple' }));
|
||||
}
|
||||
log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch');
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
|
||||
|
||||
@@ -45,22 +45,25 @@ import { Transform } from 'features/controlLayers/components/Transform/Transform
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { loadImage } from 'features/controlLayers/konva/util';
|
||||
import { selectDynamicGrid, selectShowHUD } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { canvasSessionStarted, selectCanvasSessionType } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { canvasSessionStarted, selectCanvasSession } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { newCanvasFromImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { newCanvasFromImage } from 'features/imageActions/actions';
|
||||
import { isImageField } from 'features/nodes/types/common';
|
||||
import { isCanvasOutputNodeId } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { round } from 'lodash-es';
|
||||
import { atom, type WritableAtom } from 'nanostores';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { createContext, memo, useCallback, useContext, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiDotsThreeOutlineVerticalFill, PiUploadBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { useListAllQueueItemsQuery } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { $socket, setProgress, useProgressData } from 'services/events/stores';
|
||||
import type { ProgressData } from 'services/events/stores';
|
||||
import { $socket, clearProgressEvent, setProgress, useHasProgressImage, useProgressData } from 'services/events/stores';
|
||||
import type { Equals, Param0 } from 'tsafe';
|
||||
import { assert, objectEntries } from 'tsafe';
|
||||
|
||||
@@ -84,25 +87,45 @@ const MenuContent = memo(() => {
|
||||
MenuContent.displayName = 'MenuContent';
|
||||
|
||||
export const CanvasMainPanelContent = memo(() => {
|
||||
const sessionType = useAppSelector(selectCanvasSessionType);
|
||||
const session = useAppSelector(selectCanvasSession);
|
||||
|
||||
if (sessionType === null) {
|
||||
if (session === null) {
|
||||
return <NoActiveSession />;
|
||||
}
|
||||
|
||||
if (sessionType === 'simple') {
|
||||
return <StagingArea />;
|
||||
if (session.type === 'simple') {
|
||||
return <StagingAreaWrapper id={session.id} />;
|
||||
}
|
||||
|
||||
if (sessionType === 'advanced') {
|
||||
if (session.type === 'advanced') {
|
||||
return <CanvasActiveSession />;
|
||||
}
|
||||
|
||||
assert<Equals<never, typeof sessionType>>(false, 'Unexpected sessionType');
|
||||
assert<Equals<never, typeof session>>(false, 'Unexpected session');
|
||||
});
|
||||
|
||||
CanvasMainPanelContent.displayName = 'CanvasMainPanelContent';
|
||||
|
||||
const StagingAreaWrapper = memo(({ id }: { id: string }) => {
|
||||
const ctx = useMemo(
|
||||
() =>
|
||||
({
|
||||
session: {
|
||||
type: 'simple',
|
||||
id,
|
||||
},
|
||||
$progressData: atom<Record<string, ProgressData>>({}),
|
||||
}) as const,
|
||||
[id]
|
||||
);
|
||||
|
||||
return (
|
||||
<StagingContext.Provider value={ctx}>
|
||||
<StagingArea />
|
||||
</StagingContext.Provider>
|
||||
);
|
||||
});
|
||||
StagingAreaWrapper.displayName = 'StagingAreaWrapper';
|
||||
|
||||
const generateWithStartingImageDndTargetData = newCanvasFromImageDndTarget.getData({
|
||||
type: 'raster_layer',
|
||||
withResize: true,
|
||||
@@ -116,6 +139,8 @@ const generateWithControlImageDndTargetData = newCanvasFromImageDndTarget.getDat
|
||||
withResize: true,
|
||||
});
|
||||
|
||||
const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
|
||||
|
||||
const NoActiveSession = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -294,14 +319,36 @@ const scrollIndicatorSx = {
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
type StagingContextValue = {
|
||||
session:
|
||||
| {
|
||||
type: 'simple';
|
||||
id: string;
|
||||
}
|
||||
| {
|
||||
type: 'advanced';
|
||||
id: string;
|
||||
};
|
||||
$progressData: WritableAtom<Record<string, ProgressData>>;
|
||||
};
|
||||
|
||||
const StagingContext = createContext<StagingContextValue | null>(null);
|
||||
|
||||
const useStagingContext = () => {
|
||||
const ctx = useContext(StagingContext);
|
||||
assert(ctx !== null, 'use in stg prov');
|
||||
return ctx;
|
||||
};
|
||||
|
||||
const StagingArea = memo(() => {
|
||||
const ctx = useStagingContext();
|
||||
const dispatch = useAppDispatch();
|
||||
const [selectedItemId, setSelectedItemId] = useState<number | null>(null);
|
||||
const [autoSwitch, setAutoSwitch] = useState(true);
|
||||
const [canScrollLeft, setCanScrollLeft] = useState(false);
|
||||
const [canScrollRight, setCanScrollRight] = useState(false);
|
||||
const scrollableRef = useRef<HTMLDivElement>(null);
|
||||
const { data } = useListAllQueueItemsQuery({ destination: 'canvas' });
|
||||
const { data } = useListAllQueueItemsQuery({ destination: ctx.session.id });
|
||||
const items = useMemo(() => data?.filter(({ status }) => status !== 'canceled') ?? EMPTY_ARRAY, [data]);
|
||||
const selectedItem = useMemo(
|
||||
() =>
|
||||
@@ -315,7 +362,7 @@ const StagingArea = memo(() => {
|
||||
);
|
||||
|
||||
const startOver = useCallback(() => {
|
||||
dispatch(canvasSessionStarted({ sessionType: null }));
|
||||
dispatch(canvasSessionStarted({ sessionType: 'simple' }));
|
||||
}, [dispatch]);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -353,13 +400,10 @@ const StagingArea = memo(() => {
|
||||
onSelectItemId(null);
|
||||
return;
|
||||
}
|
||||
if (selectedItem === null && items.length > 0) {
|
||||
if (selectedItemId === null && items.length > 0) {
|
||||
onSelectItemId(items[0]?.item_id ?? null);
|
||||
return;
|
||||
}
|
||||
if (selectedItemId === null || items.find((item) => item.item_id === selectedItemId) === undefined) {
|
||||
return;
|
||||
}
|
||||
}, [items, onSelectItemId, selectedItem, selectedItemId]);
|
||||
|
||||
const onNext = useCallback(() => {
|
||||
@@ -387,26 +431,42 @@ const StagingArea = memo(() => {
|
||||
onSelectItemId(prevItem.item_id);
|
||||
}, [items, onSelectItemId, selectedItemId]);
|
||||
|
||||
useHotkeys('left', onPrev);
|
||||
useHotkeys('right', onNext);
|
||||
const onFirst = useCallback(() => {
|
||||
const first = items.at(0);
|
||||
if (!first) {
|
||||
return;
|
||||
}
|
||||
onSelectItemId(first.item_id);
|
||||
}, [items, onSelectItemId]);
|
||||
const onLast = useCallback(() => {
|
||||
const last = items.at(-1);
|
||||
if (!last) {
|
||||
return;
|
||||
}
|
||||
onSelectItemId(last.item_id);
|
||||
}, [items, onSelectItemId]);
|
||||
|
||||
useHotkeys('left', onPrev, { preventDefault: true });
|
||||
useHotkeys('right', onNext, { preventDefault: true });
|
||||
useHotkeys('meta+left', onFirst, { preventDefault: true });
|
||||
useHotkeys('meta+right', onLast, { preventDefault: true });
|
||||
|
||||
const socket = useStore($socket);
|
||||
useEffect(() => {
|
||||
if (!autoSwitch) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!socket) {
|
||||
return;
|
||||
}
|
||||
|
||||
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
|
||||
if (data.destination !== 'canvas') {
|
||||
if (data.destination !== ctx.session.id) {
|
||||
return;
|
||||
}
|
||||
if (data.status === 'in_progress') {
|
||||
if (data.status === 'in_progress' && autoSwitch) {
|
||||
onSelectItemId(data.item_id);
|
||||
}
|
||||
if (data.status === 'completed' || data.status === 'canceled' || data.status === 'failed') {
|
||||
clearProgressEvent(ctx.$progressData, data.session_id);
|
||||
}
|
||||
};
|
||||
|
||||
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
@@ -414,7 +474,7 @@ const StagingArea = memo(() => {
|
||||
return () => {
|
||||
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
};
|
||||
}, [autoSwitch, onSelectItemId, socket]);
|
||||
}, [autoSwitch, ctx.$progressData, ctx.session.id, onSelectItemId, socket]);
|
||||
|
||||
const _onChangeAutoSwitch = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setAutoSwitch(e.target.checked);
|
||||
@@ -425,17 +485,17 @@ const StagingArea = memo(() => {
|
||||
return;
|
||||
}
|
||||
const onProgress = (data: S['InvocationProgressEvent']) => {
|
||||
if (data.destination !== 'canvas') {
|
||||
if (data.destination !== ctx.session.id) {
|
||||
return;
|
||||
}
|
||||
setProgress(data);
|
||||
setProgress(ctx.$progressData, data);
|
||||
};
|
||||
socket.on('invocation_progress', onProgress);
|
||||
|
||||
return () => {
|
||||
socket.off('invocation_progress', onProgress);
|
||||
};
|
||||
}, [socket]);
|
||||
}, [ctx.$progressData, ctx.session.id, socket]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={2} w="full" h="full" minW={0} minH={0}>
|
||||
@@ -467,8 +527,8 @@ const StagingArea = memo(() => {
|
||||
<Switch size="sm" isChecked={autoSwitch} onChange={_onChangeAutoSwitch} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex position="relative" maxW="full" h={108} justifyContent="center">
|
||||
<Flex ref={scrollableRef} gap={2} maxW="full" overflowX="scroll" flexShrink={0}>
|
||||
<Flex position="relative" maxW="full" w="full" h={108}>
|
||||
<Flex ref={scrollableRef} gap={2} maxW="full" overflowX="scroll">
|
||||
{items.map((item, i) => (
|
||||
<QueueItemCard
|
||||
key={item.item_id}
|
||||
@@ -511,7 +571,7 @@ const StagingArea = memo(() => {
|
||||
});
|
||||
StagingArea.displayName = 'StagingArea';
|
||||
|
||||
const queueItemStatusCardMiniSx = {
|
||||
const queueItemCardSx = {
|
||||
cursor: 'pointer',
|
||||
pos: 'relative',
|
||||
borderWidth: 1,
|
||||
@@ -528,6 +588,9 @@ const queueItemStatusCardMiniSx = {
|
||||
'&[data-size="mini"]': {
|
||||
flexShrink: 0,
|
||||
},
|
||||
'&[data-size="full"]&[data-has-progress-image="false"]': {
|
||||
w: 1024,
|
||||
},
|
||||
};
|
||||
|
||||
const getCardId = (item_id: number) => `queue-item-status-card-${item_id}`;
|
||||
@@ -543,7 +606,9 @@ type QueueItemStatusCardMiniProps = {
|
||||
|
||||
const QueueItemCard = memo(
|
||||
({ item, isSelected, number, onSelectItemId, onChangeAutoSwitch, size }: QueueItemStatusCardMiniProps) => {
|
||||
const ctx = useStagingContext();
|
||||
const [isImageLoaded, setIsImageLoaded] = useState(false);
|
||||
const hasProgressImage = useHasProgressImage(ctx.$progressData, item.session_id);
|
||||
|
||||
const outputImageName = useMemo(() => {
|
||||
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>
|
||||
@@ -566,13 +631,23 @@ const QueueItemCard = memo(
|
||||
|
||||
const { currentData: imageDTO } = useGetImageDTOQuery(outputImageName ?? skipToken);
|
||||
|
||||
useEffect(() => {
|
||||
if (imageDTO) {
|
||||
loadImage(imageDTO.thumbnail_url, true).then(() => {
|
||||
setIsImageLoaded(true);
|
||||
});
|
||||
const syncIsReady = useCallback(async () => {
|
||||
if (!imageDTO) {
|
||||
setIsImageLoaded(false);
|
||||
return;
|
||||
}
|
||||
}, [imageDTO, item.session_id]);
|
||||
try {
|
||||
const _ = await loadImage(size === 'mini' ? imageDTO.thumbnail_url : imageDTO.image_url, true);
|
||||
setIsImageLoaded(true);
|
||||
return;
|
||||
} catch {
|
||||
setIsImageLoaded(false);
|
||||
}
|
||||
}, [imageDTO, size]);
|
||||
|
||||
useEffect(() => {
|
||||
syncIsReady();
|
||||
}, [syncIsReady]);
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
onSelectItemId(item.item_id);
|
||||
@@ -584,9 +659,16 @@ const QueueItemCard = memo(
|
||||
|
||||
if (imageDTO && isImageLoaded) {
|
||||
return (
|
||||
<Flex id={getCardId(item.item_id)} sx={queueItemStatusCardMiniSx} data-selected={isSelected} data-size={size}>
|
||||
<DndImage imageDTO={imageDTO} onClick={onClick} onDoubleClick={onDoubleClick} />
|
||||
<Text position="absolute" top={0} left={1} pointerEvents="none" userSelect="none">{`#${number}`}</Text>
|
||||
<Flex id={getCardId(item.item_id)} sx={queueItemCardSx} data-selected={isSelected} data-size={size}>
|
||||
<DndImage imageDTO={imageDTO} onClick={onClick} onDoubleClick={onDoubleClick} asThumbnail={size === 'mini'} />
|
||||
<Text
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={1}
|
||||
pointerEvents="none"
|
||||
userSelect="none"
|
||||
filter={DROP_SHADOW}
|
||||
>{`#${number}`}</Text>
|
||||
{size === 'full' && (
|
||||
<Flex position="absolute" top={2} right={2}>
|
||||
<ImageActions imageDTO={imageDTO} />
|
||||
@@ -599,40 +681,72 @@ const QueueItemCard = memo(
|
||||
return (
|
||||
<Flex
|
||||
id={getCardId(item.item_id)}
|
||||
sx={queueItemStatusCardMiniSx}
|
||||
sx={queueItemCardSx}
|
||||
data-selected={isSelected}
|
||||
data-size={size}
|
||||
data-has-progress-image={hasProgressImage}
|
||||
onClick={onClick}
|
||||
onDoubleClick={onDoubleClick}
|
||||
>
|
||||
<InProgressContent item={item} />
|
||||
<Text position="absolute" top={0} left={1} pointerEvents="none" userSelect="none">{`#${number}`}</Text>
|
||||
<Text
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={1}
|
||||
pointerEvents="none"
|
||||
userSelect="none"
|
||||
filter={DROP_SHADOW}
|
||||
>{`#${number}`}</Text>
|
||||
{size === 'full' && <ProgressMessage key={item.session_id} session_id={item.session_id} />}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
QueueItemCard.displayName = 'QueueItemStatusCard';
|
||||
|
||||
const getMessage = (data: S['InvocationProgressEvent']) => {
|
||||
let message = data.message;
|
||||
if (data.percentage) {
|
||||
message += ` (${round(data.percentage * 100)}%)`;
|
||||
}
|
||||
return message;
|
||||
};
|
||||
|
||||
const ProgressMessage = memo(({ session_id }: { session_id: string }) => {
|
||||
const { $progressData } = useStagingContext();
|
||||
const { progressEvent } = useProgressData($progressData, session_id);
|
||||
if (!progressEvent) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<Text position="absolute" bottom={2} left={2} pointerEvents="none" userSelect="none" filter={DROP_SHADOW}>
|
||||
{getMessage(progressEvent)}
|
||||
</Text>
|
||||
);
|
||||
});
|
||||
ProgressMessage.displayName = 'ProgressMessage';
|
||||
|
||||
const InProgressContent = memo(({ item }: { item: S['SessionQueueItem'] }) => {
|
||||
const { progressEvent, progressImage } = useProgressData(item.session_id);
|
||||
const { $progressData } = useStagingContext();
|
||||
const { progressEvent, progressImage } = useProgressData($progressData, item.session_id);
|
||||
|
||||
if (item.status === 'pending') {
|
||||
return (
|
||||
<Text fontWeight="semibold" color="base.300">
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="base.300">
|
||||
Pending
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
if (item.status === 'canceled') {
|
||||
return (
|
||||
<Text fontWeight="semibold" color="warning.300">
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="warning.300">
|
||||
Canceled
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
if (item.status === 'failed') {
|
||||
return (
|
||||
<Text fontWeight="semibold" color="error.300">
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300">
|
||||
Failed
|
||||
</Text>
|
||||
);
|
||||
@@ -650,7 +764,7 @@ const InProgressContent = memo(({ item }: { item: S['SessionQueueItem'] }) => {
|
||||
if (item.status === 'in_progress') {
|
||||
return (
|
||||
<>
|
||||
<Text fontWeight="semibold" color="invokeBlue.300">
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeBlue.300">
|
||||
In Progress
|
||||
</Text>
|
||||
<ProgressCircle data={progressEvent} />
|
||||
@@ -660,7 +774,7 @@ const InProgressContent = memo(({ item }: { item: S['SessionQueueItem'] }) => {
|
||||
|
||||
if (item.status === 'completed') {
|
||||
return (
|
||||
<Text fontWeight="semibold" color="error.300">
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300">
|
||||
Unable to get image
|
||||
</Text>
|
||||
);
|
||||
|
||||
@@ -20,7 +20,7 @@ export const useNewGallerySession = () => {
|
||||
const newSessionDialog = useNewGallerySessionDialog();
|
||||
|
||||
const newGallerySessionImmediate = useCallback(() => {
|
||||
dispatch(canvasSessionStarted({ sessionType: null }));
|
||||
dispatch(canvasSessionStarted({ sessionType: 'simple' }));
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { canvasReset } from 'features/controlLayers/store/actions';
|
||||
import type { StagingAreaImage, StagingAreaProgressImage } from 'features/controlLayers/store/types';
|
||||
import { selectCanvasQueueCounts } from 'services/api/endpoints/queue';
|
||||
|
||||
type CanvasStagingAreaState = {
|
||||
session: { type: 'simple'; id: string } | { type: 'advanced'; id: string } | null;
|
||||
sessionType: 'simple' | 'advanced' | null;
|
||||
images: (StagingAreaImage | StagingAreaProgressImage)[];
|
||||
selectedImageIndex: number;
|
||||
};
|
||||
|
||||
const INITIAL_STATE: CanvasStagingAreaState = {
|
||||
session: null,
|
||||
sessionType: null,
|
||||
images: [],
|
||||
selectedImageIndex: 0,
|
||||
@@ -23,6 +26,10 @@ export const canvasSessionSlice = createSlice({
|
||||
name: 'canvasSession',
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
sessionChanged: (state, action: PayloadAction<{ session: CanvasStagingAreaState['session'] }>) => {
|
||||
const { session } = action.payload;
|
||||
state.session = session;
|
||||
},
|
||||
stagingAreaImageStaged: (state, action: PayloadAction<{ stagingAreaImage: StagingAreaImage }>) => {
|
||||
const { stagingAreaImage } = action.payload;
|
||||
let didReplace = false;
|
||||
@@ -67,11 +74,19 @@ export const canvasSessionSlice = createSlice({
|
||||
state.images = [];
|
||||
state.selectedImageIndex = 0;
|
||||
},
|
||||
canvasSessionStarted: (_, action: PayloadAction<{ sessionType: CanvasStagingAreaState['sessionType'] }>) => {
|
||||
const { sessionType } = action.payload;
|
||||
const state = getInitialState();
|
||||
state.sessionType = sessionType;
|
||||
return state;
|
||||
canvasSessionStarted: {
|
||||
reducer: (state, action: PayloadAction<{ session: CanvasStagingAreaState['session'] }>) => {
|
||||
const { session } = action.payload;
|
||||
state.session = session;
|
||||
},
|
||||
prepare: (payload: { sessionType: 'simple' | 'advanced' }) => ({
|
||||
payload: {
|
||||
session: {
|
||||
type: payload.sessionType,
|
||||
id: getPrefixedId(`canvas:${payload.sessionType}`),
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
@@ -80,6 +95,7 @@ export const canvasSessionSlice = createSlice({
|
||||
});
|
||||
|
||||
export const {
|
||||
sessionChanged,
|
||||
stagingAreaImageStaged,
|
||||
stagingAreaGenerationStarted,
|
||||
stagingAreaGenerationFinished,
|
||||
@@ -140,3 +156,7 @@ export const selectCanvasSessionType = createSelector(
|
||||
selectCanvasStagingAreaSlice,
|
||||
(canvasSession) => canvasSession.sessionType
|
||||
);
|
||||
export const selectCanvasSession = createSelector(
|
||||
selectCanvasStagingAreaSlice,
|
||||
(canvasSession) => canvasSession.session
|
||||
);
|
||||
|
||||
@@ -34,7 +34,7 @@ export const prepareLinearUIBatch = (arg: {
|
||||
seedFieldIdentifier?: FieldIdentifier;
|
||||
positivePromptFieldIdentifier: FieldIdentifier;
|
||||
origin: 'canvas' | 'workflows' | 'upscaling';
|
||||
destination: 'canvas' | 'gallery';
|
||||
destination: string;
|
||||
}): EnqueueBatchArg => {
|
||||
const { state, g, prepend, seedFieldIdentifier, positivePromptFieldIdentifier, origin, destination } = arg;
|
||||
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.params;
|
||||
|
||||
@@ -1,22 +1,18 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { stagingAreaImageStaged } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { isCanvasOutputEvent } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { flushSync } from 'react-dom';
|
||||
import type { ApiTagDescription } from 'services/api';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { getCategories, getListImagesUrl } from 'services/api/util';
|
||||
import {
|
||||
$lastCanvasProgressImage,
|
||||
$lastProgressEvent,
|
||||
$progressImages,
|
||||
} from 'services/events/stores';
|
||||
import type { Param0 } from 'tsafe';
|
||||
import { objectEntries } from 'tsafe';
|
||||
@@ -180,29 +176,29 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
|
||||
|
||||
await addImagesToGallery(data);
|
||||
|
||||
// We expect only a single image in the canvas output
|
||||
const imageDTO = (await getResultImageDTOs(data))[0];
|
||||
// // We expect only a single image in the canvas output
|
||||
// const imageDTO = (await getResultImageDTOs(data))[0];
|
||||
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
// if (!imageDTO) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
flushSync(() => {
|
||||
dispatch(
|
||||
stagingAreaImageStaged({
|
||||
stagingAreaImage: { type: 'staged', sessionId: data.session_id, imageDTO, offsetX: 0, offsetY: 0 },
|
||||
})
|
||||
);
|
||||
});
|
||||
// flushSync(() => {
|
||||
// dispatch(
|
||||
// stagingAreaImageStaged({
|
||||
// stagingAreaImage: { type: 'staged', sessionId: data.session_id, imageDTO, offsetX: 0, offsetY: 0 },
|
||||
// })
|
||||
// );
|
||||
// });
|
||||
|
||||
const progressData = $progressImages.get()[data.session_id];
|
||||
if (progressData) {
|
||||
$progressImages.setKey(data.session_id, { ...progressData, isFinished: true, resultImage: imageDTO });
|
||||
} else {
|
||||
$progressImages.setKey(data.session_id, { sessionId: data.session_id, isFinished: true, resultImage: imageDTO });
|
||||
}
|
||||
// const progressData = $progressImages.get()[data.session_id];
|
||||
// if (progressData) {
|
||||
// $progressImages.setKey(data.session_id, { ...progressData, isFinished: true, resultImage: imageDTO });
|
||||
// } else {
|
||||
// $progressImages.setKey(data.session_id, { sessionId: data.session_id, isFinished: true, resultImage: imageDTO });
|
||||
// }
|
||||
|
||||
$lastCanvasProgressImage.set(null);
|
||||
// $lastCanvasProgressImage.set(null);
|
||||
};
|
||||
|
||||
const handleOriginOther = async (data: S['InvocationCompleteEvent']) => {
|
||||
|
||||
@@ -8,7 +8,6 @@ import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
|
||||
import { $queueId } from 'app/store/nanostores/queueId';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { stagingAreaGenerationStarted } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
$isInPublishFlow,
|
||||
$outputNodeId,
|
||||
@@ -397,8 +396,8 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
|
||||
$nodeExecutionStates.setKey(clone.nodeId, clone);
|
||||
});
|
||||
if (data.origin === 'canvas') {
|
||||
store.dispatch(stagingAreaGenerationStarted({ sessionId: session_id }));
|
||||
$progressImages.setKey(session_id, { sessionId: session_id, isFinished: false });
|
||||
// store.dispatch(stagingAreaGenerationStarted({ sessionId: session_id }));
|
||||
// $progressImages.setKey(session_id, { sessionId: session_id, isFinished: false });
|
||||
}
|
||||
} else if (status === 'completed' || status === 'failed' || status === 'canceled') {
|
||||
if (status === 'failed' && error_type) {
|
||||
@@ -423,7 +422,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
|
||||
}
|
||||
// If the queue item is completed, failed, or cancelled, we want to clear the last progress event
|
||||
$lastProgressEvent.set(null);
|
||||
$progressImages.setKey(session_id, undefined);
|
||||
// $progressImages.setKey(session_id, undefined);
|
||||
|
||||
// When a validation run is completed, we want to clear the validation run batch ID & set the workflow as published
|
||||
const validationRunData = $validationRunData.get();
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { EphemeralProgressImage } from 'features/controlLayers/store/types';
|
||||
import type { ProgressImage } from 'features/nodes/types/common';
|
||||
import { round } from 'lodash-es';
|
||||
import type { WritableAtom } from 'nanostores';
|
||||
import { atom, computed, map } from 'nanostores';
|
||||
import { useEffect, useState } from 'react';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
@@ -20,16 +21,19 @@ export type ProgressAndResult = {
|
||||
};
|
||||
export const $progressImages = map({} as Record<string, ProgressAndResult>);
|
||||
|
||||
type ProgressData = {
|
||||
export type ProgressData = {
|
||||
sessionId: string;
|
||||
progressEvent: S['InvocationProgressEvent'] | null;
|
||||
progressImage: ProgressImage | null;
|
||||
};
|
||||
|
||||
export const $progressData = atom<Record<string, ProgressData>>({});
|
||||
|
||||
export const useProgressData = (sessionId: string): ProgressData => {
|
||||
const [value, setValue] = useState<ProgressData>({ sessionId, progressEvent: null, progressImage: null });
|
||||
export const useProgressData = (
|
||||
$progressData: WritableAtom<Record<string, ProgressData>>,
|
||||
sessionId: string
|
||||
): ProgressData => {
|
||||
const [value, setValue] = useState<ProgressData>(() => {
|
||||
return $progressData.get()[sessionId] ?? { sessionId, progressEvent: null, progressImage: null };
|
||||
});
|
||||
useEffect(() => {
|
||||
const unsub = $progressData.subscribe((data) => {
|
||||
const progressData = data[sessionId];
|
||||
@@ -41,12 +45,33 @@ export const useProgressData = (sessionId: string): ProgressData => {
|
||||
return () => {
|
||||
unsub();
|
||||
};
|
||||
}, [sessionId]);
|
||||
}, [$progressData, sessionId]);
|
||||
|
||||
return value;
|
||||
};
|
||||
|
||||
export const setProgress = (data: S['InvocationProgressEvent']) => {
|
||||
export const useHasProgressImage = (
|
||||
$progressData: WritableAtom<Record<string, ProgressData>>,
|
||||
sessionId: string
|
||||
): boolean => {
|
||||
const [value, setValue] = useState(false);
|
||||
useEffect(() => {
|
||||
const unsub = $progressData.subscribe((data) => {
|
||||
const progressData = data[sessionId];
|
||||
setValue(Boolean(progressData?.progressImage));
|
||||
});
|
||||
return () => {
|
||||
unsub();
|
||||
};
|
||||
}, [$progressData, sessionId]);
|
||||
|
||||
return value;
|
||||
};
|
||||
|
||||
export const setProgress = (
|
||||
$progressData: WritableAtom<Record<string, ProgressData>>,
|
||||
data: S['InvocationProgressEvent']
|
||||
) => {
|
||||
const progressData = $progressData.get();
|
||||
const current = progressData[data.session_id];
|
||||
if (current) {
|
||||
@@ -71,7 +96,21 @@ export const setProgress = (data: S['InvocationProgressEvent']) => {
|
||||
}
|
||||
};
|
||||
|
||||
export const clearProgressImage = (sessionId: string) => {
|
||||
export const clearProgressEvent = ($progressData: WritableAtom<Record<string, ProgressData>>, sessionId: string) => {
|
||||
const progressData = $progressData.get();
|
||||
const current = progressData[sessionId];
|
||||
if (!current) {
|
||||
return;
|
||||
}
|
||||
const next = { ...current };
|
||||
next.progressEvent = null;
|
||||
$progressData.set({
|
||||
...progressData,
|
||||
[sessionId]: next,
|
||||
});
|
||||
};
|
||||
|
||||
export const clearProgressImage = ($progressData: WritableAtom<Record<string, ProgressData>>, sessionId: string) => {
|
||||
const progressData = $progressData.get();
|
||||
const current = progressData[sessionId];
|
||||
if (!current) {
|
||||
|
||||
Reference in New Issue
Block a user