mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 13:25:48 -05:00
feat(ui): store output image DTO in session context instead of just the name
This commit is contained in:
@@ -8,8 +8,9 @@ import type { Atom, WritableAtom } from 'nanostores';
|
||||
import { atom, computed, effect } from 'nanostores';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { createContext, memo, useCallback, useContext, useEffect, useMemo, useState } from 'react';
|
||||
import { getImageDTOSafe } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { S } from 'services/api/types';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { $socket } from 'services/events/stores';
|
||||
import { assert } from 'tsafe';
|
||||
import { z } from 'zod';
|
||||
@@ -21,14 +22,14 @@ export type ProgressData = {
|
||||
itemId: number;
|
||||
progressEvent: S['InvocationProgressEvent'] | null;
|
||||
progressImage: ProgressImage | null;
|
||||
outputImageName: string | null;
|
||||
imageDTO: ImageDTO | null;
|
||||
};
|
||||
|
||||
const getInitialProgressData = (itemId: number): ProgressData => ({
|
||||
itemId,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
outputImageName: null,
|
||||
imageDTO: null,
|
||||
});
|
||||
|
||||
export const useProgressData = (
|
||||
@@ -74,7 +75,7 @@ const setProgress = ($progressData: WritableAtom<Record<number, ProgressData>>,
|
||||
itemId: data.item_id,
|
||||
progressEvent: data,
|
||||
progressImage: data.image ?? null,
|
||||
outputImageName: null,
|
||||
imageDTO: null,
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -89,7 +90,7 @@ type CanvasSessionContextValue = {
|
||||
$selectedItemId: WritableAtom<number | null>;
|
||||
$selectedItem: Atom<S['SessionQueueItem'] | null>;
|
||||
$selectedItemIndex: Atom<number | null>;
|
||||
$selectedItemOutputImageName: Atom<string | null>;
|
||||
$selectedItemOutputImageDTO: Atom<ImageDTO | null>;
|
||||
$autoSwitch: WritableAtom<AutoSwitchMode>;
|
||||
$lastLoadedItemId: WritableAtom<number | null>;
|
||||
selectNext: () => void;
|
||||
@@ -187,7 +188,7 @@ export const CanvasSessionContextProvider = memo(
|
||||
* The currently selected queue item's output image name, or null if one is not selected or there is no output
|
||||
* image recorded.
|
||||
*/
|
||||
const $selectedItemOutputImageName = useState(() =>
|
||||
const $selectedItemOutputImageDTO = useState(() =>
|
||||
computed([$selectedItemId, $progressData], (selectedItemId, progressData) => {
|
||||
if (selectedItemId === null) {
|
||||
return null;
|
||||
@@ -196,7 +197,7 @@ export const CanvasSessionContextProvider = memo(
|
||||
if (!datum) {
|
||||
return null;
|
||||
}
|
||||
return datum.outputImageName;
|
||||
return datum.imageDTO;
|
||||
})
|
||||
)[0];
|
||||
|
||||
@@ -328,7 +329,7 @@ export const CanvasSessionContextProvider = memo(
|
||||
});
|
||||
|
||||
// Clean up the progress data when a queue item is discarded.
|
||||
const unsubCleanUpProgressData = $items.listen((items) => {
|
||||
const unsubCleanUpProgressData = $items.listen(async (items) => {
|
||||
const progressData = $progressData.get();
|
||||
|
||||
const toDelete: number[] = [];
|
||||
@@ -343,7 +344,7 @@ export const CanvasSessionContextProvider = memo(
|
||||
...datum,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
outputImageName: null,
|
||||
imageDTO: null,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -352,21 +353,34 @@ export const CanvasSessionContextProvider = memo(
|
||||
const datum = progressData[item.item_id];
|
||||
|
||||
if (datum) {
|
||||
if (datum.outputImageName) {
|
||||
if (datum.imageDTO) {
|
||||
continue;
|
||||
}
|
||||
const outputImageName = getOutputImageName(item);
|
||||
if (!outputImageName) {
|
||||
continue;
|
||||
}
|
||||
const imageDTO = await getImageDTOSafe(outputImageName);
|
||||
if (!imageDTO) {
|
||||
continue;
|
||||
}
|
||||
toUpdate.push({
|
||||
...datum,
|
||||
outputImageName,
|
||||
imageDTO,
|
||||
});
|
||||
} else {
|
||||
const _datum = getInitialProgressData(item.item_id);
|
||||
_datum.outputImageName = getOutputImageName(item);
|
||||
toUpdate.push(_datum);
|
||||
const outputImageName = getOutputImageName(item);
|
||||
if (!outputImageName) {
|
||||
continue;
|
||||
}
|
||||
const imageDTO = await getImageDTOSafe(outputImageName);
|
||||
if (!imageDTO) {
|
||||
continue;
|
||||
}
|
||||
toUpdate.push({
|
||||
...getInitialProgressData(item.item_id),
|
||||
imageDTO,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -435,7 +449,7 @@ export const CanvasSessionContextProvider = memo(
|
||||
$selectedItem,
|
||||
$selectedItemIndex,
|
||||
$lastLoadedItemId,
|
||||
$selectedItemOutputImageName,
|
||||
$selectedItemOutputImageDTO,
|
||||
$itemCount,
|
||||
selectNext,
|
||||
selectPrev,
|
||||
@@ -452,7 +466,7 @@ export const CanvasSessionContextProvider = memo(
|
||||
$selectedItemId,
|
||||
$selectedItemIndex,
|
||||
session,
|
||||
$selectedItemOutputImageName,
|
||||
$selectedItemOutputImageDTO,
|
||||
$itemCount,
|
||||
selectNext,
|
||||
selectPrev,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { isImageField } from 'features/nodes/types/common';
|
||||
import { isCanvasOutputNodeId } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { round } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { computed } from 'nanostores';
|
||||
import { useState } from 'react';
|
||||
import type { S } from 'services/api/types';
|
||||
import { objectEntries } from 'tsafe';
|
||||
|
||||
@@ -43,9 +44,11 @@ export const getOutputImageName = (item: S['SessionQueueItem']) => {
|
||||
};
|
||||
|
||||
export const useOutputImageDTO = (item: S['SessionQueueItem']) => {
|
||||
const outputImageName = useMemo(() => getOutputImageName(item), [item]);
|
||||
|
||||
const { currentData: imageDTO } = useGetImageDTOQuery(outputImageName ?? skipToken);
|
||||
const ctx = useCanvasSessionContext();
|
||||
const $imageDTO = useState(() =>
|
||||
computed([ctx.$progressData], (progressData) => progressData[item.item_id]?.imageDTO ?? null)
|
||||
)[0];
|
||||
const imageDTO = useStore($imageDTO);
|
||||
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user