mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 06:18:03 -05:00
fix(ui): progress image fixes
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useOutputImageDTO } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import {
|
||||
useCanvasSessionContext,
|
||||
useOutputImageDTO,
|
||||
useProgressData,
|
||||
} from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { ImageActions } from 'features/controlLayers/components/SimpleSession/ImageActions';
|
||||
import { QueueItemCircularProgress } from 'features/controlLayers/components/SimpleSession/QueueItemCircularProgress';
|
||||
import { QueueItemNumber } from 'features/controlLayers/components/SimpleSession/QueueItemNumber';
|
||||
@@ -8,7 +12,7 @@ import { QueueItemProgressImage } from 'features/controlLayers/components/Simple
|
||||
import { QueueItemStatusLabel } from 'features/controlLayers/components/SimpleSession/QueueItemStatusLabel';
|
||||
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { memo } from 'react';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
@@ -27,17 +31,14 @@ const sx = {
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
export const QueueItemPreviewFull = memo(({ item, number }: Props) => {
|
||||
const ctx = useCanvasSessionContext();
|
||||
const imageDTO = useOutputImageDTO(item);
|
||||
const [imageLoaded, setImageLoaded] = useState(false);
|
||||
|
||||
const onLoad = useCallback(() => {
|
||||
setImageLoaded(true);
|
||||
}, []);
|
||||
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
|
||||
|
||||
return (
|
||||
<Flex id={getQueueItemElementId(item.item_id)} sx={sx}>
|
||||
<QueueItemStatusLabel status={item.status} position="absolute" margin="auto" />
|
||||
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} />}
|
||||
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
|
||||
{imageDTO && <DndImage imageDTO={imageDTO} />}
|
||||
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
|
||||
{imageDTO && <ImageActions imageDTO={imageDTO} position="absolute" top={1} right={2} />}
|
||||
<QueueItemNumber number={number} position="absolute" top={1} left={2} />
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useCanvasSessionContext, useOutputImageDTO } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import {
|
||||
useCanvasSessionContext,
|
||||
useOutputImageDTO,
|
||||
useProgressData,
|
||||
} from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { QueueItemCircularProgress } from 'features/controlLayers/components/SimpleSession/QueueItemCircularProgress';
|
||||
import { QueueItemNumber } from 'features/controlLayers/components/SimpleSession/QueueItemNumber';
|
||||
import { QueueItemProgressImage } from 'features/controlLayers/components/SimpleSession/QueueItemProgressImage';
|
||||
import { QueueItemStatusLabel } from 'features/controlLayers/components/SimpleSession/QueueItemStatusLabel';
|
||||
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
const sx = {
|
||||
@@ -35,7 +39,7 @@ type Props = {
|
||||
|
||||
export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) => {
|
||||
const ctx = useCanvasSessionContext();
|
||||
const [imageLoaded, setImageLoaded] = useState(false);
|
||||
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
|
||||
const imageDTO = useOutputImageDTO(item);
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
@@ -43,15 +47,12 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) =
|
||||
}, [ctx.$selectedItemId, item.item_id]);
|
||||
|
||||
const onLoad = useCallback(() => {
|
||||
setImageLoaded(true);
|
||||
if (ctx.$progressData.get()[item.item_id]) {
|
||||
ctx.$lastLoadedItemId.set(item.item_id);
|
||||
}
|
||||
}, [ctx.$lastLoadedItemId, ctx.$progressData, item.item_id]);
|
||||
ctx.onImageLoad(item.item_id);
|
||||
}, [ctx, item.item_id]);
|
||||
|
||||
return (
|
||||
<Flex id={getQueueItemElementId(item.item_id)} sx={sx} data-selected={isSelected} onClick={onClick}>
|
||||
<QueueItemStatusLabel status={item.status} position="absolute" margin="auto" />
|
||||
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
|
||||
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail />}
|
||||
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
|
||||
<QueueItemNumber number={number} position="absolute" top={0} left={1} />
|
||||
|
||||
@@ -1,27 +1,35 @@
|
||||
/* eslint-disable i18next/no-literal-string */
|
||||
import type { TextProps } from '@invoke-ai/ui-library';
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import { useCanvasSessionContext, useProgressData } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { memo } from 'react';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
type Props = { status: S['SessionQueueItem']['status'] } & TextProps;
|
||||
type Props = { item: S['SessionQueueItem'] } & TextProps;
|
||||
|
||||
export const QueueItemStatusLabel = memo(({ status, ...rest }: Props) => {
|
||||
if (status === 'pending') {
|
||||
export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
|
||||
const ctx = useCanvasSessionContext();
|
||||
const { progressImage, imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
|
||||
|
||||
if (progressImage || imageLoaded) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (item.status === 'pending') {
|
||||
return (
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="base.300" {...rest}>
|
||||
Pending
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
if (status === 'canceled') {
|
||||
if (item.status === 'canceled') {
|
||||
return (
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="warning.300" {...rest}>
|
||||
Canceled
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
if (status === 'failed') {
|
||||
if (item.status === 'failed') {
|
||||
return (
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300" {...rest}>
|
||||
Failed
|
||||
@@ -29,7 +37,7 @@ export const QueueItemStatusLabel = memo(({ status, ...rest }: Props) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (status === 'in_progress') {
|
||||
if (item.status === 'in_progress') {
|
||||
return (
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeBlue.300" {...rest}>
|
||||
In Progress
|
||||
@@ -37,6 +45,14 @@ export const QueueItemStatusLabel = memo(({ status, ...rest }: Props) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (item.status === 'completed') {
|
||||
return (
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeGreen.300" {...rest}>
|
||||
Completed
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
});
|
||||
QueueItemStatusLabel.displayName = 'QueueItemStatusLabel';
|
||||
|
||||
@@ -4,21 +4,22 @@ import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { getOutputImageName } from 'features/controlLayers/components/SimpleSession/shared';
|
||||
import type { ProgressImage } from 'features/nodes/types/common';
|
||||
import type { Atom, WritableAtom } from 'nanostores';
|
||||
import { atom, computed, effect } from 'nanostores';
|
||||
import type { Atom, MapStore, StoreValue, WritableAtom } from 'nanostores';
|
||||
import { atom, computed, effect, map, subscribeKeys } from 'nanostores';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { createContext, memo, useCallback, useContext, useEffect, useMemo, useState } from 'react';
|
||||
import { getImageDTOSafe } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { $socket } from 'services/events/stores';
|
||||
import { assert } from 'tsafe';
|
||||
import { assert, objectEntries } from 'tsafe';
|
||||
|
||||
export type ProgressData = {
|
||||
itemId: number;
|
||||
progressEvent: S['InvocationProgressEvent'] | null;
|
||||
progressImage: ProgressImage | null;
|
||||
imageDTO: ImageDTO | null;
|
||||
imageLoaded: boolean;
|
||||
};
|
||||
|
||||
const getInitialProgressData = (itemId: number): ProgressData => ({
|
||||
@@ -26,17 +27,17 @@ const getInitialProgressData = (itemId: number): ProgressData => ({
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
export const useProgressData = (
|
||||
$progressData: WritableAtom<Record<number, ProgressData>>,
|
||||
itemId: number
|
||||
): ProgressData => {
|
||||
const [value, setValue] = useState<ProgressData>(() => {
|
||||
return $progressData.get()[itemId] ?? getInitialProgressData(itemId);
|
||||
});
|
||||
export const useProgressData = ($progressData: ProgressDataMap, itemId: number): ProgressData => {
|
||||
const getInitialValue = useCallback(
|
||||
() => $progressData.get()[itemId] ?? getInitialProgressData(itemId),
|
||||
[$progressData, itemId]
|
||||
);
|
||||
const [value, setValue] = useState(getInitialValue);
|
||||
useEffect(() => {
|
||||
const unsub = $progressData.subscribe((data) => {
|
||||
const unsub = subscribeKeys($progressData, [itemId], (data) => {
|
||||
const progressData = data[itemId];
|
||||
if (!progressData) {
|
||||
return;
|
||||
@@ -51,7 +52,7 @@ export const useProgressData = (
|
||||
return value;
|
||||
};
|
||||
|
||||
const setProgress = ($progressData: WritableAtom<Record<number, ProgressData>>, data: S['InvocationProgressEvent']) => {
|
||||
const setProgress = ($progressData: ProgressDataMap, data: S['InvocationProgressEvent']) => {
|
||||
const progressData = $progressData.get();
|
||||
const current = progressData[data.item_id];
|
||||
if (current) {
|
||||
@@ -72,27 +73,30 @@ const setProgress = ($progressData: WritableAtom<Record<number, ProgressData>>,
|
||||
progressEvent: data,
|
||||
progressImage: data.image ?? null,
|
||||
imageDTO: null,
|
||||
imageLoaded: false,
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
type ProgressDataMap = MapStore<Record<number, ProgressData | undefined>>;
|
||||
|
||||
type CanvasSessionContextValue = {
|
||||
session: { id: string; type: 'simple' | 'advanced' };
|
||||
$items: Atom<S['SessionQueueItem'][]>;
|
||||
$itemCount: Atom<number>;
|
||||
$hasItems: Atom<boolean>;
|
||||
$progressData: WritableAtom<Record<string, ProgressData>>;
|
||||
$progressData: ProgressDataMap;
|
||||
$selectedItemId: WritableAtom<number | null>;
|
||||
$selectedItem: Atom<S['SessionQueueItem'] | null>;
|
||||
$selectedItemIndex: Atom<number | null>;
|
||||
$selectedItemOutputImageDTO: Atom<ImageDTO | null>;
|
||||
$autoSwitch: WritableAtom<boolean>;
|
||||
$lastLoadedItemId: WritableAtom<number | null>;
|
||||
selectNext: () => void;
|
||||
selectPrev: () => void;
|
||||
selectFirst: () => void;
|
||||
selectLast: () => void;
|
||||
onImageLoad: (itemId: number) => void;
|
||||
};
|
||||
|
||||
const CanvasSessionContext = createContext<CanvasSessionContextValue | null>(null);
|
||||
@@ -112,6 +116,7 @@ export const CanvasSessionContextProvider = memo(
|
||||
const store = useAppStore();
|
||||
|
||||
const socket = useStore($socket);
|
||||
const $lastCompletedItemId = useState(() => atom<number | null>(null))[0];
|
||||
|
||||
/**
|
||||
* Manually-synced atom containing queue items for the current session. This is populated from the RTK Query cache
|
||||
@@ -133,7 +138,7 @@ export const CanvasSessionContextProvider = memo(
|
||||
/**
|
||||
* An ephemeral store of progress events and images for all items in the current session.
|
||||
*/
|
||||
const $progressData = useState(() => atom<Record<number, ProgressData>>({}))[0];
|
||||
const $progressData = useState(() => map<StoreValue<ProgressDataMap>>({}))[0];
|
||||
|
||||
/**
|
||||
* The currently selected queue item's ID, or null if one is not selected.
|
||||
@@ -259,6 +264,27 @@ export const CanvasSessionContextProvider = memo(
|
||||
$selectedItemId.set(last.item_id);
|
||||
}, [$items, $selectedItemId]);
|
||||
|
||||
const onImageLoad = useCallback(
|
||||
(itemId: number) => {
|
||||
const progressData = $progressData.get();
|
||||
const current = progressData[itemId];
|
||||
if (current) {
|
||||
const next = { ...current, imageLoaded: true };
|
||||
$progressData.setKey(itemId, next);
|
||||
} else {
|
||||
$progressData.setKey(itemId, {
|
||||
...getInitialProgressData(itemId),
|
||||
imageLoaded: true,
|
||||
});
|
||||
}
|
||||
if ($lastCompletedItemId.get() === itemId) {
|
||||
$selectedItemId.set(itemId);
|
||||
$lastCompletedItemId.set(null);
|
||||
}
|
||||
},
|
||||
[$lastCompletedItemId, $progressData, $selectedItemId]
|
||||
);
|
||||
|
||||
// Set up socket listeners
|
||||
useEffect(() => {
|
||||
if (!socket) {
|
||||
@@ -272,12 +298,23 @@ export const CanvasSessionContextProvider = memo(
|
||||
setProgress($progressData, data);
|
||||
};
|
||||
|
||||
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
|
||||
if (data.destination !== session.id) {
|
||||
return;
|
||||
}
|
||||
if (data.status === 'completed') {
|
||||
$lastCompletedItemId.set(data.item_id);
|
||||
}
|
||||
};
|
||||
|
||||
socket.on('invocation_progress', onProgress);
|
||||
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
|
||||
return () => {
|
||||
socket.off('invocation_progress', onProgress);
|
||||
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
};
|
||||
}, [$autoSwitch, $progressData, $selectedItemId, session.id, socket]);
|
||||
}, [$autoSwitch, $lastCompletedItemId, $progressData, $selectedItemId, session.id, socket]);
|
||||
|
||||
// Set up state subscriptions and effects
|
||||
useEffect(() => {
|
||||
@@ -327,7 +364,11 @@ export const CanvasSessionContextProvider = memo(
|
||||
const toDelete: number[] = [];
|
||||
const toUpdate: ProgressData[] = [];
|
||||
|
||||
for (const datum of Object.values(progressData)) {
|
||||
for (const [id, datum] of objectEntries(progressData)) {
|
||||
if (!datum) {
|
||||
toDelete.push(id);
|
||||
continue;
|
||||
}
|
||||
const item = items.find(({ item_id }) => item_id === datum.itemId);
|
||||
if (!item) {
|
||||
toDelete.push(datum.itemId);
|
||||
@@ -376,21 +417,13 @@ export const CanvasSessionContextProvider = memo(
|
||||
}
|
||||
}
|
||||
|
||||
if (toDelete.length === 0 && toUpdate.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newProgressData = { ...progressData };
|
||||
|
||||
for (const itemId of toDelete) {
|
||||
delete newProgressData[itemId];
|
||||
$progressData.setKey(itemId, undefined);
|
||||
}
|
||||
|
||||
for (const datum of toUpdate) {
|
||||
newProgressData[datum.itemId] = datum;
|
||||
$progressData.setKey(datum.itemId, datum);
|
||||
}
|
||||
|
||||
$progressData.set(newProgressData);
|
||||
});
|
||||
|
||||
// We only want to auto-switch to completed queue items once their images have fully loaded to prevent flashes
|
||||
@@ -440,19 +473,18 @@ export const CanvasSessionContextProvider = memo(
|
||||
$autoSwitch,
|
||||
$selectedItem,
|
||||
$selectedItemIndex,
|
||||
$lastLoadedItemId,
|
||||
$selectedItemOutputImageDTO,
|
||||
$itemCount,
|
||||
selectNext,
|
||||
selectPrev,
|
||||
selectFirst,
|
||||
selectLast,
|
||||
onImageLoad,
|
||||
}),
|
||||
[
|
||||
$autoSwitch,
|
||||
$items,
|
||||
$hasItems,
|
||||
$lastLoadedItemId,
|
||||
$progressData,
|
||||
$selectedItem,
|
||||
$selectedItemId,
|
||||
@@ -464,6 +496,7 @@ export const CanvasSessionContextProvider = memo(
|
||||
selectPrev,
|
||||
selectFirst,
|
||||
selectLast,
|
||||
onImageLoad,
|
||||
]
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user