feat: canvas flow rework (wip)

This commit is contained in:
psychedelicious
2025-06-05 11:54:06 +10:00
parent c8df7cd2c0
commit ea34690709
10 changed files with 156 additions and 258 deletions

View File

@@ -1,6 +1,6 @@
import type { CircularProgressProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { CircularProgress, Tooltip } from '@invoke-ai/ui-library';
import { useCanvasSessionContext,useProgressData } from 'features/controlLayers/components/SimpleSession/context';
import { useCanvasSessionContext, useProgressData } from 'features/controlLayers/components/SimpleSession/context';
import { getProgressMessage } from 'features/controlLayers/components/SimpleSession/shared';
import { memo } from 'react';
import type { S } from 'services/api/types';
@@ -15,32 +15,28 @@ const circleStyles: SystemStyleObject = {
right: 2,
};
export const QueueItemCircularProgress = memo(
({
session_id,
status,
...rest
}: { session_id: string; status: S['SessionQueueItem']['status'] } & CircularProgressProps) => {
const { $progressData } = useCanvasSessionContext();
const { progressEvent } = useProgressData($progressData, session_id);
type Props = { itemId: number; status: S['SessionQueueItem']['status'] } & CircularProgressProps;
if (status !== 'in_progress') {
return null;
}
export const QueueItemCircularProgress = memo(({ itemId, status, ...rest }: Props) => {
const { $progressData } = useCanvasSessionContext();
const { progressEvent } = useProgressData($progressData, itemId);
return (
<Tooltip label={getProgressMessage(progressEvent)}>
<CircularProgress
size="14px"
color="invokeBlue.500"
thickness={14}
isIndeterminate={!progressEvent || progressEvent.percentage === null}
value={progressEvent?.percentage ? progressEvent.percentage * 100 : undefined}
sx={circleStyles}
{...rest}
/>
</Tooltip>
);
if (status !== 'in_progress') {
return null;
}
);
return (
<Tooltip label={getProgressMessage(progressEvent)}>
<CircularProgress
size="14px"
color="invokeBlue.500"
thickness={14}
isIndeterminate={!progressEvent || progressEvent.percentage === null}
value={progressEvent?.percentage ? progressEvent.percentage * 100 : undefined}
sx={circleStyles}
{...rest}
/>
</Tooltip>
);
});
QueueItemCircularProgress.displayName = 'QueueItemCircularProgress';

View File

@@ -38,23 +38,11 @@ export const QueueItemPreviewFull = memo(({ item, number }: Props) => {
<Flex id={getQueueItemElementId(item.item_id)} sx={sx}>
<QueueItemStatusLabel status={item.status} position="absolute" margin="auto" />
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} />}
{!imageLoaded && <QueueItemProgressImage session_id={item.session_id} position="absolute" />}
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
{imageDTO && <ImageActions imageDTO={imageDTO} position="absolute" top={1} right={2} />}
<QueueItemNumber number={number} position="absolute" top={1} left={2} />
<QueueItemProgressMessage
session_id={item.session_id}
status={item.status}
position="absolute"
bottom={1}
left={2}
/>
<QueueItemCircularProgress
session_id={item.session_id}
status={item.status}
position="absolute"
top={1}
right={2}
/>
<QueueItemProgressMessage itemId={item.item_id} status={item.status} position="absolute" bottom={1} left={2} />
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />
</Flex>
);
});

View File

@@ -1,5 +1,6 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { QueueItemCircularProgress } from 'features/controlLayers/components/SimpleSession/QueueItemCircularProgress';
import { QueueItemNumber } from 'features/controlLayers/components/SimpleSession/QueueItemNumber';
import { QueueItemProgressImage } from 'features/controlLayers/components/SimpleSession/QueueItemProgressImage';
@@ -34,25 +35,25 @@ type Props = {
item: S['SessionQueueItem'];
number: number;
isSelected: boolean;
onSelectItemId: (item_id: number) => void;
onChangeAutoSwitch: (autoSwitch: boolean) => void;
};
export const QueueItemPreviewMini = memo(({ item, isSelected, number, onSelectItemId, onChangeAutoSwitch }: Props) => {
export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) => {
const ctx = useCanvasSessionContext();
const [imageLoaded, setImageLoaded] = useState(false);
const imageDTO = useOutputImageDTO(item);
const onClick = useCallback(() => {
onSelectItemId(item.item_id);
}, [item.item_id, onSelectItemId]);
ctx.$selectedItemId.set(item.item_id);
}, [ctx.$selectedItemId, item.item_id]);
const onDoubleClick = useCallback(() => {
onChangeAutoSwitch(item.status === 'in_progress');
}, [item.status, onChangeAutoSwitch]);
ctx.$autoSwitch.set(item.status === 'in_progress');
}, [ctx.$autoSwitch, item.status]);
const onLoad = useCallback(() => {
setImageLoaded(true);
}, []);
ctx.$lastLoadedItemId.set(item.item_id);
}, [ctx.$lastLoadedItemId, item.item_id]);
return (
<Flex
@@ -63,16 +64,10 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, number, onSelectIt
onDoubleClick={onDoubleClick}
>
<QueueItemStatusLabel status={item.status} position="absolute" margin="auto" />
{imageDTO && <DndImage imageDTO={imageDTO} asThumbnail onLoad={onLoad} />}
{!imageLoaded && <QueueItemProgressImage session_id={item.session_id} position="absolute" />}
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} />}
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
<QueueItemNumber number={number} position="absolute" top={0} left={1} />
<QueueItemCircularProgress
session_id={item.session_id}
status={item.status}
position="absolute"
top={1}
right={2}
/>
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />
</Flex>
);
});

View File

@@ -1,12 +1,13 @@
import type { ImageProps } from '@invoke-ai/ui-library';
import { Image } from '@invoke-ai/ui-library';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useCanvasSessionContext, useProgressData } from 'features/controlLayers/components/SimpleSession/context';
import { memo } from 'react';
import { useProgressData } from 'services/events/stores';
export const QueueItemProgressImage = memo(({ session_id, ...rest }: { session_id: string } & ImageProps) => {
const { $progressData } = useCanvasSessionContext();
const { progressImage } = useProgressData($progressData, session_id);
type Props = { itemId: number } & ImageProps;
export const QueueItemProgressImage = memo(({ itemId, ...rest }: Props) => {
const ctx = useCanvasSessionContext();
const { progressImage } = useProgressData(ctx.$progressData, itemId);
if (!progressImage) {
return null;

View File

@@ -1,34 +1,33 @@
/* eslint-disable i18next/no-literal-string */
import type { TextProps } from '@invoke-ai/ui-library';
import { Text } from '@invoke-ai/ui-library';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useCanvasSessionContext, useProgressData } from 'features/controlLayers/components/SimpleSession/context';
import { DROP_SHADOW, getProgressMessage } from 'features/controlLayers/components/SimpleSession/shared';
import { memo } from 'react';
import type { S } from 'services/api/types';
import { useProgressData } from 'services/events/stores';
export const QueueItemProgressMessage = memo(
({ session_id, status, ...rest }: { session_id: string; status: S['SessionQueueItem']['status'] } & TextProps) => {
const { $progressData } = useCanvasSessionContext();
const { progressEvent } = useProgressData($progressData, session_id);
type Props = { itemId: number; status: S['SessionQueueItem']['status'] } & TextProps;
if (status === 'completed' || status === 'failed' || status === 'canceled') {
return null;
}
export const QueueItemProgressMessage = memo(({ itemId, status, ...rest }: Props) => {
const ctx = useCanvasSessionContext();
const { progressEvent } = useProgressData(ctx.$progressData, itemId);
if (status === 'pending') {
return (
<Text pointerEvents="none" userSelect="none" filter={DROP_SHADOW} {...rest}>
Waiting to start...
</Text>
);
}
if (status === 'completed' || status === 'failed' || status === 'canceled') {
return null;
}
if (status === 'pending') {
return (
<Text pointerEvents="none" userSelect="none" filter={DROP_SHADOW} {...rest}>
{getProgressMessage(progressEvent)}
Waiting to start...
</Text>
);
}
);
return (
<Text pointerEvents="none" userSelect="none" filter={DROP_SHADOW} {...rest}>
{getProgressMessage(progressEvent)}
</Text>
);
});
QueueItemProgressMessage.displayName = 'QueueItemProgressMessage';

View File

@@ -4,39 +4,39 @@ import { Text } from '@invoke-ai/ui-library';
import { memo } from 'react';
import type { S } from 'services/api/types';
export const QueueItemStatusLabel = memo(
({ status, ...rest }: { status: S['SessionQueueItem']['status'] } & TextProps) => {
if (status === 'pending') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="base.300" {...rest}>
Pending
</Text>
);
}
if (status === 'canceled') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="warning.300" {...rest}>
Canceled
</Text>
);
}
if (status === 'failed') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300" {...rest}>
Failed
</Text>
);
}
type Props = { status: S['SessionQueueItem']['status'] } & TextProps;
if (status === 'in_progress') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeBlue.300" {...rest}>
In Progress
</Text>
);
}
return null;
export const QueueItemStatusLabel = memo(({ status, ...rest }: Props) => {
if (status === 'pending') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="base.300" {...rest}>
Pending
</Text>
);
}
);
if (status === 'canceled') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="warning.300" {...rest}>
Canceled
</Text>
);
}
if (status === 'failed') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300" {...rest}>
Failed
</Text>
);
}
if (status === 'in_progress') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeBlue.300" {...rest}>
In Progress
</Text>
);
}
return null;
});
QueueItemStatusLabel.displayName = 'QueueItemStatusLabel';

View File

@@ -20,8 +20,6 @@ export const StagingAreaItemsList = memo(() => {
item={item}
number={i + 1}
isSelected={selectedItemId === item.item_id}
onSelectItemId={ctx.$selectedItemId.set}
onChangeAutoSwitch={ctx.$autoSwitch.set}
/>
))}
</Flex>

View File

@@ -17,23 +17,21 @@ import { $socket } from 'services/events/stores';
import { assert } from 'tsafe';
export type ProgressData = {
sessionId: string;
itemId: number;
progressEvent: S['InvocationProgressEvent'] | null;
progressImage: ProgressImage | null;
};
export const buildProgressDataAtom = () => atom<Record<string, ProgressData>>({});
export const useProgressData = (
$progressData: WritableAtom<Record<string, ProgressData>>,
sessionId: string
$progressData: WritableAtom<Record<number, ProgressData>>,
itemId: number
): ProgressData => {
const [value, setValue] = useState<ProgressData>(() => {
return $progressData.get()[sessionId] ?? { sessionId, progressEvent: null, progressImage: null };
return $progressData.get()[itemId] ?? { itemId, progressEvent: null, progressImage: null };
});
useEffect(() => {
const unsub = $progressData.subscribe((data) => {
const progressData = data[sessionId];
const progressData = data[itemId];
if (!progressData) {
return;
}
@@ -42,35 +40,35 @@ export const useProgressData = (
return () => {
unsub();
};
}, [$progressData, sessionId]);
}, [$progressData, itemId]);
return value;
};
export const useHasProgressImage = (
$progressData: WritableAtom<Record<string, ProgressData>>,
sessionId: string
$progressData: WritableAtom<Record<number, ProgressData>>,
itemId: number
): boolean => {
const [value, setValue] = useState(false);
useEffect(() => {
const unsub = $progressData.subscribe((data) => {
const progressData = data[sessionId];
const progressData = data[itemId];
setValue(Boolean(progressData?.progressImage));
});
return () => {
unsub();
};
}, [$progressData, sessionId]);
}, [$progressData, itemId]);
return value;
};
export const setProgress = (
$progressData: WritableAtom<Record<string, ProgressData>>,
$progressData: WritableAtom<Record<number, ProgressData>>,
data: S['InvocationProgressEvent']
) => {
const progressData = $progressData.get();
const current = progressData[data.session_id];
const current = progressData[data.item_id];
if (current) {
const next = { ...current };
next.progressEvent = data;
@@ -79,13 +77,13 @@ export const setProgress = (
}
$progressData.set({
...progressData,
[data.session_id]: next,
[data.item_id]: next,
});
} else {
$progressData.set({
...progressData,
[data.session_id]: {
sessionId: data.session_id,
[data.item_id]: {
itemId: data.item_id,
progressEvent: data,
progressImage: data.image ?? null,
},
@@ -93,9 +91,9 @@ export const setProgress = (
}
};
export const clearProgressEvent = ($progressData: WritableAtom<Record<string, ProgressData>>, sessionId: string) => {
export const clearProgressEvent = ($progressData: WritableAtom<Record<number, ProgressData>>, itemId: number) => {
const progressData = $progressData.get();
const current = progressData[sessionId];
const current = progressData[itemId];
if (!current) {
return;
}
@@ -103,13 +101,13 @@ export const clearProgressEvent = ($progressData: WritableAtom<Record<string, Pr
next.progressEvent = null;
$progressData.set({
...progressData,
[sessionId]: next,
[itemId]: next,
});
};
export const clearProgressImage = ($progressData: WritableAtom<Record<string, ProgressData>>, sessionId: string) => {
export const clearProgressImage = ($progressData: WritableAtom<Record<number, ProgressData>>, itemId: number) => {
const progressData = $progressData.get();
const current = progressData[sessionId];
const current = progressData[itemId];
if (!current) {
return;
}
@@ -117,7 +115,7 @@ export const clearProgressImage = ($progressData: WritableAtom<Record<string, Pr
next.progressImage = null;
$progressData.set({
...progressData,
[sessionId]: next,
[itemId]: next,
});
};
@@ -130,6 +128,7 @@ export type CanvasSessionContextValue = {
$selectedItem: Atom<S['SessionQueueItem'] | null>;
$selectedItemIndex: Atom<number | null>;
$autoSwitch: WritableAtom<boolean>;
$lastLoadedItemId: WritableAtom<number | null>;
};
const CanvasSessionContext = createContext<CanvasSessionContextValue | null>(null);
@@ -159,10 +158,16 @@ export const CanvasSessionContextProvider = memo(
*/
const $autoSwitch = useState(() => atom(true))[0];
/**
* An internal flag used to work around race conditions with auto-switch switching to queue items before their
* output images have fully loaded.
*/
const $lastLoadedItemId = useState(() => atom<number | null>(null))[0];
/**
* An ephemeral store of progress events and images for all items in the current session.
*/
const $progressData = useState(() => atom<Record<string, ProgressData>>({}))[0];
const $progressData = useState(() => atom<Record<number, ProgressData>>({}))[0];
/**
* The currently selected queue item's ID, or null if one is not selected.
@@ -231,21 +236,10 @@ export const CanvasSessionContextProvider = memo(
setProgress($progressData, data);
};
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
if (data.destination !== session.id) {
return;
}
if (data.status === 'completed' && $autoSwitch.get()) {
$selectedItemId.set(data.item_id);
}
};
socket.on('invocation_progress', onProgress);
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
return () => {
socket.off('invocation_progress', onProgress);
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [$autoSwitch, $progressData, $selectedItemId, session.id, socket]);
@@ -285,23 +279,37 @@ export const CanvasSessionContextProvider = memo(
// Clean up the progress data when a queue item is discarded.
const unsubCleanUpProgressData = effect([$items, $progressData], (items, progressData) => {
const toDelete: string[] = [];
const toDelete: number[] = [];
for (const datum of Object.values(progressData)) {
if (items.findIndex(({ session_id }) => session_id === datum.sessionId) === -1) {
toDelete.push(datum.sessionId);
if (items.findIndex(({ item_id }) => item_id === datum.itemId) === -1) {
toDelete.push(datum.itemId);
}
}
if (toDelete.length === 0) {
return;
}
const newProgressData = { ...progressData };
for (const sessionId of toDelete) {
delete newProgressData[sessionId];
for (const itemId of toDelete) {
delete newProgressData[itemId];
}
// This will re-trigger the effect - maybe this could just be a listener on $items? Brain hurt
$progressData.set(newProgressData);
});
// We only want to auto-switch to completed queue items once their images have fully loaded to prevent flashes
// of fallback content and/or progress images. The only surefire way to determine when images have fully loaded
// is via the image elements' `onLoad` callback. Images set `$lastLoadedItemId` to their queue item ID in their
// `onLoad` handler, and we listen for that here. If auto-switch is enabled, we then switch the to the item.
const unsubHandleAutoSwitch = $lastLoadedItemId.listen((lastLoadedItemId) => {
if (lastLoadedItemId === null) {
return;
}
if ($autoSwitch.get()) {
$selectedItemId.set(lastLoadedItemId);
}
$lastLoadedItemId.set(null);
});
// Create an RTK Query subscription. Without this, the query cache selector will never return anything bc RTK
// doesn't know we care about it.
const { unsubscribe: unsubQueueItemsQuery } = store.dispatch(
@@ -310,6 +318,7 @@ export const CanvasSessionContextProvider = memo(
// Clean up all subscriptions and top-level (i.e. non-computed/derived state)
return () => {
unsubHandleAutoSwitch();
unsubQueueItemsQuery();
unsubReduxSyncToItemsAtom();
unsubEnsureSelectedItemIdExists();
@@ -318,7 +327,7 @@ export const CanvasSessionContextProvider = memo(
$progressData.set({});
$selectedItemId.set(null);
};
}, [$items, $progressData, $selectedItemId, selectQueueItems, session.id, store]);
}, [$autoSwitch, $items, $lastLoadedItemId, $progressData, $selectedItemId, selectQueueItems, session.id, store]);
const value = useMemo<CanvasSessionContextValue>(
() => ({
@@ -330,8 +339,19 @@ export const CanvasSessionContextProvider = memo(
$autoSwitch,
$selectedItem,
$selectedItemIndex,
$lastLoadedItemId,
}),
[$autoSwitch, $hasItems, $items, $progressData, $selectedItem, $selectedItemId, $selectedItemIndex, session]
[
$autoSwitch,
$hasItems,
$items,
$lastLoadedItemId,
$progressData,
$selectedItem,
$selectedItemId,
$selectedItemIndex,
session,
]
);
return <CanvasSessionContext.Provider value={value}>{children}</CanvasSessionContext.Provider>;

View File

@@ -21,7 +21,7 @@ export const getProgressMessage = (data?: S['InvocationProgressEvent'] | null) =
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
export const getQueueItemElementId = (item_id: number) => `queue-item-status-card-${item_id}`;
export const getQueueItemElementId = (itemId: number) => `queue-item-status-card-${itemId}`;
const getOutputImageName = (item: S['SessionQueueItem']) => {
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>

View File

@@ -1,9 +1,7 @@
import type { EphemeralProgressImage } from 'features/controlLayers/store/types';
import type { ProgressImage } from 'features/nodes/types/common';
import { round } from 'lodash-es';
import type { WritableAtom } from 'nanostores';
import { atom, computed, map } from 'nanostores';
import { useEffect, useState } from 'react';
import type { ImageDTO, S } from 'services/api/types';
import type { AppSocket } from 'services/events/types';
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
@@ -27,103 +25,6 @@ export type ProgressData = {
progressImage: ProgressImage | null;
};
export const useProgressData = (
$progressData: WritableAtom<Record<string, ProgressData>>,
sessionId: string
): ProgressData => {
const [value, setValue] = useState<ProgressData>(() => {
return $progressData.get()[sessionId] ?? { sessionId, progressEvent: null, progressImage: null };
});
useEffect(() => {
const unsub = $progressData.subscribe((data) => {
const progressData = data[sessionId];
if (!progressData) {
return;
}
setValue(progressData);
});
return () => {
unsub();
};
}, [$progressData, sessionId]);
return value;
};
export const useHasProgressImage = (
$progressData: WritableAtom<Record<string, ProgressData>>,
sessionId: string
): boolean => {
const [value, setValue] = useState(false);
useEffect(() => {
const unsub = $progressData.subscribe((data) => {
const progressData = data[sessionId];
setValue(Boolean(progressData?.progressImage));
});
return () => {
unsub();
};
}, [$progressData, sessionId]);
return value;
};
export const setProgress = (
$progressData: WritableAtom<Record<string, ProgressData>>,
data: S['InvocationProgressEvent']
) => {
const progressData = $progressData.get();
const current = progressData[data.session_id];
if (current) {
const next = { ...current };
next.progressEvent = data;
if (data.image) {
next.progressImage = data.image;
}
$progressData.set({
...progressData,
[data.session_id]: next,
});
} else {
$progressData.set({
...progressData,
[data.session_id]: {
sessionId: data.session_id,
progressEvent: data,
progressImage: data.image ?? null,
},
});
}
};
export const clearProgressEvent = ($progressData: WritableAtom<Record<string, ProgressData>>, sessionId: string) => {
const progressData = $progressData.get();
const current = progressData[sessionId];
if (!current) {
return;
}
const next = { ...current };
next.progressEvent = null;
$progressData.set({
...progressData,
[sessionId]: next,
});
};
export const clearProgressImage = ($progressData: WritableAtom<Record<string, ProgressData>>, sessionId: string) => {
const progressData = $progressData.get();
const current = progressData[sessionId];
if (!current) {
return;
}
const next = { ...current };
next.progressImage = null;
$progressData.set({
...progressData,
[sessionId]: next,
});
};
export const $lastCanvasProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
export const $lastCanvasProgressImage = atom<EphemeralProgressImage | null>(null);
export const $lastWorkflowsProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);