From 2e2ac71278f67fe28475612c070c184b14ff7113 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 4 Jun 2025 14:24:03 +1000 Subject: [PATCH] feat: canvas flow rework (wip) --- .../OverlayScrollbars/ScrollableContent.tsx | 5 +- .../components/CanvasMainPanelContent.tsx | 492 +++++++++--------- 2 files changed, 244 insertions(+), 253 deletions(-) diff --git a/invokeai/frontend/web/src/common/components/OverlayScrollbars/ScrollableContent.tsx b/invokeai/frontend/web/src/common/components/OverlayScrollbars/ScrollableContent.tsx index d61a5e498c..5da75b10c6 100644 --- a/invokeai/frontend/web/src/common/components/OverlayScrollbars/ScrollableContent.tsx +++ b/invokeai/frontend/web/src/common/components/OverlayScrollbars/ScrollableContent.tsx @@ -11,13 +11,14 @@ import { memo, useEffect, useMemo, useState } from 'react'; type Props = PropsWithChildren & { maxHeight?: ChakraProps['maxHeight']; + maxWidth?: ChakraProps['maxWidth']; overflowX?: 'hidden' | 'scroll'; overflowY?: 'hidden' | 'scroll'; }; const styles: CSSProperties = { position: 'absolute', top: 0, left: 0, right: 0, bottom: 0 }; -const ScrollableContent = ({ children, maxHeight, overflowX = 'hidden', overflowY = 'scroll' }: Props) => { +const ScrollableContent = ({ children, maxHeight, maxWidth, overflowX = 'hidden', overflowY = 'scroll' }: Props) => { const overlayscrollbarsOptions = useMemo( () => getOverlayScrollbarsParams({ overflowX, overflowY }).options, [overflowX, overflowY] @@ -44,7 +45,7 @@ const ScrollableContent = ({ children, maxHeight, overflowX = 'hidden', overflow }, [os]); return ( - + {children} diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasMainPanelContent.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasMainPanelContent.tsx index 6692ee1c9e..de6bab9717 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasMainPanelContent.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasMainPanelContent.tsx @@ -1,11 +1,17 @@ /* eslint-disable i18next/no-literal-string */ -import type { ButtonGroupProps, SystemStyleObject, TextProps } from '@invoke-ai/ui-library'; +import type { + ButtonGroupProps, + CircularProgressProps, + ImageProps, + SystemStyleObject, + TextProps, +} from '@invoke-ai/ui-library'; import { - Box, Button, ButtonGroup, CircularProgress, ContextMenu, + Divider, Flex, FormControl, FormLabel, @@ -26,7 +32,7 @@ import { EMPTY_ARRAY } from 'app/store/constants'; import { useAppStore } from 'app/store/nanostores/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper'; -import { IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback'; +import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; import { CanvasAlertsPreserveMask } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsPreserveMask'; import { CanvasAlertsSelectedEntityStatus } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSelectedEntityStatus'; @@ -44,7 +50,6 @@ import { StagingAreaToolbar } from 'features/controlLayers/components/StagingAre import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar'; import { Transform } from 'features/controlLayers/components/Transform/Transform'; import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; -import { loadImage } from 'features/controlLayers/konva/util'; import { selectDynamicGrid, selectShowHUD } from 'features/controlLayers/store/canvasSettingsSlice'; import { canvasSessionStarted, selectCanvasSession } from 'features/controlLayers/store/canvasStagingAreaSlice'; import { newCanvasFromImageDndTarget } from 'features/dnd/dnd'; @@ -56,7 +61,7 @@ import { isCanvasOutputNodeId } from 'features/nodes/util/graph/graphBuilderUtil import { round } from 'lodash-es'; import { atom, type WritableAtom } from 'nanostores'; import type { ChangeEvent } from 'react'; -import { createContext, memo, useCallback, useContext, useEffect, useMemo, useRef, useState } from 'react'; +import { createContext, memo, useCallback, useContext, useEffect, useMemo, useState } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { Trans, useTranslation } from 'react-i18next'; import { PiDotsThreeOutlineVerticalFill, PiUploadBold } from 'react-icons/pi'; @@ -64,7 +69,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useListAllQueueItemsQuery } from 'services/api/endpoints/queue'; import type { ImageDTO, S } from 'services/api/types'; import type { ProgressData } from 'services/events/stores'; -import { $socket, clearProgressImage, setProgress, useHasProgressImage, useProgressData } from 'services/events/stores'; +import { $socket, setProgress, useProgressData } from 'services/events/stores'; import type { Equals, Param0 } from 'tsafe'; import { assert, objectEntries } from 'tsafe'; @@ -313,31 +318,6 @@ const GenerateWithStartingImageAndInpaintMask = memo(() => { }); GenerateWithStartingImageAndInpaintMask.displayName = 'GenerateWithStartingImageAndInpaintMask'; -const scrollIndicatorBaseSx = { - opacity: 0, - position: 'absolute', - w: 16, - h: 'full', - transitionProperty: 'opacity', - transitionDuration: '0.3s', - pointerEvents: 'none', - '&[data-visible="true"]': { - opacity: 1, - }, -} satisfies SystemStyleObject; - -const scrollIndicatorLeftSx = { - ...scrollIndicatorBaseSx, - left: 0, - bg: 'linear-gradient(to right, var(--invoke-colors-base-900), transparent)', -} satisfies SystemStyleObject; - -const scrollIndicatorRightSx = { - ...scrollIndicatorBaseSx, - right: 0, - bg: 'linear-gradient(to left, var(--invoke-colors-base-900), transparent)', -} satisfies SystemStyleObject; - type StagingContextValue = { session: | { @@ -423,39 +403,25 @@ const StagingArea = memo(() => { const ctx = useStagingContext(); const [selectedItemId, setSelectedItemId] = useState(null); const [autoSwitch, setAutoSwitch] = useState(true); - const [canScrollLeft, setCanScrollLeft] = useState(false); - const [canScrollRight, setCanScrollRight] = useState(false); - const scrollableRef = useRef(null); const { items } = useListAllQueueItemsQuery({ destination: ctx.session.id }, LIST_ALL_OPTIONS); - const selectedItem = useMemo( - () => - items.length > 0 && selectedItemId !== null ? items.find(({ item_id }) => item_id === selectedItemId) : null, - [items, selectedItemId] - ); - const selectedItemIndex = useMemo( - () => - items.length > 0 && selectedItemId !== null ? items.findIndex(({ item_id }) => item_id === selectedItemId) : null, - [items, selectedItemId] - ); - - useEffect(() => { - const el = scrollableRef.current; - if (!el) { - return; + const selectedItem = useMemo(() => { + if (items.length === 0) { + return null; } - const onScroll = () => { - const { scrollLeft, scrollWidth, clientWidth } = el; - setCanScrollLeft(scrollLeft > 0); - setCanScrollRight(scrollLeft + clientWidth < scrollWidth); - }; - el.addEventListener('scroll', onScroll); - const observer = new ResizeObserver(onScroll); - observer.observe(el); - return () => { - el.removeEventListener('scroll', onScroll); - observer.disconnect(); - }; - }, []); + if (selectedItemId === null) { + return null; + } + return items.find(({ item_id }) => item_id === selectedItemId) ?? null; + }, [items, selectedItemId]); + const selectedItemIndex = useMemo(() => { + if (items.length === 0) { + return null; + } + if (selectedItemId === null) { + return null; + } + return items.findIndex(({ item_id }) => item_id === selectedItemId) ?? null; + }, [items, selectedItemId]); const onSelectItemId = useCallback((item_id: number | null) => { setSelectedItemId(item_id); @@ -466,10 +432,6 @@ const StagingArea = memo(() => { useStagingAreaKeyboardNav(items, selectedItemId, onSelectItemId); - const onChangeAutoSwitch = useCallback((autoSwitch: boolean) => { - setAutoSwitch(autoSwitch); - }, []); - useEffect(() => { if (items.length === 0) { onSelectItemId(null); @@ -503,10 +465,6 @@ const StagingArea = memo(() => { }; }, [autoSwitch, ctx.$progressData, ctx.session.id, onSelectItemId, socket]); - const _onChangeAutoSwitch = useCallback((e: ChangeEvent) => { - setAutoSwitch(e.target.checked); - }, []); - useEffect(() => { if (!socket) { return; @@ -526,9 +484,47 @@ const StagingArea = memo(() => { return ( - - - + + + {items.length > 0 && ( + + )} + {items.length === 0 && ( + + No generations + + )} + + ); +}); +StagingArea.displayName = 'StagingArea'; + +const StagingAreaContent = memo( + ({ + items, + selectedItem, + selectedItemId, + selectedItemIndex, + onChangeAutoSwitch, + onSelectItemId, + }: { + items: S['SessionQueueItem'][]; + selectedItem: S['SessionQueueItem'] | null; + selectedItemId: number | null; + selectedItemIndex: number | null; + onChangeAutoSwitch: (autoSwitch: boolean) => void; + onSelectItemId: (itemId: number) => void; + }) => { + return ( + <> + {selectedItem && selectedItemIndex !== null && ( { )} {!selectedItem && No generation selected} - - Auto-switch - - - - - - {items.map((item, i) => ( - - ))} + + + + + {items.map((item, i) => ( + + ))} + + - - + + ); + } +); +StagingAreaContent.displayName = 'StagingAreaContent'; + +const StagingAreaHeader = memo( + ({ autoSwitch, setAutoSwitch }: { autoSwitch: boolean; setAutoSwitch: (autoSwitch: boolean) => void }) => { + const dispatch = useAppDispatch(); + + const startOver = useCallback(() => { + dispatch(canvasSessionStarted({ sessionType: 'simple' })); + }, [dispatch]); + + const onChangeAutoSwitch = useCallback( + (e: ChangeEvent) => { + setAutoSwitch(e.target.checked); + }, + [setAutoSwitch] + ); + + return ( + + + Generations + + + + Auto-switch + + + - - ); -}); -StagingArea.displayName = 'StagingArea'; - -const StagingAreaHeader = memo(() => { - const dispatch = useAppDispatch(); - - const startOver = useCallback(() => { - dispatch(canvasSessionStarted({ sessionType: 'simple' })); - }, [dispatch]); - - return ( - - - Generations - - - - - ); -}); + ); + } +); StagingAreaHeader.displayName = 'StagingAreaHeader'; const miniQueueItemSx = { cursor: 'pointer', + userSelect: 'none', pos: 'relative', alignItems: 'center', justifyContent: 'center', @@ -603,57 +610,34 @@ const miniQueueItemSx = { }, aspectRatio: '1/1', flexShrink: 0, -}; +} satisfies SystemStyleObject; const getCardId = (item_id: number) => `queue-item-status-card-${item_id}`; -const useOutputImageDTO = (item: S['SessionQueueItem']) => { - const ctx = useStagingContext(); - - const outputImageName = useMemo(() => { - const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) => - isCanvasOutputNodeId(nodeId) - )?.[1][0]; - const output = nodeId ? item.session.results[nodeId] : undefined; - - if (!output) { - return null; - } - - for (const [_name, value] of objectEntries(output)) { - if (isImageField(value)) { - return value.image_name; - } - } +const getOutputImageName = (item: S['SessionQueueItem']) => { + const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) => + isCanvasOutputNodeId(nodeId) + )?.[1][0]; + const output = nodeId ? item.session.results[nodeId] : undefined; + if (!output) { return null; - }, [item.session.results, item.session.source_prepared_mapping]); + } + + for (const [_name, value] of objectEntries(output)) { + if (isImageField(value)) { + return value.image_name; + } + } + + return null; +}; + +const useOutputImageDTO = (item: S['SessionQueueItem']) => { + const outputImageName = useMemo(() => getOutputImageName(item), [item]); const { currentData: imageDTO } = useGetImageDTOQuery(outputImageName ?? skipToken); - const preloadOutputImageAndClearProgress = useCallback( - async (imageDTO: ImageDTO) => { - try { - await loadImage(imageDTO.image_url, true); - clearProgressImage(ctx.$progressData, item.session_id); - return; - } catch { - // noop - but should we do something? means image failed to load... - } - }, - [ctx.$progressData, item.session_id] - ); - - useEffect(() => { - if (!imageDTO) { - return; - } - if (!ctx.$progressData.get()[item.session_id]?.progressImage) { - return; - } - preloadOutputImageAndClearProgress(imageDTO); - }, [ctx.$progressData, imageDTO, item.session_id, preloadOutputImageAndClearProgress]); - return imageDTO; }; @@ -666,8 +650,7 @@ type MiniQueueItemProps = { }; const MiniQueueItem = memo(({ item, isSelected, number, onSelectItemId, onChangeAutoSwitch }: MiniQueueItemProps) => { - const ctx = useStagingContext(); - const hasProgressImage = useHasProgressImage(ctx.$progressData, item.session_id); + const [imageLoaded, setImageLoaded] = useState(false); const imageDTO = useOutputImageDTO(item); const onClick = useCallback(() => { @@ -678,14 +661,9 @@ const MiniQueueItem = memo(({ item, isSelected, number, onSelectItemId, onChange onChangeAutoSwitch(item.status === 'in_progress'); }, [item.status, onChangeAutoSwitch]); - if (imageDTO && !hasProgressImage) { - return ( - - - - - ); - } + const onLoad = useCallback(() => { + setImageLoaded(true); + }, []); return ( - + + {imageDTO && } + {!imageLoaded && } + ); }); @@ -704,16 +685,14 @@ MiniQueueItem.displayName = 'MiniQueueItem'; const fullSizeQueueItemSx = { cursor: 'pointer', + userSelect: 'none', pos: 'relative', alignItems: 'center', justifyContent: 'center', overflow: 'hidden', h: 'full', - maxH: 'full', - maxW: 'full', - minW: 0, - minH: 0, -}; + w: 'full', +} satisfies SystemStyleObject; type FullSizeQueueItemProps = { item: S['SessionQueueItem']; @@ -721,30 +700,48 @@ type FullSizeQueueItemProps = { }; const FullSizeQueueItem = memo(({ item, number }: FullSizeQueueItemProps) => { - const ctx = useStagingContext(); - const hasProgressImage = useHasProgressImage(ctx.$progressData, item.session_id); const imageDTO = useOutputImageDTO(item); + const [imageLoaded, setImageLoaded] = useState(false); - if (imageDTO && !hasProgressImage) { - return ( - - - - - - ); - } + const onLoad = useCallback(() => { + setImageLoaded(true); + }, []); return ( - + + {imageDTO && } + {!imageLoaded && } - + + ); }); FullSizeQueueItem.displayName = 'FullSizeQueueItem'; +const ProgressImage = memo(({ session_id, ...rest }: { session_id: string } & ImageProps) => { + const { $progressData } = useStagingContext(); + const { progressImage } = useProgressData($progressData, session_id); + + if (!progressImage) { + return null; + } + + return ( + + ); +}); +ProgressImage.displayName = 'ProgressImage'; + const getMessage = (data: S['InvocationProgressEvent']) => { let message = data.message; if (data.percentage) { @@ -758,79 +755,58 @@ const ItemNumber = memo(({ number, ...rest }: { number: number } & TextProps) => }); ItemNumber.displayName = 'ItemNumber'; -const ProgressMessage = memo(({ session_id, ...rest }: { session_id: string } & TextProps) => { - const { $progressData } = useStagingContext(); - const { progressEvent } = useProgressData($progressData, session_id); - if (!progressEvent) { - return null; +const ProgressMessage = memo( + ({ session_id, status, ...rest }: { session_id: string; status: S['SessionQueueItem']['status'] } & TextProps) => { + const { $progressData } = useStagingContext(); + const { progressEvent } = useProgressData($progressData, session_id); + + if (status === 'completed' || status === 'failed' || status === 'canceled') { + return null; + } + + return ( + + {progressEvent ? getMessage(progressEvent) : 'Waiting to start...'} + + ); } - return ( - - {getMessage(progressEvent)} - - ); -}); +); ProgressMessage.displayName = 'ProgressMessage'; -const InProgressContent = memo(({ item }: { item: S['SessionQueueItem'] }) => { - const { $progressData } = useStagingContext(); - const { progressEvent, progressImage } = useProgressData($progressData, item.session_id); - - if (item.status === 'pending') { +const ProgressLabel = memo(({ status, ...rest }: { status: S['SessionQueueItem']['status'] } & TextProps) => { + if (status === 'pending') { return ( - + Pending ); } - if (item.status === 'canceled') { + if (status === 'canceled') { return ( - + Canceled ); } - if (item.status === 'failed') { + if (status === 'failed') { return ( - + Failed ); } - if (progressImage) { + if (status === 'in_progress') { return ( - <> - - - + + In Progress + ); } - if (item.status === 'in_progress') { - return ( - <> - - In Progress - - - - ); - } - - if (item.status === 'completed') { - return ; - } - assert>(false); + return null; }); -InProgressContent.displayName = 'InProgressContent'; +ProgressLabel.displayName = 'ProgressLabel'; const circleStyles: SystemStyleObject = { circle: { @@ -842,20 +818,34 @@ const circleStyles: SystemStyleObject = { right: 2, }; -const ProgressCircle = memo(({ data }: { data?: S['InvocationProgressEvent'] | null }) => { - return ( - - - - ); -}); +const ProgressCircle = memo( + ({ + session_id, + status, + ...rest + }: { session_id: string; status: S['SessionQueueItem']['status'] } & CircularProgressProps) => { + const { $progressData } = useStagingContext(); + const { progressEvent } = useProgressData($progressData, session_id); + + if (status !== 'in_progress') { + return null; + } + + return ( + + + + ); + } +); ProgressCircle.displayName = 'ProgressCircle'; const ImageActions = memo(({ imageDTO, ...rest }: { imageDTO: ImageDTO } & ButtonGroupProps) => {