feat: canvas flow rework (wip)

This commit is contained in:
psychedelicious
2025-06-03 13:45:01 +10:00
parent 5e93f58530
commit 8a78e37634
18 changed files with 645 additions and 268 deletions

View File

@@ -1,10 +1,14 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
/* eslint-disable i18next/no-literal-string */
import type { FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
import {
Box,
Button,
ButtonGroup,
CircularProgress,
ContextMenu,
Flex,
FormControl,
FormLabel,
Heading,
IconButton,
Image,
@@ -12,9 +16,13 @@ import {
MenuButton,
MenuList,
Spacer,
Switch,
Text,
Tooltip,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper';
@@ -49,15 +57,21 @@ import { newCanvasFromImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { newCanvasFromImage } from 'features/imageActions/actions';
import { memo, useCallback, useEffect, useMemo } from 'react';
import type { ProgressImage } from 'features/nodes/types/common';
import { isImageField } from 'features/nodes/types/common';
import { isCanvasOutputNodeId } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { Trans, useTranslation } from 'react-i18next';
import { PiDotsThreeOutlineVerticalFill, PiUploadBold } from 'react-icons/pi';
import type { ImageDTO } from 'services/api/types';
import { getImageDTOSafe, useGetImageDTOQuery } from 'services/api/endpoints/images';
import { queueItemsAdapterSelectors, useListQueueItemsQuery } from 'services/api/endpoints/queue';
import type { ImageDTO, S } from 'services/api/types';
import type { ProgressAndResult } from 'services/events/stores';
import { $progressImages, useMapSelector } from 'services/events/stores';
import { $progressImages, $socket, useMapSelector } from 'services/events/stores';
import type { Equals, Param0 } from 'tsafe';
import { assert } from 'tsafe';
import { assert, objectEntries } from 'tsafe';
import { CanvasAlertsInvocationProgress } from './CanvasAlerts/CanvasAlertsInvocationProgress';
@@ -311,20 +325,405 @@ const SimpleActiveSession = memo(() => {
<Flex flexDir="column" w="full" h="full" alignItems="center" justifyContent="center" gap={2}>
<Flex w="full">
<Text fontSize="lg" fontWeight="bold">
Generations from this Session
Generations
</Text>
<Spacer />
<Button size="sm" onClick={startOver}>
<Button size="sm" variant="ghost" onClick={startOver}>
Start Over
</Button>
</Flex>
<SelectedImageOrProgressImage />
<SessionImages />
<StagingArea />
</Flex>
);
});
SimpleActiveSession.displayName = 'SimpleActiveSession';
const scrollIndicatorSx = {
opacity: 0,
'&[data-visible="true"]': {
opacity: 1,
},
} satisfies SystemStyleObject;
const StagingArea = memo(() => {
const [selectedItemId, setSelectedItemId] = useState<number | null>(null);
const [autoSwitch, setAutoSwitch] = useState(true);
const [canScrollLeft, setCanScrollLeft] = useState(false);
const [canScrollRight, setCanScrollRight] = useState(false);
const scrollableRef = useRef<HTMLDivElement>(null);
const { data } = useListQueueItemsQuery({ destination: 'canvas' });
const items = useMemo(() => {
if (!data) {
return EMPTY_ARRAY;
}
return queueItemsAdapterSelectors.selectAll(data);
}, [data]);
const selectedItem = useMemo(
() =>
data && selectedItemId !== null ? queueItemsAdapterSelectors.selectById(data, String(selectedItemId)) : null,
[data, selectedItemId]
);
useEffect(() => {
if (items.length === 0) {
setSelectedItemId(null);
return;
}
if (selectedItem === null && items.length > 0) {
setSelectedItemId(items[0]?.item_id ?? null);
return;
}
if (selectedItemId === null || items.find((item) => item.item_id === selectedItemId) === undefined) {
return;
}
document.getElementById(`queue-item-status-card-${selectedItemId}`)?.scrollIntoView();
}, [items, selectedItem, selectedItemId]);
useEffect(() => {
const el = scrollableRef.current;
if (!el) {
return;
}
const onScroll = () => {
const { scrollLeft, scrollWidth, clientWidth } = el;
setCanScrollLeft(scrollLeft > 0);
setCanScrollRight(scrollLeft + clientWidth < scrollWidth);
};
el.addEventListener('scroll', onScroll);
const observer = new ResizeObserver(onScroll);
observer.observe(el);
return () => {
el.removeEventListener('scroll', onScroll);
observer.disconnect();
};
}, []);
const onSelectItem = useCallback((item: S['SessionQueueItem']) => {
setSelectedItemId(item.item_id);
if (item.status !== 'in_progress') {
setAutoSwitch(false);
}
}, []);
const onNext = useCallback(() => {
if (selectedItemId === null) {
return;
}
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
const nextIndex = (currentIndex + 1) % items.length;
const nextItem = items[nextIndex];
if (!nextItem) {
return;
}
setSelectedItemId(nextItem.item_id);
}, [items, selectedItemId]);
const onPrev = useCallback(() => {
if (selectedItemId === null) {
return;
}
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
const prevIndex = (currentIndex - 1 + items.length) % items.length;
const prevItem = items[prevIndex];
if (!prevItem) {
return;
}
setSelectedItemId(prevItem.item_id);
}, [items, selectedItemId]);
useHotkeys('left', onPrev);
useHotkeys('right', onNext);
const socket = useStore($socket);
useEffect(() => {
if (!autoSwitch) {
return;
}
if (!socket) {
return;
}
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
if (data.destination !== 'canvas') {
return;
}
if (data.status === 'in_progress') {
setSelectedItemId(data.item_id);
}
};
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
return () => {
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [autoSwitch, socket]);
const onChangeAutoSwitch = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setAutoSwitch(e.target.checked);
}, []);
return (
<Flex position="relative" flexDir="column" gap={2} w="full" h="full" minW={0} minH={0}>
<Flex w="full" h="full" alignItems="center" justifyContent="center" minW={0} minH={0}>
{selectedItem && <QueueItemStatusCard item={selectedItem} minW={0} minH={0} h="full" isSelected={false} />}
{!selectedItem && <Text>No queued generations</Text>}
</Flex>
<FormControl position="absolute" top={2} right={2} w="min-content">
<FormLabel m={0}>Auto-switch</FormLabel>
<Switch size="sm" isChecked={autoSwitch} onChange={onChangeAutoSwitch} />
</FormControl>
<Flex position="relative" w="full" maxW="full">
<Flex ref={scrollableRef} gap={2} h={108} maxW="full" overflowX="scroll" flexShrink={0}>
{items.map((item, i) => (
<QueueItemStatusCard
id={`queue-item-status-card-${item.item_id}`}
key={item.item_id}
item={item}
number={i + 1}
onSelectItem={onSelectItem}
isSelected={selectedItemId === item.item_id}
w={108}
h={108}
flexShrink={0}
/>
))}
</Flex>
<Box
position="absolute"
sx={scrollIndicatorSx}
left={0}
w={16}
h="full"
bg="linear-gradient(to right, var(--invoke-colors-base-900), transparent)"
data-visible={canScrollLeft}
transitionProperty="opacity"
transitionDuration="0.3s"
pointerEvents="none"
/>
<Box
position="absolute"
sx={scrollIndicatorSx}
right={0}
w={16}
h="full"
bg="linear-gradient(to left, var(--invoke-colors-base-900), transparent)"
data-visible={canScrollRight}
transitionProperty="opacity"
transitionDuration="0.3s"
pointerEvents="none"
/>
</Flex>
</Flex>
);
});
StagingArea.displayName = 'StagingArea';
const IMAGE_DTO_ERROR = Symbol('IMAGE_DTO_ERROR');
const useOutputImageDTO = (item: S['SessionQueueItem']) => {
const [imageDTO, setImageDTO] = useState<ImageDTO | typeof IMAGE_DTO_ERROR | null>(null);
const syncImageDTO = useCallback(async (item: S['SessionQueueItem']) => {
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>
isCanvasOutputNodeId(nodeId)
)?.[1][0];
const output = nodeId ? item.session.results[nodeId] : undefined;
if (!output) {
return setImageDTO(null);
}
for (const [_name, value] of objectEntries(output)) {
if (isImageField(value)) {
const imageDTO = await getImageDTOSafe(value.image_name);
if (imageDTO) {
setImageDTO(imageDTO);
$progressImages.setKey(item.session_id, undefined);
return;
}
}
}
setImageDTO(IMAGE_DTO_ERROR);
}, []);
useEffect(() => {
syncImageDTO(item);
}, [item, syncImageDTO]);
return imageDTO;
};
const QueueItemStatusCard = memo(
({
item,
isSelected,
number,
onSelectItem,
...rest
}: {
item: S['SessionQueueItem'];
isSelected: boolean;
number?: number;
onSelectItem?: (item: S['SessionQueueItem']) => void;
} & FlexProps) => {
const onClick = useCallback(() => {
onSelectItem?.(item);
}, [item, onSelectItem]);
return (
<Flex
role="button"
pos="relative"
borderWidth={1}
borderRadius="base"
alignItems="center"
justifyContent="center"
overflow="hidden"
onClick={onClick}
aspectRatio="1/1"
borderColor={isSelected ? 'invokeBlue.300' : undefined}
{...rest}
>
<QueueItemStatusCardContent item={item} />
{number !== undefined && <Text position="absolute" top={0} left={1}>{`#${number}`}</Text>}
</Flex>
);
}
);
QueueItemStatusCard.displayName = 'QueueItemStatusCard';
const QueueItemStatusCardContent = memo(({ item }: { item: S['SessionQueueItem'] }) => {
const socket = useStore($socket);
const [progressEvent, setProgressEvent] = useState<S['InvocationProgressEvent'] | null>(null);
const [progressImage, setProgressImage] = useState<ProgressImage | null>(null);
useEffect(() => {
if (!socket) {
return;
}
const onProgress = (data: S['InvocationProgressEvent']) => {
if (data.session_id !== item.session_id) {
return;
}
setProgressEvent(data);
if (data.image) {
setProgressImage(data.image);
}
};
socket.on('invocation_progress', onProgress);
return () => {
socket.off('invocation_progress', onProgress);
};
}, [item.session_id, socket]);
const imageDTO = useOutputImageDTO(item);
if (item.status === 'pending') {
return (
<Text fontWeight="semibold" color="base.300">
Pending
</Text>
);
}
if (item.status === 'canceled') {
return (
<Text fontWeight="semibold" color="warning.300">
Canceled
</Text>
);
}
if (item.status === 'failed') {
return (
<Text fontWeight="semibold" color="error.300">
Failed
</Text>
);
}
if (item.status === 'in_progress' || !imageDTO) {
if (!progressImage) {
return (
<>
<Text fontWeight="semibold" color="invokeBlue.300">
In Progress
</Text>
<ProgressCircle data={progressEvent} />
</>
);
}
return (
<>
<Image objectFit="contain" maxH="full" maxW="full" src={progressImage.dataURL} width={progressImage.width} />
<ProgressCircle data={progressEvent} />
</>
);
}
if (item.status === 'completed' && imageDTO && imageDTO !== IMAGE_DTO_ERROR) {
return <Image objectFit="contain" maxH="full" maxW="full" src={imageDTO.image_url} width={imageDTO.width} />;
}
if (item.status === 'completed') {
return (
<Text fontWeight="semibold" color="error.300">
Unable to get image
</Text>
);
}
assert<Equals<never, typeof item.status>>(false);
});
QueueItemStatusCardContent.displayName = 'QueueItemStatusCardContent';
const circleStyles: SystemStyleObject = {
circle: {
transitionProperty: 'none',
transitionDuration: '0s',
},
position: 'absolute',
top: 2,
right: 2,
};
const ProgressCircle = ({ data }: { data?: S['InvocationProgressEvent'] | null }) => {
return (
<Tooltip label={data?.message ?? 'Generating'}>
<CircularProgress
size="14px"
color="invokeBlue.500"
thickness={14}
isIndeterminate={!data || data.percentage === null}
value={data?.percentage ? data.percentage * 100 : undefined}
sx={circleStyles}
/>
</Tooltip>
);
};
ProgressCircle.displayName = 'ProgressCircle';
const QueueItemResultCard = memo(({ item }: { item: S['SessionQueueItem'] }) => {
const imageName = useMemo(() => {
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>
isCanvasOutputNodeId(nodeId)
)?.[1][0];
const output = nodeId ? item.session.results[nodeId] : undefined;
if (!output) {
return;
}
for (const [_name, value] of objectEntries(output)) {
if (isImageField(value)) {
return value.image_name;
}
}
}, [item]);
const { data: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
if (!imageDTO) {
return <Text>Unknown output type</Text>;
}
return <Image objectFit="contain" maxH="full" maxW="full" src={imageDTO.image_url} width={imageDTO.width} />;
});
QueueItemResultCard.displayName = 'QueueItemResultCard';
const SelectedImageOrProgressImage = memo(() => {
const selectedImage = useAppSelector(selectSelectedImage);

View File

@@ -159,6 +159,8 @@ export const isMainModelWithoutUnet = (modelLoader: Invocation<MainModelLoaderNo
);
};
export const isCanvasOutputNodeId = (nodeId: string) => nodeId.split(':')[0] === CANVAS_OUTPUT_PREFIX;
export const isCanvasOutputEvent = (data: S['InvocationCompleteEvent']) => {
return data.invocation_source_id.split(':')[0] === CANVAS_OUTPUT_PREFIX;
return isCanvasOutputNodeId(data.invocation_source_id);
};

View File

@@ -13,7 +13,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiXBold } from 'react-icons/pi';
import { useSelector } from 'react-redux';
import type { SessionQueueItemDTO } from 'services/api/types';
import type { S } from 'services/api/types';
import { COLUMN_WIDTHS } from './constants';
import QueueItemDetail from './QueueItemDetail';
@@ -23,7 +23,7 @@ const selectedStyles = { bg: 'base.700' };
type InnerItemProps = {
index: number;
item: SessionQueueItemDTO;
item: S['SessionQueueItem'];
context: ListContext;
};
@@ -155,7 +155,7 @@ const QueueItemComponent = ({ index, item, context }: InnerItemProps) => {
</Flex>
<Collapse in={isOpen} transition={transition} unmountOnExit={true}>
<QueueItemDetail queueItemDTO={item} />
<QueueItemDetail queueItem={item} />
</Collapse>
</Flex>
);

View File

@@ -12,24 +12,20 @@ import type { ReactNode } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiXBold } from 'react-icons/pi';
import { useGetQueueItemQuery } from 'services/api/endpoints/queue';
import type { SessionQueueItemDTO } from 'services/api/types';
import type { S } from 'services/api/types';
type Props = {
queueItemDTO: SessionQueueItemDTO;
queueItem: S['SessionQueueItem'];
};
const QueueItemComponent = ({ queueItemDTO }: Props) => {
const { session_id, batch_id, item_id, origin, destination } = queueItemDTO;
const QueueItemComponent = ({ queueItem }: Props) => {
const { session_id, batch_id, item_id, origin, destination } = queueItem;
const { t } = useTranslation();
const isRetryEnabled = useFeatureStatus('retryQueueItem');
const { cancelBatch, isLoading: isLoadingCancelBatch, isCanceled: isBatchCanceled } = useCancelBatch(batch_id);
const { cancelQueueItem, isLoading: isLoadingCancelQueueItem } = useCancelQueueItem(item_id);
const { retryQueueItem, isLoading: isLoadingRetryQueueItem } = useRetryQueueItem(item_id);
const { data: queueItem } = useGetQueueItemQuery(item_id);
const originText = useOriginText(origin);
const destinationText = useDestinationText(destination);

View File

@@ -14,7 +14,7 @@ import { useTranslation } from 'react-i18next';
import type { Components, ItemContent } from 'react-virtuoso';
import { Virtuoso } from 'react-virtuoso';
import { queueItemsAdapterSelectors, useListQueueItemsQuery } from 'services/api/endpoints/queue';
import type { SessionQueueItemDTO } from 'services/api/types';
import type { S } from 'services/api/types';
import QueueItemComponent from './QueueItemComponent';
import QueueListComponent from './QueueListComponent';
@@ -24,13 +24,13 @@ import type { ListContext } from './types';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
type TableVirtuosoScrollerRef = (ref: HTMLElement | Window | null) => any;
const computeItemKey = (index: number, item: SessionQueueItemDTO): number => item.item_id;
const computeItemKey = (index: number, item: S['SessionQueueItem']): number => item.item_id;
const components: Components<SessionQueueItemDTO, ListContext> = {
const components: Components<S['SessionQueueItem'], ListContext> = {
List: QueueListComponent,
};
const itemContent: ItemContent<SessionQueueItemDTO, ListContext> = (index, item, context) => (
const itemContent: ItemContent<S['SessionQueueItem'], ListContext> = (index, item, context) => (
<QueueItemComponent index={index} item={item} context={context} />
);
@@ -114,7 +114,7 @@ const QueueList = () => {
<Flex w="full" h="full" flexDir="column">
<QueueListHeader />
<Flex ref={rootRef} w="full" h="full" alignItems="center" justifyContent="center">
<Virtuoso<SessionQueueItemDTO, ListContext>
<Virtuoso<S['SessionQueueItem'], ListContext>
data={queueItems}
endReached={handleLoadMore}
scrollerRef={setScroller as TableVirtuosoScrollerRef}

View File

@@ -1,10 +1,10 @@
import { Flex, forwardRef, typedMemo } from '@invoke-ai/ui-library';
import type { Components } from 'react-virtuoso';
import type { SessionQueueItemDTO } from 'services/api/types';
import type { S } from 'services/api/types';
import type { ListContext } from './types';
const QueueListComponent: Components<SessionQueueItemDTO, ListContext>['List'] = typedMemo(
const QueueListComponent: Components<S['SessionQueueItem'], ListContext>['List'] = typedMemo(
forwardRef((props, ref) => {
return (
<Flex {...props} ref={ref} flexDirection="column" gap={0.5}>

View File

@@ -1,7 +1,7 @@
import { useTranslation } from 'react-i18next';
import type { SessionQueueItemDTO } from 'services/api/types';
import type { S } from 'services/api/types';
export const useDestinationText = (destination: SessionQueueItemDTO['destination']) => {
export const useDestinationText = (destination: S['SessionQueueItem']['destination']) => {
const { t } = useTranslation();
if (destination === 'canvas') {

View File

@@ -1,7 +1,7 @@
import { useTranslation } from 'react-i18next';
import type { SessionQueueItemDTO } from 'services/api/types';
import type { S } from 'services/api/types';
export const useOriginText = (origin: SessionQueueItemDTO['origin']) => {
export const useOriginText = (origin: S['SessionQueueItem']['origin']) => {
const { t } = useTranslation();
if (origin === 'generation') {