feat: canvas flow rework (wip)

This commit is contained in:
psychedelicious
2025-06-03 13:45:01 +10:00
parent 5e93f58530
commit 8a78e37634
18 changed files with 645 additions and 268 deletions

View File

@@ -20,7 +20,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.pagination import CursorPaginatedResults
@@ -68,7 +67,7 @@ async def enqueue_batch(
"/{queue_id}/list",
operation_id="list_queue_items",
responses={
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
200: {"model": CursorPaginatedResults[SessionQueueItem]},
},
)
async def list_queue_items(
@@ -77,11 +76,38 @@ async def list_queue_items(
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
priority: int = Query(default=0, description="The pagination cursor priority"),
) -> CursorPaginatedResults[SessionQueueItemDTO]:
"""Gets all queue items (without graphs)"""
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets cursor-paginated queue items"""
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id, limit=limit, status=status, cursor=cursor, priority=priority
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
@session_queue_router.get(
"/{queue_id}/all",
operation_id="list_all_queue_items",
responses={
200: {"model": list[SessionQueueItem]},
},
)
async def list_all_queue_items(
queue_id: str = Path(description="The queue id to perform this operation on"),
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> list[SessionQueueItem]:
"""Gets all queue items"""
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
status=status,
destination=destination,
)

View File

@@ -17,7 +17,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import GraphExecutionState
@@ -127,10 +126,21 @@ class SessionQueueBase(ABC):
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets a page of session queue items"""
pass
@abstractmethod
def list_all_queue_items(
self,
queue_id: str,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
pass
@abstractmethod
def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID"""

View File

@@ -208,7 +208,7 @@ class FieldIdentifier(BaseModel):
user_label: str | None = Field(description="The user label of the field, if any")
class SessionQueueItemWithoutGraph(BaseModel):
class SessionQueueItem(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
item_id: int = Field(description="The identifier of the session queue item")
@@ -252,42 +252,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
default=None,
description="The ID of the published workflow associated with this queue item",
)
api_input_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The fields that were used as input to the API"
)
api_output_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The nodes that were used as output from the API"
)
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
return SessionQueueItemDTO(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra={
"required": [
"item_id",
"status",
"batch_id",
"queue_id",
"session_id",
"priority",
"session_id",
"created_at",
"updated_at",
]
}
)
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass
class SessionQueueItem(SessionQueueItemWithoutGraph):
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow associated with this queue item"

View File

@@ -24,7 +24,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatus,
ValueToInsertTuple,
@@ -540,26 +539,12 @@ class SqliteSessionQueue(SessionQueueBase):
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
cursor_ = self._conn.cursor()
item_id = cursor
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id,
origin,
destination
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
@@ -571,6 +556,12 @@ class SqliteSessionQueue(SessionQueueBase):
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
@@ -586,7 +577,7 @@ class SqliteSessionQueue(SessionQueueBase):
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
@@ -594,6 +585,44 @@ class SqliteSessionQueue(SessionQueueBase):
has_more = True
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def list_all_queue_items(
self,
queue_id: str,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
cursor_ = self._conn.cursor()
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
query += """--sql
ORDER BY
priority DESC,
item_id ASC
;
"""
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
cursor = self._conn.cursor()
cursor.execute(

View File

@@ -7,6 +7,7 @@ from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type
import networkx as nx
from pydantic import (
BaseModel,
ConfigDict,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
ValidationError,
@@ -787,6 +788,22 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
model_config = ConfigDict(
json_schema_extra={
"required": [
"id",
"graph",
"execution_graph",
"executed",
"executed_history",
"results",
"errors",
"prepared_source_mapping",
"source_prepared_mapping",
]
}
)
@field_validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""

View File

@@ -3,7 +3,6 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too
import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import { getDebugLoggerMiddleware } from 'app/store/middleware/debugLoggerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
@@ -176,7 +175,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
.concat(api.middleware)
.concat(dynamicMiddlewares)
.concat(authToastMiddleware)
.concat(getDebugLoggerMiddleware())
// .concat(getDebugLoggerMiddleware())
.prepend(listenerMiddleware.middleware),
enhancers: (getDefaultEnhancers) => {
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());

View File

@@ -1,10 +1,14 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
/* eslint-disable i18next/no-literal-string */
import type { FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
import {
Box,
Button,
ButtonGroup,
CircularProgress,
ContextMenu,
Flex,
FormControl,
FormLabel,
Heading,
IconButton,
Image,
@@ -12,9 +16,13 @@ import {
MenuButton,
MenuList,
Spacer,
Switch,
Text,
Tooltip,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper';
@@ -49,15 +57,21 @@ import { newCanvasFromImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { newCanvasFromImage } from 'features/imageActions/actions';
import { memo, useCallback, useEffect, useMemo } from 'react';
import type { ProgressImage } from 'features/nodes/types/common';
import { isImageField } from 'features/nodes/types/common';
import { isCanvasOutputNodeId } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { Trans, useTranslation } from 'react-i18next';
import { PiDotsThreeOutlineVerticalFill, PiUploadBold } from 'react-icons/pi';
import type { ImageDTO } from 'services/api/types';
import { getImageDTOSafe, useGetImageDTOQuery } from 'services/api/endpoints/images';
import { queueItemsAdapterSelectors, useListQueueItemsQuery } from 'services/api/endpoints/queue';
import type { ImageDTO, S } from 'services/api/types';
import type { ProgressAndResult } from 'services/events/stores';
import { $progressImages, useMapSelector } from 'services/events/stores';
import { $progressImages, $socket, useMapSelector } from 'services/events/stores';
import type { Equals, Param0 } from 'tsafe';
import { assert } from 'tsafe';
import { assert, objectEntries } from 'tsafe';
import { CanvasAlertsInvocationProgress } from './CanvasAlerts/CanvasAlertsInvocationProgress';
@@ -311,20 +325,405 @@ const SimpleActiveSession = memo(() => {
<Flex flexDir="column" w="full" h="full" alignItems="center" justifyContent="center" gap={2}>
<Flex w="full">
<Text fontSize="lg" fontWeight="bold">
Generations from this Session
Generations
</Text>
<Spacer />
<Button size="sm" onClick={startOver}>
<Button size="sm" variant="ghost" onClick={startOver}>
Start Over
</Button>
</Flex>
<SelectedImageOrProgressImage />
<SessionImages />
<StagingArea />
</Flex>
);
});
SimpleActiveSession.displayName = 'SimpleActiveSession';
const scrollIndicatorSx = {
opacity: 0,
'&[data-visible="true"]': {
opacity: 1,
},
} satisfies SystemStyleObject;
const StagingArea = memo(() => {
const [selectedItemId, setSelectedItemId] = useState<number | null>(null);
const [autoSwitch, setAutoSwitch] = useState(true);
const [canScrollLeft, setCanScrollLeft] = useState(false);
const [canScrollRight, setCanScrollRight] = useState(false);
const scrollableRef = useRef<HTMLDivElement>(null);
const { data } = useListQueueItemsQuery({ destination: 'canvas' });
const items = useMemo(() => {
if (!data) {
return EMPTY_ARRAY;
}
return queueItemsAdapterSelectors.selectAll(data);
}, [data]);
const selectedItem = useMemo(
() =>
data && selectedItemId !== null ? queueItemsAdapterSelectors.selectById(data, String(selectedItemId)) : null,
[data, selectedItemId]
);
useEffect(() => {
if (items.length === 0) {
setSelectedItemId(null);
return;
}
if (selectedItem === null && items.length > 0) {
setSelectedItemId(items[0]?.item_id ?? null);
return;
}
if (selectedItemId === null || items.find((item) => item.item_id === selectedItemId) === undefined) {
return;
}
document.getElementById(`queue-item-status-card-${selectedItemId}`)?.scrollIntoView();
}, [items, selectedItem, selectedItemId]);
useEffect(() => {
const el = scrollableRef.current;
if (!el) {
return;
}
const onScroll = () => {
const { scrollLeft, scrollWidth, clientWidth } = el;
setCanScrollLeft(scrollLeft > 0);
setCanScrollRight(scrollLeft + clientWidth < scrollWidth);
};
el.addEventListener('scroll', onScroll);
const observer = new ResizeObserver(onScroll);
observer.observe(el);
return () => {
el.removeEventListener('scroll', onScroll);
observer.disconnect();
};
}, []);
const onSelectItem = useCallback((item: S['SessionQueueItem']) => {
setSelectedItemId(item.item_id);
if (item.status !== 'in_progress') {
setAutoSwitch(false);
}
}, []);
const onNext = useCallback(() => {
if (selectedItemId === null) {
return;
}
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
const nextIndex = (currentIndex + 1) % items.length;
const nextItem = items[nextIndex];
if (!nextItem) {
return;
}
setSelectedItemId(nextItem.item_id);
}, [items, selectedItemId]);
const onPrev = useCallback(() => {
if (selectedItemId === null) {
return;
}
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
const prevIndex = (currentIndex - 1 + items.length) % items.length;
const prevItem = items[prevIndex];
if (!prevItem) {
return;
}
setSelectedItemId(prevItem.item_id);
}, [items, selectedItemId]);
useHotkeys('left', onPrev);
useHotkeys('right', onNext);
const socket = useStore($socket);
useEffect(() => {
if (!autoSwitch) {
return;
}
if (!socket) {
return;
}
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
if (data.destination !== 'canvas') {
return;
}
if (data.status === 'in_progress') {
setSelectedItemId(data.item_id);
}
};
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
return () => {
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [autoSwitch, socket]);
const onChangeAutoSwitch = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setAutoSwitch(e.target.checked);
}, []);
return (
<Flex position="relative" flexDir="column" gap={2} w="full" h="full" minW={0} minH={0}>
<Flex w="full" h="full" alignItems="center" justifyContent="center" minW={0} minH={0}>
{selectedItem && <QueueItemStatusCard item={selectedItem} minW={0} minH={0} h="full" isSelected={false} />}
{!selectedItem && <Text>No queued generations</Text>}
</Flex>
<FormControl position="absolute" top={2} right={2} w="min-content">
<FormLabel m={0}>Auto-switch</FormLabel>
<Switch size="sm" isChecked={autoSwitch} onChange={onChangeAutoSwitch} />
</FormControl>
<Flex position="relative" w="full" maxW="full">
<Flex ref={scrollableRef} gap={2} h={108} maxW="full" overflowX="scroll" flexShrink={0}>
{items.map((item, i) => (
<QueueItemStatusCard
id={`queue-item-status-card-${item.item_id}`}
key={item.item_id}
item={item}
number={i + 1}
onSelectItem={onSelectItem}
isSelected={selectedItemId === item.item_id}
w={108}
h={108}
flexShrink={0}
/>
))}
</Flex>
<Box
position="absolute"
sx={scrollIndicatorSx}
left={0}
w={16}
h="full"
bg="linear-gradient(to right, var(--invoke-colors-base-900), transparent)"
data-visible={canScrollLeft}
transitionProperty="opacity"
transitionDuration="0.3s"
pointerEvents="none"
/>
<Box
position="absolute"
sx={scrollIndicatorSx}
right={0}
w={16}
h="full"
bg="linear-gradient(to left, var(--invoke-colors-base-900), transparent)"
data-visible={canScrollRight}
transitionProperty="opacity"
transitionDuration="0.3s"
pointerEvents="none"
/>
</Flex>
</Flex>
);
});
StagingArea.displayName = 'StagingArea';
const IMAGE_DTO_ERROR = Symbol('IMAGE_DTO_ERROR');
const useOutputImageDTO = (item: S['SessionQueueItem']) => {
const [imageDTO, setImageDTO] = useState<ImageDTO | typeof IMAGE_DTO_ERROR | null>(null);
const syncImageDTO = useCallback(async (item: S['SessionQueueItem']) => {
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>
isCanvasOutputNodeId(nodeId)
)?.[1][0];
const output = nodeId ? item.session.results[nodeId] : undefined;
if (!output) {
return setImageDTO(null);
}
for (const [_name, value] of objectEntries(output)) {
if (isImageField(value)) {
const imageDTO = await getImageDTOSafe(value.image_name);
if (imageDTO) {
setImageDTO(imageDTO);
$progressImages.setKey(item.session_id, undefined);
return;
}
}
}
setImageDTO(IMAGE_DTO_ERROR);
}, []);
useEffect(() => {
syncImageDTO(item);
}, [item, syncImageDTO]);
return imageDTO;
};
const QueueItemStatusCard = memo(
({
item,
isSelected,
number,
onSelectItem,
...rest
}: {
item: S['SessionQueueItem'];
isSelected: boolean;
number?: number;
onSelectItem?: (item: S['SessionQueueItem']) => void;
} & FlexProps) => {
const onClick = useCallback(() => {
onSelectItem?.(item);
}, [item, onSelectItem]);
return (
<Flex
role="button"
pos="relative"
borderWidth={1}
borderRadius="base"
alignItems="center"
justifyContent="center"
overflow="hidden"
onClick={onClick}
aspectRatio="1/1"
borderColor={isSelected ? 'invokeBlue.300' : undefined}
{...rest}
>
<QueueItemStatusCardContent item={item} />
{number !== undefined && <Text position="absolute" top={0} left={1}>{`#${number}`}</Text>}
</Flex>
);
}
);
QueueItemStatusCard.displayName = 'QueueItemStatusCard';
const QueueItemStatusCardContent = memo(({ item }: { item: S['SessionQueueItem'] }) => {
const socket = useStore($socket);
const [progressEvent, setProgressEvent] = useState<S['InvocationProgressEvent'] | null>(null);
const [progressImage, setProgressImage] = useState<ProgressImage | null>(null);
useEffect(() => {
if (!socket) {
return;
}
const onProgress = (data: S['InvocationProgressEvent']) => {
if (data.session_id !== item.session_id) {
return;
}
setProgressEvent(data);
if (data.image) {
setProgressImage(data.image);
}
};
socket.on('invocation_progress', onProgress);
return () => {
socket.off('invocation_progress', onProgress);
};
}, [item.session_id, socket]);
const imageDTO = useOutputImageDTO(item);
if (item.status === 'pending') {
return (
<Text fontWeight="semibold" color="base.300">
Pending
</Text>
);
}
if (item.status === 'canceled') {
return (
<Text fontWeight="semibold" color="warning.300">
Canceled
</Text>
);
}
if (item.status === 'failed') {
return (
<Text fontWeight="semibold" color="error.300">
Failed
</Text>
);
}
if (item.status === 'in_progress' || !imageDTO) {
if (!progressImage) {
return (
<>
<Text fontWeight="semibold" color="invokeBlue.300">
In Progress
</Text>
<ProgressCircle data={progressEvent} />
</>
);
}
return (
<>
<Image objectFit="contain" maxH="full" maxW="full" src={progressImage.dataURL} width={progressImage.width} />
<ProgressCircle data={progressEvent} />
</>
);
}
if (item.status === 'completed' && imageDTO && imageDTO !== IMAGE_DTO_ERROR) {
return <Image objectFit="contain" maxH="full" maxW="full" src={imageDTO.image_url} width={imageDTO.width} />;
}
if (item.status === 'completed') {
return (
<Text fontWeight="semibold" color="error.300">
Unable to get image
</Text>
);
}
assert<Equals<never, typeof item.status>>(false);
});
QueueItemStatusCardContent.displayName = 'QueueItemStatusCardContent';
const circleStyles: SystemStyleObject = {
circle: {
transitionProperty: 'none',
transitionDuration: '0s',
},
position: 'absolute',
top: 2,
right: 2,
};
const ProgressCircle = ({ data }: { data?: S['InvocationProgressEvent'] | null }) => {
return (
<Tooltip label={data?.message ?? 'Generating'}>
<CircularProgress
size="14px"
color="invokeBlue.500"
thickness={14}
isIndeterminate={!data || data.percentage === null}
value={data?.percentage ? data.percentage * 100 : undefined}
sx={circleStyles}
/>
</Tooltip>
);
};
ProgressCircle.displayName = 'ProgressCircle';
const QueueItemResultCard = memo(({ item }: { item: S['SessionQueueItem'] }) => {
const imageName = useMemo(() => {
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>
isCanvasOutputNodeId(nodeId)
)?.[1][0];
const output = nodeId ? item.session.results[nodeId] : undefined;
if (!output) {
return;
}
for (const [_name, value] of objectEntries(output)) {
if (isImageField(value)) {
return value.image_name;
}
}
}, [item]);
const { data: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
if (!imageDTO) {
return <Text>Unknown output type</Text>;
}
return <Image objectFit="contain" maxH="full" maxW="full" src={imageDTO.image_url} width={imageDTO.width} />;
});
QueueItemResultCard.displayName = 'QueueItemResultCard';
const SelectedImageOrProgressImage = memo(() => {
const selectedImage = useAppSelector(selectSelectedImage);

View File

@@ -159,6 +159,8 @@ export const isMainModelWithoutUnet = (modelLoader: Invocation<MainModelLoaderNo
);
};
export const isCanvasOutputNodeId = (nodeId: string) => nodeId.split(':')[0] === CANVAS_OUTPUT_PREFIX;
export const isCanvasOutputEvent = (data: S['InvocationCompleteEvent']) => {
return data.invocation_source_id.split(':')[0] === CANVAS_OUTPUT_PREFIX;
return isCanvasOutputNodeId(data.invocation_source_id);
};

View File

@@ -13,7 +13,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiXBold } from 'react-icons/pi';
import { useSelector } from 'react-redux';
import type { SessionQueueItemDTO } from 'services/api/types';
import type { S } from 'services/api/types';
import { COLUMN_WIDTHS } from './constants';
import QueueItemDetail from './QueueItemDetail';
@@ -23,7 +23,7 @@ const selectedStyles = { bg: 'base.700' };
type InnerItemProps = {
index: number;
item: SessionQueueItemDTO;
item: S['SessionQueueItem'];
context: ListContext;
};
@@ -155,7 +155,7 @@ const QueueItemComponent = ({ index, item, context }: InnerItemProps) => {
</Flex>
<Collapse in={isOpen} transition={transition} unmountOnExit={true}>
<QueueItemDetail queueItemDTO={item} />
<QueueItemDetail queueItem={item} />
</Collapse>
</Flex>
);

View File

@@ -12,24 +12,20 @@ import type { ReactNode } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiXBold } from 'react-icons/pi';
import { useGetQueueItemQuery } from 'services/api/endpoints/queue';
import type { SessionQueueItemDTO } from 'services/api/types';
import type { S } from 'services/api/types';
type Props = {
queueItemDTO: SessionQueueItemDTO;
queueItem: S['SessionQueueItem'];
};
const QueueItemComponent = ({ queueItemDTO }: Props) => {
const { session_id, batch_id, item_id, origin, destination } = queueItemDTO;
const QueueItemComponent = ({ queueItem }: Props) => {
const { session_id, batch_id, item_id, origin, destination } = queueItem;
const { t } = useTranslation();
const isRetryEnabled = useFeatureStatus('retryQueueItem');
const { cancelBatch, isLoading: isLoadingCancelBatch, isCanceled: isBatchCanceled } = useCancelBatch(batch_id);
const { cancelQueueItem, isLoading: isLoadingCancelQueueItem } = useCancelQueueItem(item_id);
const { retryQueueItem, isLoading: isLoadingRetryQueueItem } = useRetryQueueItem(item_id);
const { data: queueItem } = useGetQueueItemQuery(item_id);
const originText = useOriginText(origin);
const destinationText = useDestinationText(destination);

View File

@@ -14,7 +14,7 @@ import { useTranslation } from 'react-i18next';
import type { Components, ItemContent } from 'react-virtuoso';
import { Virtuoso } from 'react-virtuoso';
import { queueItemsAdapterSelectors, useListQueueItemsQuery } from 'services/api/endpoints/queue';
import type { SessionQueueItemDTO } from 'services/api/types';
import type { S } from 'services/api/types';
import QueueItemComponent from './QueueItemComponent';
import QueueListComponent from './QueueListComponent';
@@ -24,13 +24,13 @@ import type { ListContext } from './types';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
type TableVirtuosoScrollerRef = (ref: HTMLElement | Window | null) => any;
const computeItemKey = (index: number, item: SessionQueueItemDTO): number => item.item_id;
const computeItemKey = (index: number, item: S['SessionQueueItem']): number => item.item_id;
const components: Components<SessionQueueItemDTO, ListContext> = {
const components: Components<S['SessionQueueItem'], ListContext> = {
List: QueueListComponent,
};
const itemContent: ItemContent<SessionQueueItemDTO, ListContext> = (index, item, context) => (
const itemContent: ItemContent<S['SessionQueueItem'], ListContext> = (index, item, context) => (
<QueueItemComponent index={index} item={item} context={context} />
);
@@ -114,7 +114,7 @@ const QueueList = () => {
<Flex w="full" h="full" flexDir="column">
<QueueListHeader />
<Flex ref={rootRef} w="full" h="full" alignItems="center" justifyContent="center">
<Virtuoso<SessionQueueItemDTO, ListContext>
<Virtuoso<S['SessionQueueItem'], ListContext>
data={queueItems}
endReached={handleLoadMore}
scrollerRef={setScroller as TableVirtuosoScrollerRef}

View File

@@ -1,10 +1,10 @@
import { Flex, forwardRef, typedMemo } from '@invoke-ai/ui-library';
import type { Components } from 'react-virtuoso';
import type { SessionQueueItemDTO } from 'services/api/types';
import type { S } from 'services/api/types';
import type { ListContext } from './types';
const QueueListComponent: Components<SessionQueueItemDTO, ListContext>['List'] = typedMemo(
const QueueListComponent: Components<S['SessionQueueItem'], ListContext>['List'] = typedMemo(
forwardRef((props, ref) => {
return (
<Flex {...props} ref={ref} flexDirection="column" gap={0.5}>

View File

@@ -1,7 +1,7 @@
import { useTranslation } from 'react-i18next';
import type { SessionQueueItemDTO } from 'services/api/types';
import type { S } from 'services/api/types';
export const useDestinationText = (destination: SessionQueueItemDTO['destination']) => {
export const useDestinationText = (destination: S['SessionQueueItem']['destination']) => {
const { t } = useTranslation();
if (destination === 'canvas') {

View File

@@ -1,7 +1,7 @@
import { useTranslation } from 'react-i18next';
import type { SessionQueueItemDTO } from 'services/api/types';
import type { S } from 'services/api/types';
export const useOriginText = (origin: SessionQueueItemDTO['origin']) => {
export const useOriginText = (origin: S['SessionQueueItem']['origin']) => {
const { t } = useTranslation();
if (origin === 'generation') {

View File

@@ -5,9 +5,10 @@ import { $queueId } from 'app/store/nanostores/queueId';
import { listParamsReset } from 'features/queue/store/queueSlice';
import queryString from 'query-string';
import type { components, paths } from 'services/api/schema';
import type { S } from 'services/api/types';
import type { ApiTagDescription } from '..';
import { api, buildV1Url } from '..';
import { api, buildV1Url, LIST_TAG } from '..';
/**
* Builds an endpoint URL for the queue router
@@ -35,7 +36,7 @@ export type SessionQueueItemStatus = NonNullable<
NonNullable<paths['/api/v1/queue/{queue_id}/list']['get']['parameters']['query']>['status']
>;
export const queueItemsAdapter = createEntityAdapter<components['schemas']['SessionQueueItemDTO'], string>({
export const queueItemsAdapter = createEntityAdapter<S['SessionQueueItem'], string>({
selectId: (queueItem) => String(queueItem.item_id),
sortComparer: (a, b) => {
// Sort by priority in descending order
@@ -388,10 +389,10 @@ export const queueApi = api.injectEndpoints({
invalidatesTags: ['CurrentSessionQueueItem', 'NextSessionQueueItem', 'QueueCountsByDestination'],
}),
listQueueItems: build.query<
EntityState<components['schemas']['SessionQueueItemDTO'], string> & {
EntityState<S['SessionQueueItem'], string> & {
has_more: boolean;
},
{ cursor?: number; priority?: number } | undefined
{ cursor?: number; priority?: number; destination?: string } | undefined
>({
query: (queryArgs) => ({
url: getListQueueItemsUrl(queryArgs),
@@ -400,20 +401,20 @@ export const queueApi = api.injectEndpoints({
serializeQueryArgs: () => {
return buildQueueUrl('list');
},
transformResponse: (response: components['schemas']['CursorPaginatedResults_SessionQueueItemDTO_']) =>
queueItemsAdapter.addMany(
transformResponse: (response: components['schemas']['CursorPaginatedResults_SessionQueueItem_']) =>
queueItemsAdapter.upsertMany(
queueItemsAdapter.getInitialState({
has_more: response.has_more,
}),
response.items
),
merge: (cache, response) => {
queueItemsAdapter.addMany(cache, queueItemsAdapterSelectors.selectAll(response));
queueItemsAdapter.upsertMany(cache, queueItemsAdapterSelectors.selectAll(response));
cache.has_more = response.has_more;
},
forceRefetch: ({ currentArg, previousArg }) => currentArg !== previousArg,
keepUnusedDataFor: 60 * 5, // 5 minutes
providesTags: ['FetchOnReconnect'],
providesTags: ['FetchOnReconnect', { type: 'SessionQueueItem', id: LIST_TAG }],
}),
getQueueCountsByDestination: build.query<
paths['/api/v1/queue/{queue_id}/counts_by_destination']['get']['responses']['200']['content']['application/json'],

View File

@@ -1153,7 +1153,7 @@ export type paths = {
};
/**
* List Queue Items
* @description Gets all queue items (without graphs)
* @description Gets cursor-paginated queue items
*/
get: operations["list_queue_items"];
put?: never;
@@ -1164,6 +1164,26 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/queue/{queue_id}/all": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* List All Queue Items
* @description Gets all queue items
*/
get: operations["list_all_queue_items"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/queue/{queue_id}/processor/resume": {
parameters: {
query?: never;
@@ -5715,8 +5735,8 @@ export type components = {
*/
type: "crop_latents";
};
/** CursorPaginatedResults[SessionQueueItemDTO] */
CursorPaginatedResults_SessionQueueItemDTO_: {
/** CursorPaginatedResults[SessionQueueItem] */
CursorPaginatedResults_SessionQueueItem_: {
/**
* Limit
* @description Limit of items to get
@@ -5731,7 +5751,7 @@ export type components = {
* Items
* @description Items
*/
items: components["schemas"]["SessionQueueItemDTO"][];
items: components["schemas"]["SessionQueueItem"][];
};
/**
* OpenCV Inpaint
@@ -8742,47 +8762,47 @@ export type components = {
* Id
* @description The id of the execution state
*/
id?: string;
id: string;
/** @description The graph being executed */
graph: components["schemas"]["Graph"];
/** @description The expanded graph of activated and executed nodes */
execution_graph?: components["schemas"]["Graph"];
execution_graph: components["schemas"]["Graph"];
/**
* Executed
* @description The set of node ids that have been executed
*/
executed?: string[];
executed: string[];
/**
* Executed History
* @description The list of node ids that have been executed, in order of execution
*/
executed_history?: string[];
executed_history: string[];
/**
* Results
* @description The results of node executions
*/
results?: {
results: {
[key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CogView4ConditioningOutput"] | components["schemas"]["CogView4ModelLoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatGeneratorOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxControlLoRALoaderOutput"] | components["schemas"]["FluxControlNetOutput"] | components["schemas"]["FluxFillOutput"] | components["schemas"]["FluxLoRALoaderOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["FluxReduxOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageGeneratorOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ImagePanelCoordinateOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerGeneratorOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsMetaOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MDControlListOutput"] | components["schemas"]["MDIPAdapterListOutput"] | components["schemas"]["MDT2IAdapterListOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["MetadataToLorasCollectionOutput"] | components["schemas"]["MetadataToModelOutput"] | components["schemas"]["MetadataToSDXLModelOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SD3ConditioningOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["Sd3ModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringGeneratorOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
};
/**
* Errors
* @description Errors raised when executing nodes
*/
errors?: {
errors: {
[key: string]: string;
};
/**
* Prepared Source Mapping
* @description The map of prepared nodes to original graph nodes
*/
prepared_source_mapping?: {
prepared_source_mapping: {
[key: string]: string;
};
/**
* Source Prepared Mapping
* @description The map of original graph nodes to prepared nodes
*/
source_prepared_mapping?: {
source_prepared_mapping: {
[key: string]: string[];
};
};
@@ -19117,7 +19137,10 @@ export type components = {
*/
total: number;
};
/** SessionQueueItem */
/**
* SessionQueueItem
* @description Session queue item without the full graph. Used for serialization.
*/
SessionQueueItem: {
/**
* Item Id
@@ -19218,16 +19241,6 @@ export type components = {
* @description The ID of the published workflow associated with this queue item
*/
published_workflow_id?: string | null;
/**
* Api Input Fields
* @description The fields that were used as input to the API
*/
api_input_fields?: components["schemas"]["FieldIdentifier"][] | null;
/**
* Api Output Fields
* @description The nodes that were used as output from the API
*/
api_output_fields?: components["schemas"]["FieldIdentifier"][] | null;
/**
* Credits
* @description The total credits used for this queue item
@@ -19238,123 +19251,6 @@ export type components = {
/** @description The workflow associated with this queue item */
workflow?: components["schemas"]["WorkflowWithoutID"] | null;
};
/** SessionQueueItemDTO */
SessionQueueItemDTO: {
/**
* Item Id
* @description The identifier of the session queue item
*/
item_id: number;
/**
* Status
* @description The status of this queue item
* @default pending
* @enum {string}
*/
status: "pending" | "in_progress" | "completed" | "failed" | "canceled";
/**
* Priority
* @description The priority of this queue item
* @default 0
*/
priority: number;
/**
* Batch Id
* @description The ID of the batch associated with this queue item
*/
batch_id: string;
/**
* Origin
* @description The origin of this queue item. This data is used by the frontend to determine how to handle results.
*/
origin?: string | null;
/**
* Destination
* @description The origin of this queue item. This data is used by the frontend to determine how to handle results
*/
destination?: string | null;
/**
* Session Id
* @description The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed.
*/
session_id: string;
/**
* Error Type
* @description The error type if this queue item errored
*/
error_type?: string | null;
/**
* Error Message
* @description The error message if this queue item errored
*/
error_message?: string | null;
/**
* Error Traceback
* @description The error traceback if this queue item errored
*/
error_traceback?: string | null;
/**
* Created At
* @description When this queue item was created
*/
created_at: string;
/**
* Updated At
* @description When this queue item was updated
*/
updated_at: string;
/**
* Started At
* @description When this queue item was started
*/
started_at?: string | null;
/**
* Completed At
* @description When this queue item was completed
*/
completed_at?: string | null;
/**
* Queue Id
* @description The id of the queue with which this item is associated
*/
queue_id: string;
/**
* Field Values
* @description The field values that were used for this queue item
*/
field_values?: components["schemas"]["NodeFieldValue"][] | null;
/**
* Retried From Item Id
* @description The item_id of the queue item that this item was retried from
*/
retried_from_item_id?: number | null;
/**
* Is Api Validation Run
* @description Whether this queue item is an API validation run.
* @default false
*/
is_api_validation_run?: boolean;
/**
* Published Workflow Id
* @description The ID of the published workflow associated with this queue item
*/
published_workflow_id?: string | null;
/**
* Api Input Fields
* @description The fields that were used as input to the API
*/
api_input_fields?: components["schemas"]["FieldIdentifier"][] | null;
/**
* Api Output Fields
* @description The nodes that were used as output from the API
*/
api_output_fields?: components["schemas"]["FieldIdentifier"][] | null;
/**
* Credits
* @description The total credits used for this queue item
*/
credits?: number | null;
};
/** SessionQueueStatus */
SessionQueueStatus: {
/**
@@ -24476,6 +24372,8 @@ export interface operations {
cursor?: number | null;
/** @description The pagination cursor priority */
priority?: number;
/** @description The destination of queue items to fetch */
destination?: string | null;
};
header?: never;
path: {
@@ -24492,7 +24390,44 @@ export interface operations {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["CursorPaginatedResults_SessionQueueItemDTO_"];
"application/json": components["schemas"]["CursorPaginatedResults_SessionQueueItem_"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
list_all_queue_items: {
parameters: {
query?: {
/** @description The status of items to fetch */
status?: ("pending" | "in_progress" | "completed" | "failed" | "canceled") | null;
/** @description The destination of queue items to fetch */
destination?: string | null;
};
header?: never;
path: {
/** @description The queue id to perform this operation on */
queue_id: string;
};
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["SessionQueueItem"][];
};
};
/** @description Validation Error */

View File

@@ -291,7 +291,6 @@ export type ModelInstallStatus = S['InstallStatus'];
export type Graph = S['Graph'];
export type NonNullableGraph = SetRequired<Graph, 'nodes' | 'edges'>;
export type Batch = S['Batch'];
export type SessionQueueItemDTO = S['SessionQueueItemDTO'];
export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy'];
export type SQLiteDirection = S['SQLiteDirection'];
export type WorkflowRecordListItemWithThumbnailDTO = S['WorkflowRecordListItemWithThumbnailDTO'];

View File

@@ -8,9 +8,7 @@ import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
import { $queueId } from 'app/store/nanostores/queueId';
import type { AppStore } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import {
stagingAreaGenerationStarted,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { stagingAreaGenerationStarted } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
$isInPublishFlow,
$outputNodeId,
@@ -25,7 +23,7 @@ import { forEach, isNil, round } from 'lodash-es';
import type { ApiTagDescription } from 'services/api';
import { api, LIST_TAG } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { queueApi } from 'services/api/endpoints/queue';
import { workflowsApi } from 'services/api/endpoints/workflows';
import { buildOnInvocationComplete } from 'services/events/onInvocationComplete';
import { buildOnModelInstallError } from 'services/events/onModelInstallError';
@@ -383,24 +381,24 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
log.debug({ data }, `Queue item ${item_id} status updated: ${status}`);
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
queueItemsAdapter.updateOne(draft, {
id: String(item_id),
changes: {
status,
started_at,
updated_at: updated_at ?? undefined,
completed_at: completed_at ?? undefined,
error_type,
error_message,
error_traceback,
credits,
},
});
})
);
// // Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
// dispatch(
// queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
// queueItemsAdapter.updateOne(draft, {
// id: String(item_id),
// changes: {
// status,
// started_at,
// updated_at: updated_at ?? undefined,
// completed_at: completed_at ?? undefined,
// error_type,
// error_message,
// error_traceback,
// credits,
// },
// });
// })
// );
// Optimistic update of the queue status. We prefer to do an optimistic update over tag invalidation due to the
// frequency of `queue_item_status_changed` events.
@@ -426,6 +424,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id },
{ type: 'SessionQueueItem', id: LIST_TAG },
];
if (destination) {
tagsToInvalidate.push({ type: 'QueueCountsByDestination', id: destination });