fix(ui): progress image fixes

This commit is contained in:
psychedelicious
2025-06-16 18:25:37 +10:00
parent 2e0824a799
commit 893f7a8744
4 changed files with 104 additions and 53 deletions

View File

@@ -1,6 +1,10 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import { useOutputImageDTO } from 'features/controlLayers/components/SimpleSession/context';
import {
useCanvasSessionContext,
useOutputImageDTO,
useProgressData,
} from 'features/controlLayers/components/SimpleSession/context';
import { ImageActions } from 'features/controlLayers/components/SimpleSession/ImageActions';
import { QueueItemCircularProgress } from 'features/controlLayers/components/SimpleSession/QueueItemCircularProgress';
import { QueueItemNumber } from 'features/controlLayers/components/SimpleSession/QueueItemNumber';
@@ -8,7 +12,7 @@ import { QueueItemProgressImage } from 'features/controlLayers/components/Simple
import { QueueItemStatusLabel } from 'features/controlLayers/components/SimpleSession/QueueItemStatusLabel';
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
import { DndImage } from 'features/dnd/DndImage';
import { memo, useCallback, useState } from 'react';
import { memo } from 'react';
import type { S } from 'services/api/types';
type Props = {
@@ -27,17 +31,14 @@ const sx = {
} satisfies SystemStyleObject;
export const QueueItemPreviewFull = memo(({ item, number }: Props) => {
const ctx = useCanvasSessionContext();
const imageDTO = useOutputImageDTO(item);
const [imageLoaded, setImageLoaded] = useState(false);
const onLoad = useCallback(() => {
setImageLoaded(true);
}, []);
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
return (
<Flex id={getQueueItemElementId(item.item_id)} sx={sx}>
<QueueItemStatusLabel status={item.status} position="absolute" margin="auto" />
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} />}
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
{imageDTO && <DndImage imageDTO={imageDTO} />}
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
{imageDTO && <ImageActions imageDTO={imageDTO} position="absolute" top={1} right={2} />}
<QueueItemNumber number={number} position="absolute" top={1} left={2} />

View File

@@ -1,13 +1,17 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import { useCanvasSessionContext, useOutputImageDTO } from 'features/controlLayers/components/SimpleSession/context';
import {
useCanvasSessionContext,
useOutputImageDTO,
useProgressData,
} from 'features/controlLayers/components/SimpleSession/context';
import { QueueItemCircularProgress } from 'features/controlLayers/components/SimpleSession/QueueItemCircularProgress';
import { QueueItemNumber } from 'features/controlLayers/components/SimpleSession/QueueItemNumber';
import { QueueItemProgressImage } from 'features/controlLayers/components/SimpleSession/QueueItemProgressImage';
import { QueueItemStatusLabel } from 'features/controlLayers/components/SimpleSession/QueueItemStatusLabel';
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
import { DndImage } from 'features/dnd/DndImage';
import { memo, useCallback, useState } from 'react';
import { memo, useCallback } from 'react';
import type { S } from 'services/api/types';
const sx = {
@@ -35,7 +39,7 @@ type Props = {
export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) => {
const ctx = useCanvasSessionContext();
const [imageLoaded, setImageLoaded] = useState(false);
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
const imageDTO = useOutputImageDTO(item);
const onClick = useCallback(() => {
@@ -43,15 +47,12 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) =
}, [ctx.$selectedItemId, item.item_id]);
const onLoad = useCallback(() => {
setImageLoaded(true);
if (ctx.$progressData.get()[item.item_id]) {
ctx.$lastLoadedItemId.set(item.item_id);
}
}, [ctx.$lastLoadedItemId, ctx.$progressData, item.item_id]);
ctx.onImageLoad(item.item_id);
}, [ctx, item.item_id]);
return (
<Flex id={getQueueItemElementId(item.item_id)} sx={sx} data-selected={isSelected} onClick={onClick}>
<QueueItemStatusLabel status={item.status} position="absolute" margin="auto" />
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail />}
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
<QueueItemNumber number={number} position="absolute" top={0} left={1} />

View File

@@ -1,27 +1,35 @@
/* eslint-disable i18next/no-literal-string */
import type { TextProps } from '@invoke-ai/ui-library';
import { Text } from '@invoke-ai/ui-library';
import { useCanvasSessionContext, useProgressData } from 'features/controlLayers/components/SimpleSession/context';
import { memo } from 'react';
import type { S } from 'services/api/types';
type Props = { status: S['SessionQueueItem']['status'] } & TextProps;
type Props = { item: S['SessionQueueItem'] } & TextProps;
export const QueueItemStatusLabel = memo(({ status, ...rest }: Props) => {
if (status === 'pending') {
export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
const ctx = useCanvasSessionContext();
const { progressImage, imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
if (progressImage || imageLoaded) {
return null;
}
if (item.status === 'pending') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="base.300" {...rest}>
Pending
</Text>
);
}
if (status === 'canceled') {
if (item.status === 'canceled') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="warning.300" {...rest}>
Canceled
</Text>
);
}
if (status === 'failed') {
if (item.status === 'failed') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300" {...rest}>
Failed
@@ -29,7 +37,7 @@ export const QueueItemStatusLabel = memo(({ status, ...rest }: Props) => {
);
}
if (status === 'in_progress') {
if (item.status === 'in_progress') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeBlue.300" {...rest}>
In Progress
@@ -37,6 +45,14 @@ export const QueueItemStatusLabel = memo(({ status, ...rest }: Props) => {
);
}
if (item.status === 'completed') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeGreen.300" {...rest}>
Completed
</Text>
);
}
return null;
});
QueueItemStatusLabel.displayName = 'QueueItemStatusLabel';

View File

@@ -4,21 +4,22 @@ import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppStore } from 'app/store/nanostores/store';
import { getOutputImageName } from 'features/controlLayers/components/SimpleSession/shared';
import type { ProgressImage } from 'features/nodes/types/common';
import type { Atom, WritableAtom } from 'nanostores';
import { atom, computed, effect } from 'nanostores';
import type { Atom, MapStore, StoreValue, WritableAtom } from 'nanostores';
import { atom, computed, effect, map, subscribeKeys } from 'nanostores';
import type { PropsWithChildren } from 'react';
import { createContext, memo, useCallback, useContext, useEffect, useMemo, useState } from 'react';
import { getImageDTOSafe } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO, S } from 'services/api/types';
import { $socket } from 'services/events/stores';
import { assert } from 'tsafe';
import { assert, objectEntries } from 'tsafe';
export type ProgressData = {
itemId: number;
progressEvent: S['InvocationProgressEvent'] | null;
progressImage: ProgressImage | null;
imageDTO: ImageDTO | null;
imageLoaded: boolean;
};
const getInitialProgressData = (itemId: number): ProgressData => ({
@@ -26,17 +27,17 @@ const getInitialProgressData = (itemId: number): ProgressData => ({
progressEvent: null,
progressImage: null,
imageDTO: null,
imageLoaded: false,
});
export const useProgressData = (
$progressData: WritableAtom<Record<number, ProgressData>>,
itemId: number
): ProgressData => {
const [value, setValue] = useState<ProgressData>(() => {
return $progressData.get()[itemId] ?? getInitialProgressData(itemId);
});
export const useProgressData = ($progressData: ProgressDataMap, itemId: number): ProgressData => {
const getInitialValue = useCallback(
() => $progressData.get()[itemId] ?? getInitialProgressData(itemId),
[$progressData, itemId]
);
const [value, setValue] = useState(getInitialValue);
useEffect(() => {
const unsub = $progressData.subscribe((data) => {
const unsub = subscribeKeys($progressData, [itemId], (data) => {
const progressData = data[itemId];
if (!progressData) {
return;
@@ -51,7 +52,7 @@ export const useProgressData = (
return value;
};
const setProgress = ($progressData: WritableAtom<Record<number, ProgressData>>, data: S['InvocationProgressEvent']) => {
const setProgress = ($progressData: ProgressDataMap, data: S['InvocationProgressEvent']) => {
const progressData = $progressData.get();
const current = progressData[data.item_id];
if (current) {
@@ -72,27 +73,30 @@ const setProgress = ($progressData: WritableAtom<Record<number, ProgressData>>,
progressEvent: data,
progressImage: data.image ?? null,
imageDTO: null,
imageLoaded: false,
},
});
}
};
type ProgressDataMap = MapStore<Record<number, ProgressData | undefined>>;
type CanvasSessionContextValue = {
session: { id: string; type: 'simple' | 'advanced' };
$items: Atom<S['SessionQueueItem'][]>;
$itemCount: Atom<number>;
$hasItems: Atom<boolean>;
$progressData: WritableAtom<Record<string, ProgressData>>;
$progressData: ProgressDataMap;
$selectedItemId: WritableAtom<number | null>;
$selectedItem: Atom<S['SessionQueueItem'] | null>;
$selectedItemIndex: Atom<number | null>;
$selectedItemOutputImageDTO: Atom<ImageDTO | null>;
$autoSwitch: WritableAtom<boolean>;
$lastLoadedItemId: WritableAtom<number | null>;
selectNext: () => void;
selectPrev: () => void;
selectFirst: () => void;
selectLast: () => void;
onImageLoad: (itemId: number) => void;
};
const CanvasSessionContext = createContext<CanvasSessionContextValue | null>(null);
@@ -112,6 +116,7 @@ export const CanvasSessionContextProvider = memo(
const store = useAppStore();
const socket = useStore($socket);
const $lastCompletedItemId = useState(() => atom<number | null>(null))[0];
/**
* Manually-synced atom containing queue items for the current session. This is populated from the RTK Query cache
@@ -133,7 +138,7 @@ export const CanvasSessionContextProvider = memo(
/**
* An ephemeral store of progress events and images for all items in the current session.
*/
const $progressData = useState(() => atom<Record<number, ProgressData>>({}))[0];
const $progressData = useState(() => map<StoreValue<ProgressDataMap>>({}))[0];
/**
* The currently selected queue item's ID, or null if one is not selected.
@@ -259,6 +264,27 @@ export const CanvasSessionContextProvider = memo(
$selectedItemId.set(last.item_id);
}, [$items, $selectedItemId]);
const onImageLoad = useCallback(
(itemId: number) => {
const progressData = $progressData.get();
const current = progressData[itemId];
if (current) {
const next = { ...current, imageLoaded: true };
$progressData.setKey(itemId, next);
} else {
$progressData.setKey(itemId, {
...getInitialProgressData(itemId),
imageLoaded: true,
});
}
if ($lastCompletedItemId.get() === itemId) {
$selectedItemId.set(itemId);
$lastCompletedItemId.set(null);
}
},
[$lastCompletedItemId, $progressData, $selectedItemId]
);
// Set up socket listeners
useEffect(() => {
if (!socket) {
@@ -272,12 +298,23 @@ export const CanvasSessionContextProvider = memo(
setProgress($progressData, data);
};
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
if (data.destination !== session.id) {
return;
}
if (data.status === 'completed') {
$lastCompletedItemId.set(data.item_id);
}
};
socket.on('invocation_progress', onProgress);
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
return () => {
socket.off('invocation_progress', onProgress);
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [$autoSwitch, $progressData, $selectedItemId, session.id, socket]);
}, [$autoSwitch, $lastCompletedItemId, $progressData, $selectedItemId, session.id, socket]);
// Set up state subscriptions and effects
useEffect(() => {
@@ -327,7 +364,11 @@ export const CanvasSessionContextProvider = memo(
const toDelete: number[] = [];
const toUpdate: ProgressData[] = [];
for (const datum of Object.values(progressData)) {
for (const [id, datum] of objectEntries(progressData)) {
if (!datum) {
toDelete.push(id);
continue;
}
const item = items.find(({ item_id }) => item_id === datum.itemId);
if (!item) {
toDelete.push(datum.itemId);
@@ -376,21 +417,13 @@ export const CanvasSessionContextProvider = memo(
}
}
if (toDelete.length === 0 && toUpdate.length === 0) {
return;
}
const newProgressData = { ...progressData };
for (const itemId of toDelete) {
delete newProgressData[itemId];
$progressData.setKey(itemId, undefined);
}
for (const datum of toUpdate) {
newProgressData[datum.itemId] = datum;
$progressData.setKey(datum.itemId, datum);
}
$progressData.set(newProgressData);
});
// We only want to auto-switch to completed queue items once their images have fully loaded to prevent flashes
@@ -440,19 +473,18 @@ export const CanvasSessionContextProvider = memo(
$autoSwitch,
$selectedItem,
$selectedItemIndex,
$lastLoadedItemId,
$selectedItemOutputImageDTO,
$itemCount,
selectNext,
selectPrev,
selectFirst,
selectLast,
onImageLoad,
}),
[
$autoSwitch,
$items,
$hasItems,
$lastLoadedItemId,
$progressData,
$selectedItem,
$selectedItemId,
@@ -464,6 +496,7 @@ export const CanvasSessionContextProvider = memo(
selectPrev,
selectFirst,
selectLast,
onImageLoad,
]
);