feat(ui): modularize all staging area logic so it can be shared w/ canvas more easily

This commit is contained in:
psychedelicious
2025-06-04 23:17:33 +10:00
parent b05de8634d
commit 002816653e
10 changed files with 303 additions and 222 deletions

View File

@@ -1,20 +1,11 @@
import type { CanvasSessionContextValue } from 'features/controlLayers/components/SimpleSession/context';
import {
buildProgressDataAtom,
CanvasSessionContextProvider,
} from 'features/controlLayers/components/SimpleSession/context';
import { CanvasSessionContextProvider } from 'features/controlLayers/components/SimpleSession/context';
import { StagingArea } from 'features/controlLayers/components/SimpleSession/StagingArea';
import type { SimpleSessionIdentifier } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { memo, useMemo } from 'react';
import { memo } from 'react';
export const SimpleSession = memo(({ session }: { session: SimpleSessionIdentifier }) => {
const ctx = useMemo(
() => ({ session, $progressData: buildProgressDataAtom() }) satisfies CanvasSessionContextValue,
[session]
);
return (
<CanvasSessionContextProvider value={ctx}>
<CanvasSessionContextProvider session={session}>
<StagingArea />
</CanvasSessionContextProvider>
);

View File

@@ -1,129 +1,36 @@
/* eslint-disable i18next/no-literal-string */
import { Divider, Flex, Text } from '@invoke-ai/ui-library';
import { Divider, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
import { StagingAreaContent } from 'features/controlLayers/components/SimpleSession/StagingAreaContent';
import { StagingAreaHeader } from 'features/controlLayers/components/SimpleSession/StagingAreaHeader';
import { StagingAreaNoItems } from 'features/controlLayers/components/SimpleSession/StagingAreaNoItems';
import { useProgressEvents } from 'features/controlLayers/components/SimpleSession/use-progress-events';
import { useStagingAreaKeyboardNav } from 'features/controlLayers/components/SimpleSession/use-staging-keyboard-nav';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useListAllQueueItemsQuery } from 'services/api/endpoints/queue';
import type { S } from 'services/api/types';
import { $socket, setProgress } from 'services/events/stores';
const LIST_ALL_OPTIONS = {
selectFromResult: ({ data }) => {
if (!data) {
return { items: EMPTY_ARRAY };
}
return { items: data.filter(({ status }) => status !== 'canceled') };
},
} satisfies Parameters<typeof useListAllQueueItemsQuery>[1];
import { memo, useEffect } from 'react';
export const StagingArea = memo(() => {
const ctx = useCanvasSessionContext();
const [selectedItemId, setSelectedItemId] = useState<number | null>(null);
const [autoSwitch, setAutoSwitch] = useState(true);
const { items } = useListAllQueueItemsQuery({ destination: ctx.session.id }, LIST_ALL_OPTIONS);
const selectedItem = useMemo(() => {
if (items.length === 0) {
return null;
}
if (selectedItemId === null) {
return null;
}
return items.find(({ item_id }) => item_id === selectedItemId) ?? null;
}, [items, selectedItemId]);
const selectedItemIndex = useMemo(() => {
if (items.length === 0) {
return null;
}
if (selectedItemId === null) {
return null;
}
return items.findIndex(({ item_id }) => item_id === selectedItemId) ?? null;
}, [items, selectedItemId]);
const onSelectItemId = useCallback((item_id: number | null) => {
setSelectedItemId(item_id);
if (item_id !== null) {
document.getElementById(getQueueItemElementId(item_id))?.scrollIntoView();
}
}, []);
useStagingAreaKeyboardNav(items, selectedItemId, onSelectItemId);
const hasItems = useStore(ctx.$hasItems);
useProgressEvents();
useStagingAreaKeyboardNav();
useEffect(() => {
if (items.length === 0) {
onSelectItemId(null);
return;
}
if (selectedItemId === null && items.length > 0) {
onSelectItemId(items[0]?.item_id ?? null);
return;
}
}, [items, onSelectItemId, selectedItem, selectedItemId]);
const socket = useStore($socket);
useEffect(() => {
if (!socket) {
return;
}
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
if (data.destination !== ctx.session.id) {
return;
return ctx.$selectedItemId.listen((id) => {
if (id !== null) {
document.getElementById(getQueueItemElementId(id))?.scrollIntoView();
}
if (data.status === 'in_progress' && autoSwitch) {
onSelectItemId(data.item_id);
}
};
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
return () => {
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [autoSwitch, ctx.$progressData, ctx.session.id, onSelectItemId, socket]);
useEffect(() => {
if (!socket) {
return;
}
const onProgress = (data: S['InvocationProgressEvent']) => {
if (data.destination !== ctx.session.id) {
return;
}
setProgress(ctx.$progressData, data);
};
socket.on('invocation_progress', onProgress);
return () => {
socket.off('invocation_progress', onProgress);
};
}, [ctx.$progressData, ctx.session.id, socket]);
});
}, [ctx.$selectedItemId]);
return (
<Flex flexDir="column" gap={2} w="full" h="full" minW={0} minH={0}>
<StagingAreaHeader autoSwitch={autoSwitch} setAutoSwitch={setAutoSwitch} />
<StagingAreaHeader />
<Divider />
{items.length > 0 && (
<StagingAreaContent
items={items}
selectedItem={selectedItem}
selectedItemId={selectedItemId}
selectedItemIndex={selectedItemIndex}
onChangeAutoSwitch={setAutoSwitch}
onSelectItemId={onSelectItemId}
/>
)}
{items.length === 0 && (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Text>No generations</Text>
</Flex>
)}
{hasItems && <StagingAreaContent />}
{!hasItems && <StagingAreaNoItems />}
</Flex>
);
});

View File

@@ -1,58 +1,20 @@
/* eslint-disable i18next/no-literal-string */
import { Divider, Flex, Text } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { QueueItemPreviewFull } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewFull';
import { QueueItemPreviewMini } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewMini';
import { Divider, Flex } from '@invoke-ai/ui-library';
import { StagingAreaItemsList } from 'features/controlLayers/components/SimpleSession/StagingAreaItemsList';
import { StagingAreaSelectedItem } from 'features/controlLayers/components/SimpleSession/StagingAreaSelectedItem';
import { memo } from 'react';
import type { S } from 'services/api/types';
export const StagingAreaContent = memo(
({
items,
selectedItem,
selectedItemId,
selectedItemIndex,
onChangeAutoSwitch,
onSelectItemId,
}: {
items: S['SessionQueueItem'][];
selectedItem: S['SessionQueueItem'] | null;
selectedItemId: number | null;
selectedItemIndex: number | null;
onChangeAutoSwitch: (autoSwitch: boolean) => void;
onSelectItemId: (itemId: number) => void;
}) => {
return (
<>
<Flex position="relative" w="full" h="full" maxH="full" alignItems="center" justifyContent="center" minH={0}>
{selectedItem && selectedItemIndex !== null && (
<QueueItemPreviewFull
key={`${selectedItem.item_id}-full`}
item={selectedItem}
number={selectedItemIndex + 1}
/>
)}
{!selectedItem && <Text>No generation selected</Text>}
</Flex>
<Divider />
<Flex position="relative" maxW="full" w="full" h={108}>
<ScrollableContent overflowX="scroll" overflowY="hidden">
<Flex gap={2} w="full" h="full">
{items.map((item, i) => (
<QueueItemPreviewMini
key={`${item.item_id}-mini`}
item={item}
number={i + 1}
isSelected={selectedItemId === item.item_id}
onSelectItemId={onSelectItemId}
onChangeAutoSwitch={onChangeAutoSwitch}
/>
))}
</Flex>
</ScrollableContent>
</Flex>
</>
);
}
);
export const StagingAreaContent = memo(() => {
return (
<>
<Flex position="relative" w="full" h="full" maxH="full" alignItems="center" justifyContent="center" minH={0}>
<StagingAreaSelectedItem />
</Flex>
<Divider />
<Flex position="relative" maxW="full" w="full" h={108}>
<StagingAreaItemsList />
</Flex>
</>
);
});
StagingAreaContent.displayName = 'StagingAreaContent';

View File

@@ -1,40 +1,42 @@
/* eslint-disable i18next/no-literal-string */
import { Button, Flex, FormControl, FormLabel, Spacer, Switch, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { canvasSessionStarted } from 'features/controlLayers/store/canvasStagingAreaSlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
export const StagingAreaHeader = memo(
({ autoSwitch, setAutoSwitch }: { autoSwitch: boolean; setAutoSwitch: (autoSwitch: boolean) => void }) => {
const dispatch = useAppDispatch();
export const StagingAreaHeader = memo(() => {
const ctx = useCanvasSessionContext();
const autoSwitch = useStore(ctx.$autoSwitch);
const dispatch = useAppDispatch();
const startOver = useCallback(() => {
dispatch(canvasSessionStarted({ sessionType: 'simple' }));
}, [dispatch]);
const startOver = useCallback(() => {
dispatch(canvasSessionStarted({ sessionType: 'simple' }));
}, [dispatch]);
const onChangeAutoSwitch = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
setAutoSwitch(e.target.checked);
},
[setAutoSwitch]
);
const onChangeAutoSwitch = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
ctx.$autoSwitch.set(e.target.checked);
},
[ctx.$autoSwitch]
);
return (
<Flex gap={2} w="full" alignItems="center">
<Text fontSize="lg" fontWeight="bold">
Generations
</Text>
<Spacer />
<FormControl w="min-content">
<FormLabel m={0}>Auto-switch</FormLabel>
<Switch size="sm" isChecked={autoSwitch} onChange={onChangeAutoSwitch} />
</FormControl>
<Button size="sm" variant="ghost" onClick={startOver}>
Start Over
</Button>
</Flex>
);
}
);
return (
<Flex gap={2} w="full" alignItems="center">
<Text fontSize="lg" fontWeight="bold">
Generations
</Text>
<Spacer />
<FormControl w="min-content">
<FormLabel m={0}>Auto-switch</FormLabel>
<Switch size="sm" isChecked={autoSwitch} onChange={onChangeAutoSwitch} />
</FormControl>
<Button size="sm" variant="ghost" onClick={startOver}>
Start Over
</Button>
</Flex>
);
});
StagingAreaHeader.displayName = 'StagingAreaHeader';

View File

@@ -0,0 +1,31 @@
/* eslint-disable i18next/no-literal-string */
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { QueueItemPreviewMini } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewMini';
import { memo } from 'react';
export const StagingAreaItemsList = memo(() => {
const ctx = useCanvasSessionContext();
const items = useStore(ctx.$items);
const selectedItemId = useStore(ctx.$selectedItemId);
return (
<ScrollableContent overflowX="scroll" overflowY="hidden">
<Flex gap={2} w="full" h="full">
{items.map((item, i) => (
<QueueItemPreviewMini
key={`${item.item_id}-mini`}
item={item}
number={i + 1}
isSelected={selectedItemId === item.item_id}
onSelectItemId={ctx.$selectedItemId.set}
onChangeAutoSwitch={ctx.$autoSwitch.set}
/>
))}
</Flex>
</ScrollableContent>
);
});
StagingAreaItemsList.displayName = 'StagingAreaItemsList';

View File

@@ -0,0 +1,13 @@
/* eslint-disable i18next/no-literal-string */
import { Flex, Text } from '@invoke-ai/ui-library';
import { memo } from 'react';
export const StagingAreaNoItems = memo(() => {
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Text>No generations</Text>
</Flex>
);
});
StagingAreaNoItems.displayName = 'StagingAreaNoItems';

View File

@@ -0,0 +1,21 @@
/* eslint-disable i18next/no-literal-string */
import { Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { QueueItemPreviewFull } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewFull';
import { memo } from 'react';
export const StagingAreaSelectedItem = memo(() => {
const ctx = useCanvasSessionContext();
const selectedItem = useStore(ctx.$selectedItem);
const selectedItemIndex = useStore(ctx.$selectedItemIndex);
if (selectedItem && selectedItemIndex !== null) {
return (
<QueueItemPreviewFull key={`${selectedItem.item_id}-full`} item={selectedItem} number={selectedItemIndex + 1} />
);
}
return <Text>No generation selected</Text>;
});
StagingAreaSelectedItem.displayName = 'StagingAreaSelectedItem';

View File

@@ -1,11 +1,16 @@
import { createSelector } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppStore } from 'app/store/nanostores/store';
import type {
AdvancedSessionIdentifier,
SimpleSessionIdentifier,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import type { ProgressImage } from 'features/nodes/types/common';
import { atom, type WritableAtom } from 'nanostores';
import type { Atom, WritableAtom } from 'nanostores';
import { atom, computed, effect } from 'nanostores';
import type { PropsWithChildren } from 'react';
import { createContext, memo, useContext, useEffect, useState } from 'react';
import { createContext, memo, useContext, useEffect, useMemo, useState } from 'react';
import { queueApi } from 'services/api/endpoints/queue';
import type { S } from 'services/api/types';
import { assert } from 'tsafe';
@@ -116,15 +121,113 @@ export const clearProgressImage = ($progressData: WritableAtom<Record<string, Pr
export type CanvasSessionContextValue = {
session: SimpleSessionIdentifier | AdvancedSessionIdentifier;
$items: Atom<S['SessionQueueItem'][]>;
$hasItems: Atom<boolean>;
$progressData: WritableAtom<Record<string, ProgressData>>;
$selectedItemId: WritableAtom<number | null>;
$selectedItem: Atom<S['SessionQueueItem'] | null>;
$selectedItemIndex: Atom<number | null>;
$autoSwitch: WritableAtom<boolean>;
};
const CanvasSessionContext = createContext<CanvasSessionContextValue | null>(null);
export const CanvasSessionContextProvider = memo(
({ value, children }: PropsWithChildren<{ value: CanvasSessionContextValue }>) => (
<CanvasSessionContext.Provider value={value}>{children}</CanvasSessionContext.Provider>
)
({ session, children }: PropsWithChildren<{ session: SimpleSessionIdentifier | AdvancedSessionIdentifier }>) => {
const store = useAppStore();
const [$items] = useState(() => atom<S['SessionQueueItem'][]>([]));
const [$hasItems] = useState(() => computed([$items], (items) => items.length > 0));
const [$autoSwitch] = useState(() => atom(true));
const [$selectedItemId] = useState(() => atom<number | null>(null));
const [$progressData] = useState(() => atom<Record<string, ProgressData>>({}));
const [$selectedItem] = useState(() =>
computed([$items, $selectedItemId], (items, selectedItemId) => {
if (items.length === 0) {
return null;
}
if (selectedItemId === null) {
return null;
}
return items.find(({ item_id }) => item_id === selectedItemId) ?? null;
})
);
const [$selectedItemIndex] = useState(() =>
computed([$items, $selectedItemId], (items, selectedItemId) => {
if (items.length === 0) {
return null;
}
if (selectedItemId === null) {
return null;
}
return items.findIndex(({ item_id }) => item_id === selectedItemId) ?? null;
})
);
const selectQueueItems = useMemo(
() =>
createSelector(
queueApi.endpoints.listAllQueueItems.select({ destination: session.id }),
({ data }) => data?.filter((item) => item.status !== 'canceled') ?? EMPTY_ARRAY
),
[session.id]
);
useEffect(() => {
$items.set(selectQueueItems(store.getState()));
const unsubReduxSyncToItemsAtom = store.subscribe(() => {
const prevItems = $items.get();
const items = selectQueueItems(store.getState());
if (items !== prevItems) {
$items.set(items);
}
});
const unsubEnsureSelectedItemIdExists = effect([$items, $selectedItemId], (items, selectedItemId) => {
if (items.length === 0) {
$selectedItemId.set(null);
return;
}
if (selectedItemId === null && items.length > 0) {
$selectedItemId.set(items[0]?.item_id ?? null);
return;
}
if (selectedItemId !== null && items.findIndex(({ item_id }) => item_id === selectedItemId) === -1) {
$selectedItemId.set(null);
return;
}
});
const { unsubscribe: unsubQueueItemsQuery } = store.dispatch(
queueApi.endpoints.listAllQueueItems.initiate({ destination: session.id })
);
return () => {
unsubQueueItemsQuery();
unsubReduxSyncToItemsAtom();
unsubEnsureSelectedItemIdExists();
$items.set([]);
$progressData.set({});
$selectedItemId.set(null);
};
}, [$items, $progressData, $selectedItemId, selectQueueItems, session.id, store]);
const value = useMemo<CanvasSessionContextValue>(
() => ({
session,
$items,
$hasItems,
$progressData,
$selectedItemId,
$autoSwitch,
$selectedItem,
$selectedItemIndex,
}),
[$autoSwitch, $hasItems, $items, $progressData, $selectedItem, $selectedItemId, $selectedItemIndex, session]
);
return <CanvasSessionContext.Provider value={value}>{children}</CanvasSessionContext.Provider>;
}
);
CanvasSessionContextProvider.displayName = 'CanvasSessionContextProvider';

View File

@@ -0,0 +1,48 @@
import { useStore } from '@nanostores/react';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useEffect } from 'react';
import type { S } from 'services/api/types';
import { $socket, setProgress } from 'services/events/stores';
export const useProgressEvents = () => {
const ctx = useCanvasSessionContext();
const socket = useStore($socket);
useEffect(() => {
if (!socket) {
return;
}
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
if (data.destination !== ctx.session.id) {
return;
}
if (data.status === 'completed' && ctx.$autoSwitch.get()) {
ctx.$selectedItemId.set(data.item_id);
}
};
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
return () => {
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [ctx.$autoSwitch, ctx.$progressData, ctx.$selectedItemId, ctx.session.id, socket]);
useEffect(() => {
if (!socket) {
return;
}
const onProgress = (data: S['InvocationProgressEvent']) => {
if (data.destination !== ctx.session.id) {
return;
}
// TODO: clear progress when done w/ it memory leak
setProgress(ctx.$progressData, data);
};
socket.on('invocation_progress', onProgress);
return () => {
socket.off('invocation_progress', onProgress);
};
}, [ctx.$progressData, ctx.session.id, socket]);
};

View File

@@ -1,51 +1,54 @@
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import type { S } from 'services/api/types';
export const useStagingAreaKeyboardNav = (
items: S['SessionQueueItem'][],
selectedItemId: number | null,
onSelectItemId: (item_id: number) => void
) => {
export const useStagingAreaKeyboardNav = () => {
const ctx = useCanvasSessionContext();
const onNext = useCallback(() => {
const selectedItemId = ctx.$selectedItemId.get();
if (selectedItemId === null) {
return;
}
const items = ctx.$items.get();
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
const nextIndex = (currentIndex + 1) % items.length;
const nextItem = items[nextIndex];
if (!nextItem) {
return;
}
onSelectItemId(nextItem.item_id);
}, [items, onSelectItemId, selectedItemId]);
ctx.$selectedItemId.set(nextItem.item_id);
}, [ctx.$items, ctx.$selectedItemId]);
const onPrev = useCallback(() => {
const selectedItemId = ctx.$selectedItemId.get();
if (selectedItemId === null) {
return;
}
const items = ctx.$items.get();
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
const prevIndex = (currentIndex - 1 + items.length) % items.length;
const prevItem = items[prevIndex];
if (!prevItem) {
return;
}
onSelectItemId(prevItem.item_id);
}, [items, onSelectItemId, selectedItemId]);
ctx.$selectedItemId.set(prevItem.item_id);
}, [ctx.$items, ctx.$selectedItemId]);
const onFirst = useCallback(() => {
const items = ctx.$items.get();
const first = items.at(0);
if (!first) {
return;
}
onSelectItemId(first.item_id);
}, [items, onSelectItemId]);
ctx.$selectedItemId.set(first.item_id);
}, [ctx.$items, ctx.$selectedItemId]);
const onLast = useCallback(() => {
const items = ctx.$items.get();
const last = items.at(-1);
if (!last) {
return;
}
onSelectItemId(last.item_id);
}, [items, onSelectItemId]);
ctx.$selectedItemId.set(last.item_id);
}, [ctx.$items, ctx.$selectedItemId]);
useHotkeys('left', onPrev, { preventDefault: true });
useHotkeys('right', onNext, { preventDefault: true });