From 01784fb3bf2da1cb9ff5082b382ef0cc1581ffb3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 11 Jun 2025 11:44:31 +1000 Subject: [PATCH] feat(ui): store output image DTO in session context instead of just the name --- .../components/SimpleSession/context.tsx | 46 ++++++++++++------- .../components/SimpleSession/shared.ts | 15 +++--- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/context.tsx b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/context.tsx index f89d76fe8d..0f79457f0d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/context.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/context.tsx @@ -8,8 +8,9 @@ import type { Atom, WritableAtom } from 'nanostores'; import { atom, computed, effect } from 'nanostores'; import type { PropsWithChildren } from 'react'; import { createContext, memo, useCallback, useContext, useEffect, useMemo, useState } from 'react'; +import { getImageDTOSafe } from 'services/api/endpoints/images'; import { queueApi } from 'services/api/endpoints/queue'; -import type { S } from 'services/api/types'; +import type { ImageDTO, S } from 'services/api/types'; import { $socket } from 'services/events/stores'; import { assert } from 'tsafe'; import { z } from 'zod'; @@ -21,14 +22,14 @@ export type ProgressData = { itemId: number; progressEvent: S['InvocationProgressEvent'] | null; progressImage: ProgressImage | null; - outputImageName: string | null; + imageDTO: ImageDTO | null; }; const getInitialProgressData = (itemId: number): ProgressData => ({ itemId, progressEvent: null, progressImage: null, - outputImageName: null, + imageDTO: null, }); export const useProgressData = ( @@ -74,7 +75,7 @@ const setProgress = ($progressData: WritableAtom>, itemId: data.item_id, progressEvent: data, progressImage: data.image ?? null, - outputImageName: null, + imageDTO: null, }, }); } @@ -89,7 +90,7 @@ type CanvasSessionContextValue = { $selectedItemId: WritableAtom; $selectedItem: Atom; $selectedItemIndex: Atom; - $selectedItemOutputImageName: Atom; + $selectedItemOutputImageDTO: Atom; $autoSwitch: WritableAtom; $lastLoadedItemId: WritableAtom; selectNext: () => void; @@ -187,7 +188,7 @@ export const CanvasSessionContextProvider = memo( * The currently selected queue item's output image name, or null if one is not selected or there is no output * image recorded. */ - const $selectedItemOutputImageName = useState(() => + const $selectedItemOutputImageDTO = useState(() => computed([$selectedItemId, $progressData], (selectedItemId, progressData) => { if (selectedItemId === null) { return null; @@ -196,7 +197,7 @@ export const CanvasSessionContextProvider = memo( if (!datum) { return null; } - return datum.outputImageName; + return datum.imageDTO; }) )[0]; @@ -328,7 +329,7 @@ export const CanvasSessionContextProvider = memo( }); // Clean up the progress data when a queue item is discarded. - const unsubCleanUpProgressData = $items.listen((items) => { + const unsubCleanUpProgressData = $items.listen(async (items) => { const progressData = $progressData.get(); const toDelete: number[] = []; @@ -343,7 +344,7 @@ export const CanvasSessionContextProvider = memo( ...datum, progressEvent: null, progressImage: null, - outputImageName: null, + imageDTO: null, }; } } @@ -352,21 +353,34 @@ export const CanvasSessionContextProvider = memo( const datum = progressData[item.item_id]; if (datum) { - if (datum.outputImageName) { + if (datum.imageDTO) { continue; } const outputImageName = getOutputImageName(item); if (!outputImageName) { continue; } + const imageDTO = await getImageDTOSafe(outputImageName); + if (!imageDTO) { + continue; + } toUpdate.push({ ...datum, - outputImageName, + imageDTO, }); } else { - const _datum = getInitialProgressData(item.item_id); - _datum.outputImageName = getOutputImageName(item); - toUpdate.push(_datum); + const outputImageName = getOutputImageName(item); + if (!outputImageName) { + continue; + } + const imageDTO = await getImageDTOSafe(outputImageName); + if (!imageDTO) { + continue; + } + toUpdate.push({ + ...getInitialProgressData(item.item_id), + imageDTO, + }); } } @@ -435,7 +449,7 @@ export const CanvasSessionContextProvider = memo( $selectedItem, $selectedItemIndex, $lastLoadedItemId, - $selectedItemOutputImageName, + $selectedItemOutputImageDTO, $itemCount, selectNext, selectPrev, @@ -452,7 +466,7 @@ export const CanvasSessionContextProvider = memo( $selectedItemId, $selectedItemIndex, session, - $selectedItemOutputImageName, + $selectedItemOutputImageDTO, $itemCount, selectNext, selectPrev, diff --git a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/shared.ts b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/shared.ts index d8b7ebc7b1..c0eee11713 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/shared.ts +++ b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/shared.ts @@ -1,9 +1,10 @@ -import { skipToken } from '@reduxjs/toolkit/query'; +import { useStore } from '@nanostores/react'; +import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context'; import { isImageField } from 'features/nodes/types/common'; import { isCanvasOutputNodeId } from 'features/nodes/util/graph/graphBuilderUtils'; import { round } from 'lodash-es'; -import { useMemo } from 'react'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { computed } from 'nanostores'; +import { useState } from 'react'; import type { S } from 'services/api/types'; import { objectEntries } from 'tsafe'; @@ -43,9 +44,11 @@ export const getOutputImageName = (item: S['SessionQueueItem']) => { }; export const useOutputImageDTO = (item: S['SessionQueueItem']) => { - const outputImageName = useMemo(() => getOutputImageName(item), [item]); - - const { currentData: imageDTO } = useGetImageDTOQuery(outputImageName ?? skipToken); + const ctx = useCanvasSessionContext(); + const $imageDTO = useState(() => + computed([ctx.$progressData], (progressData) => progressData[item.item_id]?.imageDTO ?? null) + )[0]; + const imageDTO = useStore($imageDTO); return imageDTO; };