Compare commits

...

48 Commits

Author SHA1 Message Date
Mary Hipp
b5d7471326 add tests and readme 2025-09-17 13:53:08 -04:00
Mary Hipp
ae8bb9a9a7 enqueue refactor 2025-09-17 11:40:11 -04:00
psychedelicious
efcb1bea7f chore: bump version to v6.8.0rc1 2025-09-17 13:57:43 +10:00
psychedelicious
e0d7a401f3 feat(ui): make ref images croppable 2025-09-17 13:43:13 +10:00
psychedelicious
aac979e9a4 fix(ui): issue w/ setting initial aspect ratio in cropper 2025-09-17 13:43:13 +10:00
psychedelicious
3b0d7f076d tidy(ui): rename from "editor" to "cropper", minor cleanup 2025-09-17 13:43:13 +10:00
psychedelicious
e1acbcdbd5 fix(ui): store floats for box 2025-09-17 13:43:13 +10:00
psychedelicious
7d9b81550b feat(ui): revert to original image when crop discarded 2025-09-17 13:43:13 +10:00
psychedelicious
6a447dd1fe refactor(ui): remove "apply", "start" and "cancel" concepts from editor 2025-09-17 13:43:13 +10:00
psychedelicious
c2dc63ddbc fix(ui): video graphs 2025-09-17 13:43:13 +10:00
psychedelicious
1bc689d531 docs(ui): add comments to startingframeimage 2025-09-17 13:43:13 +10:00
psychedelicious
4829975827 feat(ui): make the editor components not care about the image 2025-09-17 13:43:13 +10:00
psychedelicious
49da4e00c3 feat(ui): add concept for editable image state 2025-09-17 13:43:13 +10:00
psychedelicious
89dfe5e729 docs(ui): add comments to editor 2025-09-17 13:43:13 +10:00
psychedelicious
6816d366df tidy(ui): editor misc 2025-09-17 13:43:13 +10:00
psychedelicious
9d3d2a36c9 tidy(ui): editor listeners 2025-09-17 13:43:13 +10:00
psychedelicious
ed231044c8 refactor(ui): simplify crop constraints 2025-09-17 13:43:13 +10:00
psychedelicious
b51a232794 feat(ui): extract config to own obj 2025-09-17 13:43:13 +10:00
psychedelicious
4412143a6e feat(ui): clean up editor 2025-09-17 13:43:13 +10:00
psychedelicious
de11cafdb3 refactor(ui): editor (wip) 2025-09-17 13:43:13 +10:00
psychedelicious
4d9114aa7d refactor(ui): editor (wip) 2025-09-17 13:43:13 +10:00
psychedelicious
67e2da1ebf refactor(ui): editor (wip) 2025-09-17 13:43:13 +10:00
psychedelicious
33ecc591c3 refactor(ui): editor init 2025-09-17 13:43:13 +10:00
psychedelicious
b57459a226 chore(ui): lint 2025-09-17 13:43:13 +10:00
psychedelicious
01282b1c90 feat(ui): do not clear crop when canceling 2025-09-17 13:43:13 +10:00
psychedelicious
3f302906dc feat(ui): crop doesn't hide outside cropped region 2025-09-17 13:43:13 +10:00
psychedelicious
81d56596fb tidy(ui): cleanup 2025-09-17 13:43:13 +10:00
psychedelicious
b536b0df0c feat(ui): misc iterate on editor 2025-09-17 13:43:13 +10:00
psychedelicious
692af1d93d feat(ui): type narrowing for editor output types 2025-09-17 13:43:13 +10:00
psychedelicious
bb7ef77b50 tidy(ui): lint/react conventions for editor component 2025-09-17 13:43:13 +10:00
psychedelicious
1862548573 feat(ui): image editor bg checkerboard pattern 2025-09-17 13:43:13 +10:00
psychedelicious
242c1b6350 feat(ui): tweak editor konva styles 2025-09-17 13:43:13 +10:00
psychedelicious
fc6e4bb04e tidy(ui): editor component cleanup 2025-09-17 13:43:13 +10:00
psychedelicious
20841abca6 tidy(ui): editor cleanup 2025-09-17 13:43:13 +10:00
psychedelicious
e8b69d99a4 chore(ui): lint 2025-09-17 13:43:13 +10:00
Mary Hipp
d6eaff8237 create editImageModal that takes an imageDTO, loads blob onto canvas, and allows cropping. cropped blob is uploaded as new asset 2025-09-17 13:43:13 +10:00
Mary Hipp
068b095956 show warning state with tooltip if starting frame image aspect ratio does not match the video output aspect ratio' 2025-09-17 13:43:13 +10:00
psychedelicious
f795a47340 tidy(ui): remove unused translation string 2025-09-16 15:04:03 +10:00
psychedelicious
df47345eb0 feat(ui): add translation strings for prompt history 2025-09-16 15:04:03 +10:00
psychedelicious
def04095a4 feat(ui): tweak prompt history styling 2025-09-16 15:04:03 +10:00
psychedelicious
28be8f0911 refactor(ui): simplify prompt history shortcuts 2025-09-16 15:04:03 +10:00
Kent Keirsey
b50c44bac0 handle potential for invalid list item 2025-09-16 15:04:03 +10:00
Kent Keirsey
b4ce0e02fc lint 2025-09-16 15:04:03 +10:00
Kent Keirsey
d6442d9a34 Prompt history shortcuts 2025-09-16 15:04:03 +10:00
Josh Corbett
4528bcafaf feat(model manager): add ModelFooter component and reusable ModelDeleteButton 2025-09-16 12:29:57 +10:00
Josh Corbett
8b82b81ee2 fix(ModelImage): change MODEL_IMAGE_THUMBNAIL_SIZE to a local constant 2025-09-16 12:29:57 +10:00
Josh Corbett
757acdd49e feat(model manager): 💄 update model manager ui, initial commit 2025-09-16 12:29:57 +10:00
psychedelicious
94b7cc583a fix(ui): do not reset params state on studio init nav to generate tab 2025-09-16 12:25:25 +10:00
79 changed files with 3674 additions and 740 deletions

View File

@@ -104,6 +104,7 @@
"copy": "Copy",
"copyError": "$t(gallery.copy) Error",
"clipboard": "Clipboard",
"crop": "Crop",
"on": "On",
"off": "Off",
"or": "or",
@@ -242,7 +243,10 @@
"resultSubtitle": "Choose how to handle the expanded prompt:",
"replace": "Replace",
"insert": "Insert",
"discard": "Discard"
"discard": "Discard",
"noPromptHistory": "No prompt history recorded.",
"noMatchingPrompts": "No matching prompts in history.",
"toSwitchBetweenPrompts": "to switch between prompts."
},
"queue": {
"queue": "Queue",
@@ -480,6 +484,14 @@
"title": "Focus Prompt",
"desc": "Move cursor focus to the positive prompt."
},
"promptHistoryPrev": {
"title": "Previous Prompt in History",
"desc": "When the prompt is focused, move to the previous (older) prompt in your history."
},
"promptHistoryNext": {
"title": "Next Prompt in History",
"desc": "When the prompt is focused, move to the next (newer) prompt in your history."
},
"toggleLeftPanel": {
"title": "Toggle Left Panel",
"desc": "Show or hide the left panel."
@@ -1258,6 +1270,7 @@
"infillColorValue": "Fill Color",
"info": "Info",
"startingFrameImage": "Start Frame",
"startingFrameImageAspectRatioWarning": "Image aspect ratio does not match the video aspect ratio ({{videoAspectRatio}}). This could lead to unexpected cropping during video generation.",
"invoke": {
"addingImagesTo": "Adding images to",
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your account settings to upgrade.",

View File

@@ -2,6 +2,7 @@ import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { CropImageModal } from 'features/cropper/components/CropImageModal';
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
import { DeleteVideoModal } from 'features/deleteVideoModal/components/DeleteVideoModal';
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
@@ -58,6 +59,7 @@ export const GlobalModalIsolator = memo(() => {
<CanvasPasteModal />
</CanvasManagerProviderGate>
<LoadWorkflowFromGraphModal />
<CropImageModal />
</>
);
});

View File

@@ -4,7 +4,6 @@ import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { withResultAsync } from 'common/util/result';
import { canvasReset } from 'features/controlLayers/store/actions';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { sentImageToCanvas } from 'features/gallery/store/actions';
@@ -164,7 +163,6 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
case 'generation':
// Go to the generate tab, open the launchpad
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
store.dispatch(paramsReset());
break;
case 'canvas':
// Go to the canvas tab, open the launchpad

View File

@@ -27,6 +27,7 @@ export const zLogNamespace = z.enum([
'queue',
'workflows',
'video',
'enqueue',
]);
export type LogNamespace = z.infer<typeof zLogNamespace>;

View File

@@ -12,7 +12,13 @@ import {
} from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier, isFLUXReduxConfig, isIPAdapterConfig } from 'features/controlLayers/store/types';
import {
getEntityIdentifier,
isFLUXReduxConfig,
isIPAdapterConfig,
isRegionalGuidanceFLUXReduxConfig,
isRegionalGuidanceIPAdapterConfig,
} from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { modelSelected } from 'features/parameters/store/actions';
import {
@@ -252,7 +258,7 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
if (!isIPAdapterConfig(config)) {
if (!isRegionalGuidanceIPAdapterConfig(config)) {
return;
}
@@ -295,7 +301,7 @@ const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
if (!isFLUXReduxConfig(config)) {
if (!isRegionalGuidanceFLUXReduxConfig(config)) {
return;
}

View File

@@ -1,12 +1,16 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { objectEquals } from '@observ33r/object-equals';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import { bboxSizeOptimized, bboxSizeRecalled } from 'features/controlLayers/store/canvasSlice';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { sizeOptimized, sizeRecalled } from 'features/controlLayers/store/paramsSlice';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import type { CroppableImageWithDims } from 'features/controlLayers/store/types';
import { imageDTOToCroppableImage, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { Editor } from 'features/cropper/lib/editor';
import { cropImageModalApi } from 'features/cropper/store';
import type { setGlobalReferenceImageDndTarget, setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
@@ -14,14 +18,14 @@ import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PiArrowCounterClockwiseBold, PiCropBold, PiRulerBold } from 'react-icons/pi';
import { useGetImageDTOQuery, useUploadImageMutation } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
type Props<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegionalGuidanceReferenceImageDndTarget> = {
image: ImageWithDims | null;
onChangeImage: (imageDTO: ImageDTO | null) => void;
image: CroppableImageWithDims | null;
onChangeImage: (croppableImage: CroppableImageWithDims | null) => void;
dndTarget: T;
dndTargetData: ReturnType<T['getData']>;
};
@@ -38,20 +42,28 @@ export const RefImageImage = memo(
const isConnected = useStore($isConnected);
const tab = useAppSelector(selectActiveTab);
const isStaging = useCanvasIsStaging();
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
const imageWithDims = image?.crop?.image ?? image?.original.image ?? null;
const croppedImageDTOReq = useGetImageDTOQuery(image?.crop?.image?.image_name ?? skipToken);
const originalImageDTOReq = useGetImageDTOQuery(image?.original.image.image_name ?? skipToken);
const [uploadImage] = useUploadImageMutation();
const originalImageDTO = originalImageDTOReq.currentData;
const croppedImageDTO = croppedImageDTOReq.currentData;
const imageDTO = croppedImageDTO ?? originalImageDTO;
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
useEffect(() => {
if (isConnected && isError) {
if ((isConnected && croppedImageDTOReq.isError) || originalImageDTOReq.isError) {
handleResetControlImage();
}
}, [handleResetControlImage, isError, isConnected]);
}, [handleResetControlImage, isConnected, croppedImageDTOReq.isError, originalImageDTOReq.isError]);
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
onChangeImage(imageDTO);
onChangeImage(imageDTOToCroppableImage(imageDTO));
},
[onChangeImage]
);
@@ -70,13 +82,67 @@ export const RefImageImage = memo(
}
}, [imageDTO, isStaging, store, tab]);
const edit = useCallback(() => {
if (!originalImageDTO) {
return;
}
// We will create a new editor instance each time the user wants to edit
const editor = new Editor();
// When the user applies the crop, we will upload the cropped image and store the applied crop box so if the user
// re-opens the editor they see the same crop
const onApplyCrop = async () => {
const box = editor.getCropBox();
if (objectEquals(box, image?.crop?.box)) {
// If the box hasn't changed, don't do anything
return;
}
if (!box || objectEquals(box, { x: 0, y: 0, width: originalImageDTO.width, height: originalImageDTO.height })) {
// There is a crop applied but it is the whole iamge - revert to original image
onChangeImage(imageDTOToCroppableImage(originalImageDTO));
return;
}
const blob = await editor.exportImage('blob');
const file = new File([blob], 'image.png', { type: 'image/png' });
const newCroppedImageDTO = await uploadImage({
file,
is_intermediate: true,
image_category: 'user',
}).unwrap();
onChangeImage(
imageDTOToCroppableImage(originalImageDTO, {
image: imageDTOToImageWithDims(newCroppedImageDTO),
box,
ratio: editor.getCropAspectRatio(),
})
);
};
const onReady = async () => {
const initial = image?.crop ? { cropBox: image.crop.box, aspectRatio: image.crop.ratio } : undefined;
// Load the image into the editor and open the modal once it's ready
await editor.loadImage(originalImageDTO.image_url, initial);
};
cropImageModalApi.open({ editor, onApplyCrop, onReady });
}, [image?.crop, onChangeImage, originalImageDTO, uploadImage]);
return (
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
<Flex
position="relative"
w="full"
h="full"
alignItems="center"
data-error={!imageDTO && !imageWithDims?.image_name}
>
{!imageDTO && (
<UploadImageIconButton
w="full"
h="full"
isError={!imageDTO && !image?.image_name}
isError={!imageDTO && !imageWithDims?.image_name}
onUpload={onUpload}
fontSize={36}
/>
@@ -99,6 +165,15 @@ export const RefImageImage = memo(
isDisabled={!imageDTO || (tab === 'canvas' && isStaging)}
/>
</Flex>
<Flex position="absolute" flexDir="column" top={2} insetInlineStart={2} gap={1}>
<DndImageIcon
onClick={edit}
icon={<PiCropBold size={16} />}
tooltip={t('common.crop')}
isDisabled={!imageDTO}
/>
</Flex>
</>
)}
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />

View File

@@ -13,7 +13,7 @@ import {
selectRefImageEntityIds,
selectSelectedRefEntityId,
} from 'features/controlLayers/store/refImagesSlice';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
import { addGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
@@ -92,7 +92,7 @@ const AddRefImageDropTargetAndButton = memo(() => {
({
onUpload: (imageDTO: ImageDTO) => {
const config = getDefaultRefImageConfig(getState);
config.image = imageDTOToImageWithDims(imageDTO);
config.image = imageDTOToCroppableImage(imageDTO);
dispatch(refImageAdded({ overrides: { config } }));
},
allowMultiple: false,

View File

@@ -1,6 +1,5 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Icon, IconButton, Image, Skeleton, Text, Tooltip } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { round } from 'es-toolkit/compat';
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
@@ -15,7 +14,7 @@ import { isIPAdapterConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { useImageDTOFromCroppableImage } from 'services/api/endpoints/images';
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
@@ -72,7 +71,8 @@ export const RefImagePreview = memo(() => {
const selectedEntityId = useAppSelector(selectSelectedRefEntityId);
const isPanelOpen = useAppSelector(selectIsRefImagePanelOpen);
const [showWeightDisplay, setShowWeightDisplay] = useState(false);
const { data: imageDTO } = useGetImageDTOQuery(entity.config.image?.image_name ?? skipToken);
const imageDTO = useImageDTOFromCroppableImage(entity.config.image);
const sx = useMemo(() => {
if (!isIPAdapterConfig(entity.config)) {
@@ -145,7 +145,7 @@ export const RefImagePreview = memo(() => {
overflow="hidden"
>
<Image
src={imageDTO?.thumbnail_url}
src={imageDTO?.image_url}
objectFit="contain"
aspectRatio="1/1"
height={imageDTO?.height}

View File

@@ -30,6 +30,7 @@ import {
} from 'features/controlLayers/store/refImagesSlice';
import type {
CLIPVisionModelV2,
CroppableImageWithDims,
FLUXReduxImageInfluence as FLUXReduxImageInfluenceType,
IPMethodV2,
} from 'features/controlLayers/store/types';
@@ -42,7 +43,6 @@ import type {
ChatGPT4oModelConfig,
FLUXKontextModelConfig,
FLUXReduxModelConfig,
ImageDTO,
IPAdapterModelConfig,
} from 'services/api/types';
@@ -104,15 +104,19 @@ const RefImageSettingsContent = memo(() => {
);
const onChangeImage = useCallback(
(imageDTO: ImageDTO | null) => {
dispatch(refImageImageChanged({ id, imageDTO }));
(croppableImage: CroppableImageWithDims | null) => {
dispatch(refImageImageChanged({ id, croppableImage }));
},
[dispatch, id]
);
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
() => setGlobalReferenceImageDndTarget.getData({ id }, config.image?.image_name),
[id, config.image?.image_name]
() =>
setGlobalReferenceImageDndTarget.getData(
{ id },
config.image?.crop?.image.image_name ?? config.image?.original.image.image_name
),
[id, config.image?.crop?.image.image_name, config.image?.original.image.image_name]
);
const isFLUX = useAppSelector(selectIsFLUX);

View File

@@ -6,7 +6,6 @@ import { FLUXReduxImageInfluence } from 'features/controlLayers/components/commo
import { IPAdapterCLIPVisionModel } from 'features/controlLayers/components/common/IPAdapterCLIPVisionModel';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { IPAdapterMethod } from 'features/controlLayers/components/RefImage/IPAdapterMethod';
import { RefImageImage } from 'features/controlLayers/components/RefImage/RefImageImage';
import { RegionalGuidanceIPAdapterSettingsEmptyState } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState';
import { RegionalReferenceImageModel } from 'features/controlLayers/components/RegionalGuidance/RegionalReferenceImageModel';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -37,6 +36,8 @@ import { PiBoundingBoxBold, PiXBold } from 'react-icons/pi';
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { RegionalGuidanceRefImageImage } from './RegionalGuidanceRefImageImage';
type Props = {
referenceImageId: string;
};
@@ -114,7 +115,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
{ entityIdentifier, referenceImageId },
config.image?.image_name
),
[entityIdentifier, config.image?.image_name, referenceImageId]
[entityIdentifier, config.image, referenceImageId]
);
const pullBboxIntoIPAdapter = usePullBboxIntoRegionalGuidanceReferenceImage(entityIdentifier, referenceImageId);
@@ -170,7 +171,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
</Flex>
)}
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
<RefImageImage
<RegionalGuidanceRefImageImage
image={config.image}
onChangeImage={onChangeImage}
dndTarget={setRegionalGuidanceReferenceImageDndTarget}

View File

@@ -0,0 +1,103 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import { bboxSizeOptimized, bboxSizeRecalled } from 'features/controlLayers/store/canvasSlice';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { sizeOptimized, sizeRecalled } from 'features/controlLayers/store/paramsSlice';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import type { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
type Props = {
image: ImageWithDims | null;
onChangeImage: (imageDTO: ImageDTO | null) => void;
dndTarget: typeof setRegionalGuidanceReferenceImageDndTarget;
dndTargetData: ReturnType<(typeof setRegionalGuidanceReferenceImageDndTarget)['getData']>;
};
export const RegionalGuidanceRefImageImage = memo(({ image, onChangeImage, dndTarget, dndTargetData }: Props) => {
const { t } = useTranslation();
const store = useAppStore();
const isConnected = useStore($isConnected);
const tab = useAppSelector(selectActiveTab);
const isStaging = useCanvasIsStaging();
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
}, [onChangeImage]);
useEffect(() => {
if (isConnected && isError) {
handleResetControlImage();
}
}, [handleResetControlImage, isError, isConnected]);
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
onChangeImage(imageDTO);
},
[onChangeImage]
);
const recallSizeAndOptimize = useCallback(() => {
if (!imageDTO || (tab === 'canvas' && isStaging)) {
return;
}
const { width, height } = imageDTO;
if (tab === 'canvas') {
store.dispatch(bboxSizeRecalled({ width, height }));
store.dispatch(bboxSizeOptimized());
} else if (tab === 'generate') {
store.dispatch(sizeRecalled({ width, height }));
store.dispatch(sizeOptimized());
}
}, [imageDTO, isStaging, store, tab]);
return (
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
{!imageDTO && (
<UploadImageIconButton
w="full"
h="full"
isError={!imageDTO && !image?.image_name}
onUpload={onUpload}
fontSize={36}
/>
)}
{imageDTO && (
<>
<DndImage imageDTO={imageDTO} borderRadius="base" borderWidth={1} borderStyle="solid" w="full" />
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('common.reset')}
/>
</Flex>
<Flex position="absolute" flexDir="column" bottom={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={recallSizeAndOptimize}
icon={<PiRulerBold size={16} />}
tooltip={t('parameters.useSize')}
isDisabled={!imageDTO || (tab === 'canvas' && isStaging)}
/>
</Flex>
</>
)}
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />
</Flex>
);
});
RegionalGuidanceRefImageImage.displayName = 'RegionalGuidanceRefImageImage';

View File

@@ -30,6 +30,7 @@ import type {
FluxKontextReferenceImageConfig,
Gemini2_5ReferenceImageConfig,
IPAdapterConfig,
RegionalGuidanceIPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import {
@@ -38,6 +39,7 @@ import {
initialFluxKontextReferenceImage,
initialGemini2_5ReferenceImage,
initialIPAdapter,
initialRegionalGuidanceIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
@@ -125,7 +127,7 @@ export const getDefaultRefImageConfig = (
return config;
};
export const getDefaultRegionalGuidanceRefImageConfig = (getState: AppGetState): IPAdapterConfig => {
export const getDefaultRegionalGuidanceRefImageConfig = (getState: AppGetState): RegionalGuidanceIPAdapterConfig => {
// Regional guidance ref images do not support ChatGPT-4o, so we always return the IP Adapter config.
const state = getState();
@@ -138,7 +140,7 @@ export const getDefaultRegionalGuidanceRefImageConfig = (getState: AppGetState):
const modelConfig = ipAdapterModelConfigs.find((m) => m.base === base);
// Clone the initial IP Adapter config and set the model if available.
const config = deepClone(initialIPAdapter);
const config = deepClone(initialRegionalGuidanceIPAdapter);
if (modelConfig) {
config.model = zModelIdentifierField.parse(modelConfig);

View File

@@ -32,7 +32,12 @@ import type {
RefImageState,
RegionalGuidanceRefImageState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import {
imageDTOToCroppableImage,
imageDTOToImageObject,
imageDTOToImageWithDims,
initialControlNet,
} from 'features/controlLayers/store/util';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import type { BoardId } from 'features/gallery/store/types';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
@@ -209,7 +214,7 @@ export const useNewGlobalReferenceImageFromBbox = () => {
const overrides: Partial<RefImageState> = {
config: {
...getDefaultRefImageConfig(getState),
image: imageDTOToImageWithDims(imageDTO),
image: imageDTOToCroppableImage(imageDTO),
},
};
dispatch(refImageAdded({ overrides }));
@@ -312,7 +317,7 @@ export const usePullBboxIntoGlobalReferenceImage = (id: string) => {
const arg = useMemo<UseSaveCanvasArg>(() => {
const onSave = (imageDTO: ImageDTO, _: Rect) => {
dispatch(refImageImageChanged({ id, imageDTO }));
dispatch(refImageImageChanged({ id, croppableImage: imageDTOToCroppableImage(imageDTO) }));
};
return {

View File

@@ -82,10 +82,10 @@ import {
IMAGEN_ASPECT_RATIOS,
isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID,
isFLUXReduxConfig,
isGemini2_5AspectRatioID,
isImagenAspectRatioID,
isIPAdapterConfig,
isRegionalGuidanceFLUXReduxConfig,
isRegionalGuidanceIPAdapterConfig,
zCanvasState,
} from './types';
import {
@@ -99,6 +99,7 @@ import {
initialControlNet,
initialFLUXRedux,
initialIPAdapter,
initialRegionalGuidanceIPAdapter,
initialT2IAdapter,
makeDefaultRasterLayerAdjustments,
} from './util';
@@ -804,7 +805,7 @@ const slice = createSlice({
if (!entity) {
return;
}
const config = { id: referenceImageId, config: deepClone(initialIPAdapter) };
const config = { id: referenceImageId, config: deepClone(initialRegionalGuidanceIPAdapter) };
merge(config, overrides);
entity.referenceImages.push(config);
},
@@ -847,7 +848,7 @@ const slice = createSlice({
if (!referenceImage) {
return;
}
if (!isIPAdapterConfig(referenceImage.config)) {
if (!isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
return;
}
@@ -864,7 +865,7 @@ const slice = createSlice({
if (!referenceImage) {
return;
}
if (!isIPAdapterConfig(referenceImage.config)) {
if (!isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
return;
}
referenceImage.config.beginEndStepPct = beginEndStepPct;
@@ -880,7 +881,7 @@ const slice = createSlice({
if (!referenceImage) {
return;
}
if (!isIPAdapterConfig(referenceImage.config)) {
if (!isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
return;
}
referenceImage.config.method = method;
@@ -899,7 +900,7 @@ const slice = createSlice({
if (!referenceImage) {
return;
}
if (!isFLUXReduxConfig(referenceImage.config)) {
if (!isRegionalGuidanceFLUXReduxConfig(referenceImage.config)) {
return;
}
@@ -928,7 +929,7 @@ const slice = createSlice({
return;
}
if (isIPAdapterConfig(referenceImage.config) && isFluxReduxModelConfig(modelConfig)) {
if (isRegionalGuidanceIPAdapterConfig(referenceImage.config) && isFluxReduxModelConfig(modelConfig)) {
// Switching from ip_adapter to flux_redux
referenceImage.config = {
...initialFLUXRedux,
@@ -938,7 +939,7 @@ const slice = createSlice({
return;
}
if (isFLUXReduxConfig(referenceImage.config) && isIPAdapterModelConfig(modelConfig)) {
if (isRegionalGuidanceFLUXReduxConfig(referenceImage.config) && isIPAdapterModelConfig(modelConfig)) {
// Switching from flux_redux to ip_adapter
referenceImage.config = {
...initialIPAdapter,
@@ -948,7 +949,7 @@ const slice = createSlice({
return;
}
if (isIPAdapterConfig(referenceImage.config)) {
if (isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
referenceImage.config.model = zModelIdentifierField.parse(modelConfig);
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
@@ -971,7 +972,7 @@ const slice = createSlice({
if (!referenceImage) {
return;
}
if (!isIPAdapterConfig(referenceImage.config)) {
if (!isRegionalGuidanceIPAdapterConfig(referenceImage.config)) {
return;
}
referenceImage.config.clipVisionModel = clipVisionModel;

View File

@@ -199,11 +199,7 @@ const slice = createSlice({
return;
}
if (state.positivePromptHistory.includes(prompt)) {
return;
}
state.positivePromptHistory.unshift(prompt);
state.positivePromptHistory = [prompt, ...state.positivePromptHistory.filter((p) => p !== prompt)];
if (state.positivePromptHistory.length > MAX_POSITIVE_PROMPT_HISTORY) {
state.positivePromptHistory = state.positivePromptHistory.slice(0, MAX_POSITIVE_PROMPT_HISTORY);

View File

@@ -6,13 +6,16 @@ import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { clamp } from 'es-toolkit/compat';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { FLUXReduxImageInfluence, RefImagesState } from 'features/controlLayers/store/types';
import type {
CroppableImageWithDims,
FLUXReduxImageInfluence,
RefImagesState,
} from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type {
ChatGPT4oModelConfig,
FLUXKontextModelConfig,
FLUXReduxModelConfig,
ImageDTO,
IPAdapterModelConfig,
} from 'services/api/types';
import { assert } from 'tsafe';
@@ -22,7 +25,6 @@ import type { CLIPVisionModelV2, IPMethodV2, RefImageState } from './types';
import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig, zRefImagesState } from './types';
import {
getReferenceImageState,
imageDTOToImageWithDims,
initialChatGPT4oReferenceImage,
initialFluxKontextReferenceImage,
initialFLUXRedux,
@@ -65,13 +67,13 @@ const slice = createSlice({
state.entities.push(...entities);
}
},
refImageImageChanged: (state, action: PayloadActionWithId<{ imageDTO: ImageDTO | null }>) => {
const { id, imageDTO } = action.payload;
refImageImageChanged: (state, action: PayloadActionWithId<{ croppableImage: CroppableImageWithDims | null }>) => {
const { id, croppableImage } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
entity.config.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
entity.config.image = croppableImage;
},
refImageIPAdapterMethodChanged: (state, action: PayloadActionWithId<{ method: IPMethodV2 }>) => {
const { id, method } = action.payload;

View File

@@ -37,6 +37,45 @@ export const zImageWithDims = z.object({
});
export type ImageWithDims = z.infer<typeof zImageWithDims>;
const zCropBox = z.object({
x: z.number().min(0),
y: z.number().min(0),
width: z.number().positive(),
height: z.number().positive(),
});
// This new schema is an extension of zImageWithDims, with an optional crop field.
//
// When we added cropping support to certain entities (e.g. Ref Images, video Starting Frame Image), we changed
// their schemas from using zImageWithDims to this new schema. To support loading pre-existing entities that
// were created before cropping was supported, we can use zod's preprocess to transform old data into the new format.
// Its essentially a data migration step.
//
// This parsing happens currently in two places:
// - Recalling metadata.
// - Loading/rehydrating persisted client state from storage.
export const zCroppableImageWithDims = z.preprocess(
(val) => {
try {
const imageWithDims = zImageWithDims.parse(val);
const migrated = { original: { image: deepClone(imageWithDims) } };
return migrated;
} catch {
return val;
}
},
z.object({
original: z.object({ image: zImageWithDims }),
crop: z
.object({
box: zCropBox,
ratio: z.number().gt(0).nullable(),
image: zImageWithDims,
})
.optional(),
})
);
export type CroppableImageWithDims = z.infer<typeof zCroppableImageWithDims>;
const zImageWithDimsDataURL = z.object({
dataURL: z.string(),
width: z.number().int().positive(),
@@ -235,7 +274,7 @@ export type CanvasObjectState = z.infer<typeof zCanvasObjectState>;
const zIPAdapterConfig = z.object({
type: z.literal('ip_adapter'),
image: zImageWithDims.nullable(),
image: zCroppableImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
@@ -244,21 +283,39 @@ const zIPAdapterConfig = z.object({
});
export type IPAdapterConfig = z.infer<typeof zIPAdapterConfig>;
const zRegionalGuidanceIPAdapterConfig = z.object({
type: z.literal('ip_adapter'),
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
method: zIPMethodV2,
clipVisionModel: zCLIPVisionModelV2,
});
export type RegionalGuidanceIPAdapterConfig = z.infer<typeof zRegionalGuidanceIPAdapterConfig>;
const zFLUXReduxImageInfluence = z.enum(['lowest', 'low', 'medium', 'high', 'highest']);
export const isFLUXReduxImageInfluence = (v: unknown): v is FLUXReduxImageInfluence =>
zFLUXReduxImageInfluence.safeParse(v).success;
export type FLUXReduxImageInfluence = z.infer<typeof zFLUXReduxImageInfluence>;
const zFLUXReduxConfig = z.object({
type: z.literal('flux_redux'),
image: zImageWithDims.nullable(),
image: zCroppableImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
imageInfluence: zFLUXReduxImageInfluence.default('highest'),
});
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
const zRegionalGuidanceFLUXReduxConfig = z.object({
type: z.literal('flux_redux'),
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
imageInfluence: zFLUXReduxImageInfluence.default('highest'),
});
type RegionalGuidanceFLUXReduxConfig = z.infer<typeof zRegionalGuidanceFLUXReduxConfig>;
const zChatGPT4oReferenceImageConfig = z.object({
type: z.literal('chatgpt_4o_reference_image'),
image: zImageWithDims.nullable(),
image: zCroppableImageWithDims.nullable(),
/**
* TODO(psyche): Technically there is no model for ChatGPT 4o reference images - it's just a field in the API call.
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
@@ -270,14 +327,14 @@ export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceIm
const zGemini2_5ReferenceImageConfig = z.object({
type: z.literal('gemini_2_5_reference_image'),
image: zImageWithDims.nullable(),
image: zCroppableImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
});
export type Gemini2_5ReferenceImageConfig = z.infer<typeof zGemini2_5ReferenceImageConfig>;
const zFluxKontextReferenceImageConfig = z.object({
type: z.literal('flux_kontext_reference_image'),
image: zImageWithDims.nullable(),
image: zCroppableImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
});
export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>;
@@ -307,6 +364,7 @@ export const isIPAdapterConfig = (config: RefImageState['config']): config is IP
export const isFLUXReduxConfig = (config: RefImageState['config']): config is FLUXReduxConfig =>
config.type === 'flux_redux';
export const isChatGPT4oReferenceImageConfig = (
config: RefImageState['config']
): config is ChatGPT4oReferenceImageConfig => config.type === 'chatgpt_4o_reference_image';
@@ -326,10 +384,18 @@ const zFill = z.object({ style: zFillStyle, color: zRgbColor });
const zRegionalGuidanceRefImageState = z.object({
id: zId,
config: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
config: z.discriminatedUnion('type', [zRegionalGuidanceIPAdapterConfig, zRegionalGuidanceFLUXReduxConfig]),
});
export type RegionalGuidanceRefImageState = z.infer<typeof zRegionalGuidanceRefImageState>;
export const isRegionalGuidanceIPAdapterConfig = (
config: RegionalGuidanceRefImageState['config']
): config is RegionalGuidanceIPAdapterConfig => config.type === 'ip_adapter';
export const isRegionalGuidanceFLUXReduxConfig = (
config: RegionalGuidanceRefImageState['config']
): config is RegionalGuidanceFLUXReduxConfig => config.type === 'flux_redux';
const zCanvasRegionalGuidanceState = zCanvasEntityBase.extend({
type: z.literal('regional_guidance'),
position: zCoordinate,

View File

@@ -10,6 +10,7 @@ import type {
ChatGPT4oReferenceImageConfig,
ControlLoRAConfig,
ControlNetConfig,
CroppableImageWithDims,
FluxKontextReferenceImageConfig,
FLUXReduxConfig,
Gemini2_5ReferenceImageConfig,
@@ -17,6 +18,7 @@ import type {
IPAdapterConfig,
RasterLayerAdjustments,
RefImageState,
RegionalGuidanceIPAdapterConfig,
RgbColor,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
@@ -45,6 +47,21 @@ export const imageDTOToImageWithDims = ({ image_name, width, height }: ImageDTO)
height,
});
export const imageDTOToCroppableImage = (
originalImageDTO: ImageDTO,
crop?: CroppableImageWithDims['crop']
): CroppableImageWithDims => {
const { image_name, width, height } = originalImageDTO;
const val: CroppableImageWithDims = {
original: { image: { image_name, width, height } },
};
if (crop) {
val.crop = deepClone(crop);
}
return val;
};
export const imageDTOToImageField = ({ image_name }: ImageDTO): ImageField => ({ image_name });
const DEFAULT_RG_MASK_FILL_COLORS: RgbColor[] = [
@@ -79,6 +96,15 @@ export const initialIPAdapter: IPAdapterConfig = {
clipVisionModel: 'ViT-H',
weight: 1,
};
export const initialRegionalGuidanceIPAdapter: RegionalGuidanceIPAdapterConfig = {
type: 'ip_adapter',
image: null,
model: null,
beginEndStepPct: [0, 1],
method: 'full',
clipVisionModel: 'ViT-H',
weight: 1,
};
export const initialFLUXRedux: FLUXReduxConfig = {
type: 'flux_redux',
image: null,

View File

@@ -0,0 +1,215 @@
import {
Button,
ButtonGroup,
Divider,
Flex,
FormControl,
FormLabel,
Select,
Spacer,
Text,
} from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import type { AspectRatioID } from 'features/controlLayers/store/types';
import { ASPECT_RATIO_MAP, isAspectRatioID } from 'features/controlLayers/store/types';
import type { CropBox } from 'features/cropper/lib/editor';
import { cropImageModalApi, type CropImageModalState } from 'features/cropper/store';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import React, { memo, useCallback, useEffect, useRef, useState } from 'react';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import { objectEntries } from 'tsafe';
type Props = {
editor: CropImageModalState['editor'];
onApplyCrop: CropImageModalState['onApplyCrop'];
onReady: CropImageModalState['onReady'];
};
const getAspectRatioString = (ratio: number | null): AspectRatioID => {
if (!ratio) {
return 'Free';
}
const entries = objectEntries(ASPECT_RATIO_MAP);
for (const [key, value] of entries) {
if (value.ratio === ratio) {
return key;
}
}
return 'Free';
};
export const CropImageEditor = memo(({ editor, onApplyCrop, onReady }: Props) => {
const containerRef = useRef<HTMLDivElement>(null);
const [zoom, setZoom] = useState(100);
const [cropBox, setCropBox] = useState<CropBox | null>(null);
const [aspectRatio, setAspectRatio] = useState<string>('free');
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const [uploadImage] = useUploadImageMutation({ fixedCacheKey: 'editorContainer' });
const setup = useCallback(
async (container: HTMLDivElement) => {
editor.init(container);
editor.onZoomChange((zoom) => {
setZoom(zoom);
});
editor.onCropBoxChange((crop) => {
setCropBox(crop);
});
editor.onAspectRatioChange((ratio) => {
setAspectRatio(getAspectRatioString(ratio));
});
await onReady();
editor.fitToContainer();
},
[editor, onReady]
);
useEffect(() => {
const container = containerRef.current;
if (!container) {
return;
}
setup(container);
const handleResize = () => {
editor.resize(container.clientWidth, container.clientHeight);
};
const resizeObserver = new ResizeObserver(handleResize);
resizeObserver.observe(container);
return () => {
resizeObserver.disconnect();
};
}, [editor, setup]);
const handleAspectRatioChange = useCallback(
(e: React.ChangeEvent<HTMLSelectElement>) => {
const newRatio = e.target.value;
if (!isAspectRatioID(newRatio)) {
return;
}
setAspectRatio(newRatio);
if (newRatio === 'Free') {
editor.setCropAspectRatio(null);
} else {
editor.setCropAspectRatio(ASPECT_RATIO_MAP[newRatio]?.ratio ?? null);
}
},
[editor]
);
const handleResetCrop = useCallback(() => {
editor.resetCrop();
}, [editor]);
const handleApplyCrop = useCallback(async () => {
await onApplyCrop();
cropImageModalApi.close();
}, [onApplyCrop]);
const handleCancelCrop = useCallback(() => {
cropImageModalApi.close();
}, []);
const handleExport = useCallback(async () => {
try {
const blob = await editor.exportImage('blob');
const file = new File([blob], 'image.png', { type: 'image/png' });
await uploadImage({
file,
is_intermediate: false,
image_category: 'user',
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
}).unwrap();
} catch (err) {
if (err instanceof Error && err.message.includes('tainted')) {
alert(
'Cannot export image: The image is from a different domain (CORS issue). To fix this:\n\n1. Load images from the same domain\n2. Use images from CORS-enabled sources\n3. Upload a local image file instead'
);
} else {
alert(`Export failed: ${err instanceof Error ? err.message : String(err)}`);
}
}
}, [autoAddBoardId, editor, uploadImage]);
const zoomIn = useCallback(() => {
editor.zoomIn();
}, [editor]);
const zoomOut = useCallback(() => {
editor.zoomOut();
}, [editor]);
const fitToContainer = useCallback(() => {
editor.fitToContainer();
}, [editor]);
const resetView = useCallback(() => {
editor.resetView();
}, [editor]);
return (
<Flex w="full" h="full" flexDir="column" gap={4}>
<Flex gap={2} alignItems="center">
<FormControl flex={1}>
<FormLabel>Aspect Ratio:</FormLabel>
<Select size="sm" value={aspectRatio} onChange={handleAspectRatioChange} w={32}>
<option value="Free">Free</option>
<option value="16:9">16:9</option>
<option value="3:2">3:2</option>
<option value="4:3">4:3</option>
<option value="1:1">1:1</option>
<option value="3:4">3:4</option>
<option value="2:3">2:3</option>
<option value="9:16">9:16</option>
</Select>
</FormControl>
<Spacer />
<ButtonGroup size="sm" isAttached={false}>
<Button onClick={fitToContainer}>Fit View</Button>
<Button onClick={resetView}>Reset View</Button>
<Button onClick={zoomIn}>Zoom In</Button>
<Button onClick={zoomOut}>Zoom Out</Button>
</ButtonGroup>
<Spacer />
<ButtonGroup size="sm" isAttached={false}>
<Button onClick={handleApplyCrop}>Apply</Button>
<Button onClick={handleResetCrop}>Reset</Button>
<Button onClick={handleCancelCrop}>Cancel</Button>
<Button onClick={handleExport}>Save to Assets</Button>
</ButtonGroup>
</Flex>
<Flex position="relative" w="full" h="full" bg="base.900">
<Flex position="absolute" inset={0} ref={containerRef} />
</Flex>
<Flex gap={2} color="base.300">
<Text>Mouse wheel: Zoom</Text>
<Divider orientation="vertical" />
<Text>Space + Drag: Pan</Text>
<Divider orientation="vertical" />
<Text>Drag crop box or handles to adjust</Text>
{cropBox && (
<>
<Divider orientation="vertical" />
<Text>
X: {Math.round(cropBox.x)}, Y: {Math.round(cropBox.y)}, Width: {Math.round(cropBox.width)}, Height:{' '}
{Math.round(cropBox.height)}
</Text>
</>
)}
<Spacer key="help-spacer" />
<Text key="help-zoom">Zoom: {Math.round(zoom * 100)}%</Text>
</Flex>
</Flex>
);
});
CropImageEditor.displayName = 'CropImageEditor';

View File

@@ -0,0 +1,29 @@
import { Modal, ModalBody, ModalContent, ModalHeader, ModalOverlay } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { cropImageModalApi } from 'features/cropper/store';
import { memo } from 'react';
import { CropImageEditor } from './CropImageEditor';
export const CropImageModal = memo(() => {
const state = useStore(cropImageModalApi.$state);
if (!state) {
return null;
}
return (
// This modal is always open when this component is rendered
<Modal isOpen={true} onClose={cropImageModalApi.close} isCentered useInert={false} size="full">
<ModalOverlay />
<ModalContent minH="unset" minW="unset" maxH="90vh" maxW="90vw" w="full" h="full" borderRadius="base">
<ModalHeader>Crop Image</ModalHeader>
<ModalBody px={4} pb={4} pt={0}>
<CropImageEditor editor={state.editor} onApplyCrop={state.onApplyCrop} onReady={state.onReady} />
</ModalBody>
</ModalContent>
</Modal>
);
});
CropImageModal.displayName = 'CropImageModal';

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,26 @@
import type { Editor } from 'features/cropper/lib/editor';
import { atom } from 'nanostores';
export type CropImageModalState = {
editor: Editor;
onApplyCrop: () => Promise<void> | void;
onReady: () => Promise<void> | void;
};
const $state = atom<CropImageModalState | null>(null);
const open = (state: CropImageModalState) => {
$state.set(state);
};
const close = () => {
const state = $state.get();
state?.editor.destroy();
$state.set(null);
};
export const cropImageModalApi = {
$state,
open,
close,
};

View File

@@ -236,8 +236,11 @@ const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, image
const deleteReferenceImages = (state: RootState, dispatch: AppDispatch, image_name: string) => {
selectReferenceImageEntities(state).forEach((entity) => {
if (entity.config.image?.image_name === image_name) {
dispatch(refImageImageChanged({ id: entity.id, imageDTO: null }));
if (
entity.config.image?.original.image.image_name === image_name ||
entity.config.image?.crop?.image.image_name === image_name
) {
dispatch(refImageImageChanged({ id: entity.id, croppableImage: null }));
}
});
};
@@ -284,7 +287,10 @@ export const getImageUsage = (
const isUpscaleImage = upscale.upscaleInitialImage?.image_name === image_name;
const isReferenceImage = refImages.entities.some(({ config }) => config.image?.image_name === image_name);
const isReferenceImage = refImages.entities.some(
({ config }) =>
config.image?.original.image.image_name === image_name || config.image?.crop?.image.image_name === image_name
);
const isRasterLayerImage = canvas.rasterLayers.entities.some(({ objects }) =>
objects.some((obj) => obj.type === 'image' && 'image_name' in obj.image && obj.image.image_name === image_name)

View File

@@ -3,7 +3,7 @@ import { IconButton } from '@invoke-ai/ui-library';
import type { MouseEvent } from 'react';
import { memo } from 'react';
const sx: SystemStyleObject = {
export const imageButtonSx: SystemStyleObject = {
minW: 0,
svg: {
transitionProperty: 'common',
@@ -31,7 +31,7 @@ export const DndImageIcon = memo((props: Props) => {
aria-label={tooltip}
icon={icon}
variant="link"
sx={sx}
sx={imageButtonSx}
data-testid={tooltip}
{...rest}
/>

View File

@@ -4,7 +4,7 @@ import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerH
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { refImageAdded } from 'features/controlLayers/store/refImagesSlice';
import type { CanvasEntityIdentifier, CanvasEntityType } from 'features/controlLayers/store/types';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
import { selectComparisonImages } from 'features/gallery/components/ImageViewer/common';
import type { BoardId } from 'features/gallery/store/types';
import {
@@ -211,7 +211,7 @@ export const addGlobalReferenceImageDndTarget: DndTarget<
handler: ({ sourceData, dispatch, getState }) => {
const { imageDTO } = sourceData.payload;
const config = getDefaultRefImageConfig(getState);
config.image = imageDTOToImageWithDims(imageDTO);
config.image = imageDTOToCroppableImage(imageDTO);
dispatch(refImageAdded({ overrides: { config } }));
},
};
@@ -641,7 +641,7 @@ export const videoFrameFromImageDndTarget: DndTarget<VideoFrameFromImageDndTarge
},
handler: ({ sourceData, dispatch }) => {
const { imageDTO } = sourceData.payload;
dispatch(startingFrameImageChanged(imageDTOToImageWithDims(imageDTO)));
dispatch(startingFrameImageChanged(imageDTOToCroppableImage(imageDTO)));
},
};
//#endregion

View File

@@ -1,4 +1,5 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
import { useItemDTOContextImageOnly } from 'features/gallery/contexts/ItemDTOContext';
import { startingFrameImageChanged } from 'features/parameters/store/videoSlice';
import { navigationApi } from 'features/ui/layouts/navigation-api';
@@ -13,7 +14,7 @@ export const ContextMenuItemSendToVideo = memo(() => {
const dispatch = useDispatch();
const onClick = useCallback(() => {
dispatch(startingFrameImageChanged(imageDTO));
dispatch(startingFrameImageChanged(imageDTOToCroppableImage(imageDTO)));
navigationApi.switchToTab('video');
}, [imageDTO, dispatch]);

View File

@@ -2,7 +2,7 @@ import { MenuItem } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/storeHooks';
import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
import { refImageAdded } from 'features/controlLayers/store/refImagesSlice';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
import { useItemDTOContextImageOnly } from 'features/gallery/contexts/ItemDTOContext';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
@@ -17,7 +17,7 @@ export const ContextMenuItemUseAsRefImage = memo(() => {
const onClickNewGlobalReferenceImageFromImage = useCallback(() => {
const { dispatch, getState } = store;
const config = getDefaultRefImageConfig(getState);
config.image = imageDTOToImageWithDims(imageDTO);
config.image = imageDTOToCroppableImage(imageDTO);
dispatch(refImageAdded({ overrides: { config } }));
toast({
id: 'SENT_TO_CANVAS',

View File

@@ -26,7 +26,12 @@ import type {
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import {
imageDTOToCroppableImage,
imageDTOToImageObject,
imageDTOToImageWithDims,
initialControlNet,
} from 'features/controlLayers/store/util';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import type { BoardId } from 'features/gallery/store/types';
@@ -44,7 +49,7 @@ import { assert } from 'tsafe';
export const setGlobalReferenceImage = (arg: { imageDTO: ImageDTO; id: string; dispatch: AppDispatch }) => {
const { imageDTO, id, dispatch } = arg;
dispatch(refImageImageChanged({ id, imageDTO }));
dispatch(refImageImageChanged({ id, croppableImage: imageDTOToCroppableImage(imageDTO) }));
};
export const setRegionalGuidanceReferenceImage = (arg: {

View File

@@ -975,7 +975,7 @@ const RefImages: CollectionMetadataHandler<RefImageState[]> = {
for (const refImage of parsed) {
if (refImage.config.image) {
await throwIfImageDoesNotExist(refImage.config.image.image_name, store);
await throwIfImageDoesNotExist(refImage.config.image.original.image.image_name, store);
}
if (refImage.config.model) {
await throwIfModelDoesNotExist(refImage.config.model.key, store);

View File

@@ -35,7 +35,7 @@ export const LaunchpadForm = memo(() => {
return (
<Flex flexDir="column" height="100%" gap={3}>
<ScrollableContent>
<Flex flexDir="column" gap={6} p={3}>
<Flex flexDir="column" gap={6} py={2}>
{/* Welcome Section */}
<Flex flexDir="column" gap={2} alignItems="flex-start">
<Heading size="md">{t('modelManager.launchpad.welcome')}</Heading>

View File

@@ -0,0 +1,45 @@
import { Badge, Button, Flex } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCheckBold, PiPlusBold } from 'react-icons/pi';
type Props = {
handleInstall: () => void;
isInstalled: boolean;
};
export const ModelResultItemActions = memo(({ handleInstall, isInstalled }: Props) => {
const { t } = useTranslation();
return (
<Flex gap={2} shrink={0} pt={1}>
{isInstalled ? (
// TODO: Add a link button to navigate to model
<Badge
variant="subtle"
colorScheme="green"
display="flex"
gap={1}
alignItems="center"
borderRadius="base"
h="24px"
>
<PiCheckBold size="14px" />
</Badge>
) : (
<Button
onClick={handleInstall}
rightIcon={<PiPlusBold size="14px" />}
textTransform="uppercase"
letterSpacing="wider"
fontSize="9px"
size="sm"
>
{t('modelManager.install')}
</Button>
)}
</Flex>
);
});
ModelResultItemActions.displayName = 'ModelResultItemActions';

View File

@@ -1,33 +1,56 @@
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Text } from '@invoke-ai/ui-library';
import { ModelResultItemActions } from 'features/modelManagerV2/subpanels/AddModelPanel/ModelResultItemActions';
import { memo, useCallback, useMemo } from 'react';
import type { ScanFolderResponse } from 'services/api/endpoints/models';
type Props = {
result: ScanFolderResponse[number];
installModel: (source: string) => void;
};
export const ScanModelResultItem = memo(({ result, installModel }: Props) => {
const { t } = useTranslation();
const scanFolderResultItemSx: SystemStyleObject = {
alignItems: 'center',
justifyContent: 'space-between',
w: '100%',
py: 2,
px: 1,
gap: 3,
borderBottomWidth: '1px',
borderColor: 'base.700',
};
export const ScanModelResultItem = memo(({ result, installModel }: Props) => {
const handleInstall = useCallback(() => {
installModel(result.path);
}, [installModel, result]);
const modelDisplayName = useMemo(() => {
const normalizedPath = result.path.replace(/\\/g, '/').replace(/\/+$/, '');
// Extract filename/folder name from path
const lastSlashIndex = normalizedPath.lastIndexOf('/');
return lastSlashIndex === -1 ? normalizedPath : normalizedPath.slice(lastSlashIndex + 1);
}, [result.path]);
const modelPathParts = result.path.split(/[/\\]/);
return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
<Flex sx={scanFolderResultItemSx}>
<Flex fontSize="sm" flexDir="column">
<Text fontWeight="semibold">{result.path.split('\\').slice(-1)[0]}</Text>
<Text variant="subtext">{result.path}</Text>
{/* Model Title */}
<Text fontWeight="semibold">{modelDisplayName}</Text>
{/* Model Path */}
<Flex flexWrap="wrap" color="base.200">
{modelPathParts.map((part, index) => (
<Text key={index} variant="subtext">
{part}
{index < modelPathParts.length - 1 && '/'}
</Text>
))}
</Flex>
</Flex>
<Box>
{result.is_installed ? (
<Badge>{t('common.installed')}</Badge>
) : (
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleInstall} size="sm" />
)}
</Box>
<ModelResultItemActions handleInstall={handleInstall} isInstalled={result.is_installed} />
</Flex>
);
});

View File

@@ -113,9 +113,9 @@ export const ScanModelsResults = memo(({ results }: ScanModelResultsProps) => {
</InputGroup>
</Flex>
</Flex>
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
<Flex height="100%" layerStyle="second" borderRadius="base" px={2}>
<ScrollableContent>
<Flex flexDir="column" gap={3}>
<Flex flexDir="column">
{filteredResults.map((result) => (
<ScanModelResultItem key={result.path} result={result} installModel={handleInstallOne} />
))}

View File

@@ -13,6 +13,7 @@ import { useStarterBundleInstallStatus } from 'features/modelManagerV2/hooks/use
import { t } from 'i18next';
import type { MouseEvent } from 'react';
import { useCallback } from 'react';
import { PiDownloadSimpleBold } from 'react-icons/pi';
import type { S } from 'services/api/types';
export const StarterBundleButton = ({ bundle, ...rest }: { bundle: S['StarterModelBundle'] } & ButtonProps) => {
@@ -33,8 +34,16 @@ export const StarterBundleButton = ({ bundle, ...rest }: { bundle: S['StarterMod
return (
<>
<Button onClick={onClickBundle} isDisabled={install.length === 0} {...rest}>
{bundle.name}
<Button
display="flex"
justifyContent="space-between"
gap={2}
onClick={onClickBundle}
isDisabled={install.length === 0}
{...rest}
>
<span>{bundle.name}</span>
<PiDownloadSimpleBold size="16px" />
</Button>
<ConfirmationAlertDialog
isOpen={isOpen}

View File

@@ -1,17 +1,30 @@
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Badge, Flex, Text } from '@invoke-ai/ui-library';
import { negate } from 'es-toolkit/compat';
import { flattenStarterModel, useBuildModelInstallArg } from 'features/modelManagerV2/hooks/useBuildModelsToInstall';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import { ModelResultItemActions } from 'features/modelManagerV2/subpanels/AddModelPanel/ModelResultItemActions';
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import type { StarterModel } from 'services/api/types';
const starterModelResultItemSx: SystemStyleObject = {
alignItems: 'start',
justifyContent: 'space-between',
w: '100%',
py: 2,
px: 1,
gap: 2,
borderBottomWidth: '1px',
borderColor: 'base.700',
};
type Props = {
starterModel: StarterModel;
};
export const StarterModelsResultItem = memo(({ starterModel }: Props) => {
const { t } = useTranslation();
const { getIsInstalled, buildModelInstallArg } = useBuildModelInstallArg();
@@ -40,22 +53,16 @@ export const StarterModelsResultItem = memo(({ starterModel }: Props) => {
}, [modelsToInstall, installModel, t]);
return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
<Flex sx={starterModelResultItemSx}>
<Flex fontSize="sm" flexDir="column">
<Flex gap={3}>
<Text fontWeight="semibold">{starterModel.name}</Text>
<Text variant="subtext">{starterModel.description}</Text>
<Flex gap={1} py={1} alignItems="center">
<Badge h="min-content">{starterModel.type.replaceAll('_', ' ')}</Badge>
<ModelBaseBadge base={starterModel.base} />
<Text fontWeight="semibold">{starterModel.name}</Text>
</Flex>
<Text variant="subtext">{starterModel.description}</Text>
</Flex>
<Box>
{isInstalled ? (
<Badge>{t('common.installed')}</Badge>
) : (
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
)}
</Box>
<ModelResultItemActions handleInstall={onClick} isInstalled={isInstalled} />
</Flex>
);
});

View File

@@ -48,9 +48,9 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
return (
<Flex flexDir="column" gap={3} height="100%">
<Flex justifyContent="space-between" alignItems="center">
<Flex gap={3} direction="column">
{size(results.starter_bundles) > 0 && (
<Flex gap={4} alignItems="center">
<Flex gap={4} alignItems="center" justifyContent="space-between" p={4} borderWidth="1px" rounded="base">
<Flex gap={2} alignItems="center">
<Text color="base.200" fontWeight="semibold">
{t('modelManager.starterBundles')}
@@ -73,7 +73,8 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
</Flex>
</Flex>
)}
<InputGroup w={64} size="xs">
<InputGroup w="100%" size="xs">
<Input
placeholder={t('modelManager.search')}
value={searchTerm}
@@ -96,9 +97,10 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
)}
</InputGroup>
</Flex>
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
<Flex height="100%" layerStyle="second" borderRadius="base" px={2}>
<ScrollableContent>
<Flex flexDir="column" gap={3}>
<Flex flexDir="column">
{filteredResults.map((result) => (
<StarterModelsResultItem key={result.source} starterModel={result} />
))}

View File

@@ -1,10 +1,12 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Button, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $installModelsTabIndex } from 'features/modelManagerV2/store/installModelsStore';
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiInfoBold } from 'react-icons/pi';
import { PiCubeBold, PiFolderOpenBold, PiInfoBold, PiLinkSimpleBold, PiShootingStarBold } from 'react-icons/pi';
import { SiHuggingface } from 'react-icons/si';
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
@@ -12,6 +14,12 @@ import { LaunchpadForm } from './AddModelPanel/LaunchpadForm/LaunchpadForm';
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
const installModelsTabSx: SystemStyleObject = {
display: 'flex',
gap: 2,
px: 2,
};
export const InstallModels = memo(() => {
const { t } = useTranslation();
const tabIndex = useStore($installModelsTabIndex);
@@ -29,21 +37,36 @@ export const InstallModels = memo(() => {
</Button>
</Flex>
<Tabs
variant="collapse"
height="50%"
variant="line"
height="100%"
display="flex"
flexDir="column"
index={tabIndex}
onChange={$installModelsTabIndex.set}
>
<TabList>
<Tab>{t('modelManager.launchpadTab')}</Tab>
<Tab>{t('modelManager.urlOrLocalPath')}</Tab>
<Tab>{t('modelManager.huggingFace')}</Tab>
<Tab>{t('modelManager.scanFolder')}</Tab>
<Tab>{t('modelManager.starterModels')}</Tab>
<Tab sx={installModelsTabSx}>
<PiCubeBold />
{t('modelManager.launchpadTab')}
</Tab>
<Tab sx={installModelsTabSx}>
<PiLinkSimpleBold />
{t('modelManager.urlOrLocalPath')}
</Tab>
<Tab sx={installModelsTabSx}>
<SiHuggingface />
{t('modelManager.huggingFace')}
</Tab>
<Tab sx={installModelsTabSx}>
<PiFolderOpenBold />
{t('modelManager.scanFolder')}
</Tab>
<Tab sx={installModelsTabSx}>
<PiShootingStarBold />
{t('modelManager.starterModels')}
</Tab>
</TabList>
<TabPanels p={3} height="100%">
<TabPanels height="100%">
<TabPanel height="100%">
<LaunchpadForm />
</TabPanel>

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectSelectedModelKey, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
@@ -8,6 +9,16 @@ import { PiPlusBold } from 'react-icons/pi';
import ModelList from './ModelManagerPanel/ModelList';
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
const modelManagerSx: SystemStyleObject = {
flexDir: 'column',
p: 4,
gap: 4,
borderRadius: 'base',
w: '50%',
minWidth: '360px',
h: 'full',
};
export const ModelManager = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
@@ -17,7 +28,7 @@ export const ModelManager = memo(() => {
const selectedModelKey = useAppSelector(selectSelectedModelKey);
return (
<Flex flexDir="column" layerStyle="first" p={4} gap={4} borderRadius="base" w="50%" h="full">
<Flex sx={modelManagerSx}>
<Flex w="full" gap={4} justifyContent="space-between" alignItems="center">
<Heading fontSize="xl" py={1}>
{t('common.modelManager')}
@@ -28,7 +39,7 @@ export const ModelManager = memo(() => {
</Button>
)}
</Flex>
<Flex flexDir="column" layerStyle="second" p={4} gap={4} borderRadius="base" w="full" h="full">
<Flex flexDir="column" gap={4} w="full" h="full">
<ModelListNavigation />
<ModelList />
</Flex>

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { PiImage } from 'react-icons/pi';
@@ -6,19 +7,23 @@ type Props = {
image_url?: string | null;
};
export const MODEL_IMAGE_THUMBNAIL_SIZE = '40px';
const FALLBACK_ICON_SIZE = '24px';
const MODEL_IMAGE_THUMBNAIL_SIZE = '54px';
const FALLBACK_ICON_SIZE = '28px';
const sharedSx: SystemStyleObject = {
rounded: 'base',
height: MODEL_IMAGE_THUMBNAIL_SIZE,
minWidth: MODEL_IMAGE_THUMBNAIL_SIZE,
bg: 'base.850',
borderWidth: '1px',
borderColor: 'base.750',
borderStyle: 'solid',
};
const ModelImage = ({ image_url }: Props) => {
if (!image_url) {
return (
<Flex
height={MODEL_IMAGE_THUMBNAIL_SIZE}
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
borderRadius="base"
alignItems="center"
justifyContent="center"
>
<Flex alignItems="center" justifyContent="center" sx={sharedSx}>
<Icon color="base.500" as={PiImage} boxSize={FALLBACK_ICON_SIZE} />
</Flex>
);
@@ -29,16 +34,14 @@ const ModelImage = ({ image_url }: Props) => {
src={image_url}
objectFit="cover"
objectPosition="50% 50%"
height={MODEL_IMAGE_THUMBNAIL_SIZE}
width={MODEL_IMAGE_THUMBNAIL_SIZE}
minHeight={MODEL_IMAGE_THUMBNAIL_SIZE}
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
borderRadius="base"
sx={sharedSx}
fallback={
<Flex
height={MODEL_IMAGE_THUMBNAIL_SIZE}
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
borderRadius="base"
sx={sharedSx}
alignItems="center"
justifyContent="center"
>

View File

@@ -1,32 +1,57 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { ConfirmationAlertDialog, Flex, IconButton, Spacer, Text, useDisclosure } from '@invoke-ai/ui-library';
import { Flex, Spacer, Text } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectModelManagerV2Slice, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import ModelFormatBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge';
import { toast } from 'features/toast/toast';
import { ModelDeleteButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelDeleteButton';
import { filesize } from 'filesize';
import type { MouseEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import ModelImage, { MODEL_IMAGE_THUMBNAIL_SIZE } from './ModelImage';
import ModelImage from './ModelImage';
type ModelListItemProps = {
model: AnyModelConfig;
};
const sx: SystemStyleObject = {
_hover: { bg: 'base.700' },
"&[aria-selected='true']": { bg: 'base.700' },
paddingInline: 3,
paddingBlock: 2,
position: 'relative',
rounded: 'base',
'&:after,&:before': {
content: `''`,
position: 'absolute',
pointerEvents: 'none',
},
'&:after': {
h: '1px',
bottom: '-0.5px',
insetInline: 3,
bg: 'base.850',
},
'&:before': {
left: 1,
w: 1,
insetBlock: 2,
rounded: 'base',
},
_hover: {
bg: 'base.850',
'& .delete-button': { opacity: 1 },
},
'& .delete-button': { opacity: 0 },
"&[aria-selected='false']:hover:before": { bg: 'base.750' },
"&[aria-selected='true']": {
bg: 'base.800',
'& .delete-button': { opacity: 1 },
},
"&[aria-selected='true']:before": { bg: 'invokeBlue.300' },
};
const ModelListItem = ({ model }: ModelListItemProps) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const selectIsSelected = useMemo(
() =>
@@ -37,58 +62,25 @@ const ModelListItem = ({ model }: ModelListItemProps) => {
[model.key]
);
const isSelected = useAppSelector(selectIsSelected);
const [deleteModel] = useDeleteModelsMutation();
const { isOpen, onOpen, onClose } = useDisclosure();
const handleSelectModel = useCallback(() => {
dispatch(setSelectedModelKey(model.key));
}, [model.key, dispatch]);
const onClickDeleteButton = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
onOpen();
},
[onOpen]
);
const handleModelDelete = useCallback(() => {
deleteModel({ key: model.key })
.unwrap()
.then((_) => {
toast({
id: 'MODEL_DELETED',
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
status: 'success',
});
})
.catch((error) => {
if (error) {
toast({
id: 'MODEL_DELETE_FAILED',
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
status: 'error',
});
}
});
dispatch(setSelectedModelKey(null));
}, [deleteModel, model, dispatch, t]);
return (
<Flex
sx={sx}
aria-selected={isSelected}
justifyContent="flex-start"
p={2}
borderRadius="base"
w="full"
alignItems="center"
alignItems="flex-start"
gap={2}
cursor="pointer"
onClick={handleSelectModel}
>
<Flex gap={2} w="full" h="full" minW={0}>
<ModelImage image_url={model.cover_image} />
<Flex gap={1} alignItems="flex-start" flexDir="column" w="full" minW={0}>
<Flex alignItems="flex-start" flexDir="column" w="full" minW={0}>
<Flex gap={2} w="full" alignItems="flex-start">
<Text fontWeight="semibold" noOfLines={1} wordBreak="break-all">
{model.name}
@@ -101,39 +93,15 @@ const ModelListItem = ({ model }: ModelListItemProps) => {
<Text variant="subtext" noOfLines={1}>
{model.description || 'No Description'}
</Text>
</Flex>
<Flex
h={MODEL_IMAGE_THUMBNAIL_SIZE}
flexDir="column"
alignItems="flex-end"
justifyContent="space-between"
gap={2}
>
<ModelBaseBadge base={model.base} />
<ModelFormatBadge format={model.format} />
<Flex gap={1} mt={1}>
<ModelBaseBadge base={model.base} />
<ModelFormatBadge format={model.format} />
</Flex>
</Flex>
</Flex>
<IconButton
onClick={onClickDeleteButton}
icon={<PiTrashSimpleBold size={16} />}
aria-label={t('modelManager.deleteConfig')}
colorScheme="error"
h={MODEL_IMAGE_THUMBNAIL_SIZE}
w={MODEL_IMAGE_THUMBNAIL_SIZE}
/>
<ConfirmationAlertDialog
isOpen={isOpen}
onClose={onClose}
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
useInert={false}
>
<Flex rowGap={4} flexDirection="column">
<Text fontWeight="bold">{t('modelManager.deleteMsg1')}</Text>
<Text>{t('modelManager.deleteMsg2')}</Text>
</Flex>
</ConfirmationAlertDialog>
<Flex mt={1}>
<ModelDeleteButton modelConfig={model} showLabel={false} />
</Flex>
</Flex>
);
};

View File

@@ -1,4 +1,4 @@
import { Flex, IconButton, Input, InputGroup, InputRightElement, Spacer } from '@invoke-ai/ui-library';
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectSearchTerm, setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { t } from 'i18next';
@@ -25,9 +25,7 @@ export const ModelListNavigation = memo(() => {
return (
<Flex gap={2} alignItems="center" justifyContent="space-between">
<ModelTypeFilter />
<Spacer />
<InputGroup maxW="400px">
<InputGroup>
<Input
placeholder={t('modelManager.search')}
value={searchTerm || ''}
@@ -47,6 +45,9 @@ export const ModelListNavigation = memo(() => {
</InputRightElement>
)}
</InputGroup>
<Flex shrink={0}>
<ModelTypeFilter />
</Flex>
</Flex>
);
});

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { StickyScrollable } from 'features/system/components/StickyScrollable';
import { memo } from 'react';
import type { AnyModelConfig } from 'services/api/types';
@@ -9,10 +10,23 @@ type ModelListWrapperProps = {
modelList: AnyModelConfig[];
};
const headingSx = {
bg: 'base.900',
pb: 3,
pl: 3,
} satisfies SystemStyleObject;
const contentSx = {
gap: 0,
p: 0,
bg: 'base.900',
borderRadius: '0',
} satisfies SystemStyleObject;
export const ModelListWrapper = memo((props: ModelListWrapperProps) => {
const { title, modelList } = props;
return (
<StickyScrollable title={title} contentSx={{ gap: 1, p: 2 }}>
<StickyScrollable title={title} contentSx={contentSx} headingSx={headingSx}>
{modelList.map((model) => (
<ModelListItem key={model.key} model={model} />
))}

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
@@ -6,13 +7,22 @@ import { memo } from 'react';
import { InstallModels } from './InstallModels';
import { Model } from './ModelPanel/Model';
const modelPaneSx: SystemStyleObject = {
layerStyle: 'first',
p: 4,
borderRadius: 'base',
w: {
base: '50%',
lg: '75%',
'2xl': '85%',
},
h: 'full',
minWidth: '300px',
};
export const ModelPane = memo(() => {
const selectedModelKey = useAppSelector(selectSelectedModelKey);
return (
<Box layerStyle="first" p={4} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
</Box>
);
return <Box sx={modelPaneSx}>{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}</Box>;
});
ModelPane.displayName = 'ModelPane';

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, IconButton, Image } from '@invoke-ai/ui-library';
import { dropzoneAccept } from 'common/hooks/useImageUploadButton';
import { typedMemo } from 'common/util/typedMemo';
@@ -8,6 +9,21 @@ import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiUploadBold } from 'react-icons/pi';
import { useDeleteModelImageMutation, useUpdateModelImageMutation } from 'services/api/endpoints/models';
const sharedSx: SystemStyleObject = {
w: 108,
h: 108,
fontSize: 36,
borderRadius: 'base',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
bg: 'base.800',
borderWidth: '1px',
borderStyle: 'solid',
borderColor: 'base.700',
flexShrink: 0,
};
type Props = {
model_key: string | null;
model_image?: string | null;
@@ -86,10 +102,9 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
src={image}
objectFit="cover"
objectPosition="50% 50%"
height={108}
width={108}
minWidth={108}
borderRadius="base"
sx={sharedSx}
/>
<IconButton
position="absolute"
@@ -112,10 +127,9 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
variant="ghost"
aria-label={t('modelManager.uploadImage')}
tooltip={t('modelManager.uploadImage')}
w={108}
h={108}
fontSize={36}
icon={<PiUploadBold />}
sx={sharedSx}
isLoading={request.isLoading}
{...getRootProps()}
/>

View File

@@ -52,6 +52,7 @@ export const ModelConvertButton = memo(({ modelConfig }: ModelConvertProps) => {
return (
<>
<Button
variant="outline"
onClick={onOpen}
size="sm"
aria-label={t('modelManager.convertToDiffusers')}

View File

@@ -0,0 +1,95 @@
import { Button, ConfirmationAlertDialog, Flex, IconButton, Text, useDisclosure } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppDispatch } from 'app/store/storeHooks';
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { toast } from 'features/toast/toast';
import { memo, type MouseEvent, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
type Props = {
showLabel?: boolean;
modelConfig: AnyModelConfig;
};
export const ModelDeleteButton = memo(({ showLabel = true, modelConfig }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const log = logger('models');
const [deleteModel] = useDeleteModelsMutation();
const { isOpen, onOpen, onClose } = useDisclosure();
const onClickDeleteButton = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
onOpen();
},
[onOpen]
);
const handleModelDelete = useCallback(() => {
deleteModel({ key: modelConfig.key })
.unwrap()
.then(() => {
dispatch(setSelectedModelKey(null));
toast({
id: 'MODEL_DELETED',
title: `${t('modelManager.modelDeleted')}: ${modelConfig.name}`,
status: 'success',
});
})
.catch((error) => {
log.error('Error deleting model', error);
toast({
id: 'MODEL_DELETE_FAILED',
title: `${t('modelManager.modelDeleteFailed')}: ${modelConfig.name}`,
status: 'error',
});
});
}, [deleteModel, modelConfig.key, modelConfig.name, dispatch, t, log]);
return (
<>
{showLabel ? (
<Button
className="delete-button"
size="sm"
leftIcon={<PiTrashSimpleBold />}
colorScheme="error"
onClick={onClickDeleteButton}
flexShrink={0}
>
{t('modelManager.delete')}
</Button>
) : (
<IconButton
className="delete-button"
onClick={onClickDeleteButton}
icon={<PiTrashSimpleBold size={16} />}
aria-label={t('modelManager.deleteConfig')}
colorScheme="error"
/>
)}
<ConfirmationAlertDialog
isOpen={isOpen}
onClose={onClose}
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
useInert={false}
>
<Flex rowGap={4} flexDirection="column">
<Text fontWeight="bold">{t('modelManager.deleteMsg1')}</Text>
<Text>{t('modelManager.deleteMsg2')}</Text>
</Flex>
</ConfirmationAlertDialog>
</>
);
});
ModelDeleteButton.displayName = 'ModelDeleteButton';

View File

@@ -24,6 +24,7 @@ import type { AnyModelConfig } from 'services/api/types';
import BaseModelSelect from './Fields/BaseModelSelect';
import ModelVariantSelect from './Fields/ModelVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
import { ModelFooter } from './ModelFooter';
type Props = {
modelConfig: AnyModelConfig;
@@ -158,6 +159,7 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
</Flex>
</form>
</Flex>
<ModelFooter modelConfig={modelConfig} isEditing={true} />
</Flex>
);
});

View File

@@ -0,0 +1,66 @@
import { Flex, Heading, type SystemStyleObject } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
import { ModelConvertButton } from './ModelConvertButton';
import { ModelDeleteButton } from './ModelDeleteButton';
import { ModelEditButton } from './ModelEditButton';
const footerRowSx: SystemStyleObject = {
justifyContent: 'space-between',
alignItems: 'center',
gap: 3,
'&:not(:last-of-type)': {
borderBottomWidth: '1px',
borderBottomStyle: 'solid',
borderBottomColor: 'border',
},
p: 3,
};
type Props = {
modelConfig: AnyModelConfig;
isEditing: boolean;
};
export const ModelFooter = memo(({ modelConfig, isEditing }: Props) => {
const { t } = useTranslation();
const shouldShowConvertOption = !isEditing && modelConfig.format === 'checkpoint' && modelConfig.type === 'main';
return (
<Flex flexDirection="column" borderWidth="1px" borderRadius="base">
{shouldShowConvertOption && (
<Flex sx={footerRowSx}>
<Heading size="sm" color="base.100">
{t('modelManager.convertToDiffusers')}
</Heading>
<Flex py={1}>
<ModelConvertButton modelConfig={modelConfig} />
</Flex>
</Flex>
)}
{!isEditing && (
<Flex sx={footerRowSx}>
<Heading size="sm" color="base.100">
{t('modelManager.edit')}
</Heading>
<Flex py={1}>
<ModelEditButton />
</Flex>
</Flex>
)}
<Flex sx={footerRowSx}>
<Heading size="sm" color="error.200">
{t('modelManager.deleteModel')}
</Heading>
<Flex py={1}>
<ModelDeleteButton modelConfig={modelConfig} />
</Flex>
</Flex>
</Flex>
);
});
ModelFooter.displayName = 'ModelFooter';

View File

@@ -1,4 +1,4 @@
import { Box, Flex, SimpleGrid } from '@invoke-ai/ui-library';
import { Box, Divider, Flex, SimpleGrid } from '@invoke-ai/ui-library';
import { ControlAdapterModelDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/ControlAdapterModelDefaultSettings';
import { LoRAModelDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/LoRAModelDefaultSettings/LoRAModelDefaultSettings';
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
@@ -12,6 +12,7 @@ import type { AnyModelConfig } from 'services/api/types';
import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings';
import { ModelAttrView } from './ModelAttrView';
import { ModelFooter } from './ModelFooter';
import { RelatedModels } from './RelatedModels';
type Props = {
@@ -46,8 +47,9 @@ export const ModelView = memo(({ modelConfig }: Props) => {
)}
<ModelEditButton />
</ModelHeader>
<Divider />
<Flex flexDir="column" h="full" gap={4}>
<Box layerStyle="second" borderRadius="base" p={4}>
<Box>
<SimpleGrid columns={2} gap={4}>
<ModelAttrView label={t('modelManager.baseModel')} value={modelConfig.base} />
<ModelAttrView label={t('modelManager.modelType')} value={modelConfig.type} />
@@ -73,26 +75,33 @@ export const ModelView = memo(({ modelConfig }: Props) => {
</SimpleGrid>
</Box>
{withSettings && (
<Box layerStyle="second" borderRadius="base" p={4}>
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
<MainModelDefaultSettings modelConfig={modelConfig} />
)}
{(modelConfig.type === 'controlnet' ||
modelConfig.type === 't2i_adapter' ||
modelConfig.type === 'control_lora') && <ControlAdapterModelDefaultSettings modelConfig={modelConfig} />}
{modelConfig.type === 'lora' && (
<>
<LoRAModelDefaultSettings modelConfig={modelConfig} />
<TriggerPhrases modelConfig={modelConfig} />
</>
)}
{modelConfig.type === 'main' && <TriggerPhrases modelConfig={modelConfig} />}
</Box>
<>
<Divider />
<Box>
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
<MainModelDefaultSettings modelConfig={modelConfig} />
)}
{(modelConfig.type === 'controlnet' ||
modelConfig.type === 't2i_adapter' ||
modelConfig.type === 'control_lora') && (
<ControlAdapterModelDefaultSettings modelConfig={modelConfig} />
)}
{modelConfig.type === 'lora' && (
<>
<LoRAModelDefaultSettings modelConfig={modelConfig} />
<TriggerPhrases modelConfig={modelConfig} />
</>
)}
{modelConfig.type === 'main' && <TriggerPhrases modelConfig={modelConfig} />}
</Box>
</>
)}
<Box overflowY="auto" layerStyle="second" borderRadius="base" p={4}>
<Divider />
<Box overflowY="auto">
<RelatedModels modelConfig={modelConfig} />
</Box>
</Flex>
<ModelFooter modelConfig={modelConfig} isEditing={false} />
</Flex>
);
});

View File

@@ -74,6 +74,8 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
[addTriggerPhrase]
);
const hasTriggerPhrases = triggerPhrases.length > 0;
return (
<Flex flexDir="column" w="full" gap="5">
<form onSubmit={onTriggerPhraseAddFormSubmit}>
@@ -99,14 +101,16 @@ export const TriggerPhrases = memo(({ modelConfig }: Props) => {
</FormControl>
</form>
<Flex gap="4" flexWrap="wrap">
{triggerPhrases.map((phrase, index) => (
<Tag size="md" key={index} py={2} px={4} bg="base.700">
<TagLabel>{phrase}</TagLabel>
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
</Tag>
))}
</Flex>
{hasTriggerPhrases && (
<Flex gap="4" flexWrap="wrap">
{triggerPhrases.map((phrase, index) => (
<Tag size="md" key={index} py={2} px={4} bg="base.700">
<TagLabel>{phrase}</TagLabel>
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
</Tag>
))}
</Flex>
)}
</Flex>
);
});

View File

@@ -253,11 +253,11 @@ const PublishWorkflowButton = memo(() => {
),
duration: null,
});
assert(result.value.enqueueResult.batch.batch_id);
assert(result.value.batchConfig.validation_run_data);
assert(result.value?.enqueueResult.batch.batch_id);
assert(result.value?.batchConfig.validation_run_data);
$validationRunData.set({
batchId: result.value.enqueueResult.batch.batch_id,
workflowId: result.value.batchConfig.validation_run_data.workflow_id,
batchId: result.value?.enqueueResult.batch.batch_id,
workflowId: result.value?.batchConfig.validation_run_data.workflow_id,
});
log.debug(parseify(result.value), 'Enqueued batch');
}

View File

@@ -87,7 +87,7 @@ const addFLUXRedux = (id: string, ipAdapter: FLUXReduxConfig, g: Graph, collecto
type: 'flux_redux',
redux_model: fluxReduxModel,
image: {
image_name: image.image_name,
image_name: image.crop?.image.image_name ?? image.original.image.image_name,
},
...IMAGE_INFLUENCE_TO_SETTINGS[ipAdapter.imageInfluence ?? 'highest'],
});

View File

@@ -58,7 +58,7 @@ const addIPAdapter = (id: string, ipAdapter: IPAdapterConfig, g: Graph, collecto
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.image_name,
image_name: image.crop?.image.image_name ?? image.original.image.image_name,
},
});
} else {
@@ -77,7 +77,7 @@ const addIPAdapter = (id: string, ipAdapter: IPAdapterConfig, g: Graph, collecto
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.image_name,
image_name: image.crop?.image.image_name ?? image.original.image.image_name,
},
});
}

View File

@@ -5,8 +5,8 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
type CanvasRegionalGuidanceState,
isFLUXReduxConfig,
isIPAdapterConfig,
isRegionalGuidanceFLUXReduxConfig,
isRegionalGuidanceIPAdapterConfig,
type Rect,
} from 'features/controlLayers/store/types';
import { getRegionalGuidanceWarnings } from 'features/controlLayers/store/validators';
@@ -279,7 +279,7 @@ export const addRegions = async ({
}
for (const { id, config } of region.referenceImages) {
if (isIPAdapterConfig(config)) {
if (isRegionalGuidanceIPAdapterConfig(config)) {
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
result.addedIPAdapters++;
@@ -304,7 +304,7 @@ export const addRegions = async ({
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', ipAdapterNode, 'mask');
g.addEdge(ipAdapterNode, 'ip_adapter', ipAdapterCollect, 'item');
} else if (isFLUXReduxConfig(config)) {
} else if (isRegionalGuidanceFLUXReduxConfig(config)) {
assert(isFLUX, 'Regional FLUX Redux requires FLUX.');
assert(fluxReduxCollect !== null, 'FLUX Redux collector is required.');
result.addedFLUXReduxes++;

View File

@@ -50,7 +50,7 @@ export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBu
for (const entity of validRefImages) {
assert(entity.config.image, 'Image is required for reference image');
reference_images.push({
image_name: entity.config.image.image_name,
image_name: entity.config.image.crop?.image.image_name ?? entity.config.image.original.image.image_name,
});
}
}

View File

@@ -61,7 +61,7 @@ export const buildFluxKontextGraph = (arg: GraphBuilderArg): GraphBuilderReturn
aspect_ratio: aspectRatio.id,
prompt_upsampling: true,
input_image: {
image_name: firstImage.image_name,
image_name: firstImage.crop?.image.image_name ?? firstImage.original.image.image_name,
},
...selectCanvasOutputFields(state),
});

View File

@@ -45,7 +45,7 @@ export const buildGemini2_5Graph = (arg: GraphBuilderArg): GraphBuilderReturn =>
for (const entity of validRefImages) {
assert(entity.config.image, 'Image is required for reference image');
reference_images.push({
image_name: entity.config.image.image_name,
image_name: entity.config.image.crop?.image.image_name ?? entity.config.image.original.image.image_name,
});
}
}

View File

@@ -38,7 +38,7 @@ export const buildRunwayVideoGraph = (arg: GraphBuilderArg): GraphBuilderReturn
const startingFrameImage = selectStartingFrameImage(state);
assert(startingFrameImage, 'Video starting frame is required for runway video generation');
const firstFrameImageField = zImageField.parse(startingFrameImage);
const firstFrameImageField = zImageField.parse(startingFrameImage.crop?.image ?? startingFrameImage.original);
const { seed, shouldRandomizeSeed } = params;
const { videoDuration, videoAspectRatio, videoResolution } = videoParams;

View File

@@ -61,7 +61,7 @@ export const buildVeo3VideoGraph = (arg: GraphBuilderArg): GraphBuilderReturn =>
const startingFrameImage = selectStartingFrameImage(state);
if (startingFrameImage) {
const startingFrameImageField = zImageField.parse(startingFrameImage);
const startingFrameImageField = zImageField.parse(startingFrameImage.crop?.image ?? startingFrameImage.original);
// @ts-expect-error: This node is not available in the OSS application
veo3VideoNode.starting_image = startingFrameImageField;
}

View File

@@ -1,11 +1,12 @@
import { Box, Flex, Textarea } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { usePersistedTextAreaSize } from 'common/hooks/usePersistedTextareaSize';
import {
positivePromptChanged,
selectModelSupportsNegativePrompt,
selectPositivePrompt,
selectPositivePromptHistory,
} from 'features/controlLayers/store/paramsSlice';
import { promptGenerationFromImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
@@ -27,9 +28,10 @@ import {
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { selectAllowPromptExpansion } from 'features/system/store/configSlice';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useMemo, useRef } from 'react';
import React, { memo, useCallback, useMemo, useRef } from 'react';
import type { HotkeyCallback } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { useClickAway } from 'react-use';
import { useListStylePresetsQuery } from 'services/api/endpoints/stylePresets';
import { PositivePromptHistoryIconButton } from './PositivePromptHistory';
@@ -40,6 +42,78 @@ const persistOptions: Parameters<typeof usePersistedTextAreaSize>[2] = {
initialHeight: 120,
};
const usePromptHistory = () => {
const store = useAppStore();
const history = useAppSelector(selectPositivePromptHistory);
/**
* This ref is populated only when the user navigates back in history. In other words, its presence is a proxy
* for "are we currently browsing history?"
*
* When we are moving thru history, we will always have a stashedPrompt (the prompt before we started browsing)
* and a historyIdx which is an index into the history array (0 = most recent, 1 = previous, etc).
*/
const stateRef = useRef<{ stashedPrompt: string; historyIdx: number } | null>(null);
const prev = useCallback(() => {
if (history.length === 0) {
// No history, nothing to do
return;
}
let state = stateRef.current;
if (!state) {
// First time going "back" in history, init state
state = { stashedPrompt: selectPositivePrompt(store.getState()), historyIdx: 0 };
stateRef.current = state;
} else {
// Subsequent "back" in history, increment index
if (state.historyIdx === history.length - 1) {
// Already at the end of history, nothing to do
return;
}
state.historyIdx = state.historyIdx + 1;
}
// We should go "back" in history
const newPrompt = history[state.historyIdx];
if (newPrompt === undefined) {
// Shouldn't happen
return;
}
store.dispatch(positivePromptChanged(newPrompt));
}, [history, store]);
const next = useCallback(() => {
if (history.length === 0) {
// No history, nothing to do
return;
}
let state = stateRef.current;
if (!state) {
// If the user hasn't gone "back" in history, "forward" does nothing
return;
}
state.historyIdx = state.historyIdx - 1;
if (state.historyIdx < 0) {
// Overshot to the "current" stashed prompt
store.dispatch(positivePromptChanged(state.stashedPrompt));
// Clear state bc we're back to current prompt
stateRef.current = null;
return;
}
// We should go "forward" in history
const newPrompt = history[state.historyIdx];
if (newPrompt === undefined) {
// Shouldn't happen
return;
}
store.dispatch(positivePromptChanged(newPrompt));
}, [history, store]);
const reset = useCallback(() => {
// Clear stashed state - used when user clicks away or types in the prompt box
stateRef.current = null;
}, []);
return { prev, next, reset };
};
export const ParamPositivePrompt = memo(() => {
const dispatch = useAppDispatch();
const prompt = useAppSelector(selectPositivePrompt);
@@ -50,6 +124,8 @@ export const ParamPositivePrompt = memo(() => {
const isPromptExpansionEnabled = useAppSelector(selectAllowPromptExpansion);
const activeTab = useAppSelector(selectActiveTab);
const promptHistoryApi = usePromptHistory();
const textareaRef = useRef<HTMLTextAreaElement>(null);
usePersistedTextAreaSize('positive_prompt', textareaRef, persistOptions);
@@ -67,8 +143,11 @@ export const ParamPositivePrompt = memo(() => {
const handleChange = useCallback(
(v: string) => {
dispatch(positivePromptChanged(v));
// When the user changes the prompt, reset the prompt history state. This event is not fired when the prompt is
// changed via the prompt history navigation.
promptHistoryApi.reset();
},
[dispatch]
[dispatch, promptHistoryApi]
);
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
prompt,
@@ -77,6 +156,9 @@ export const ParamPositivePrompt = memo(() => {
isDisabled: isPromptExpansionPending,
});
// When the user clicks away from the textarea, reset the prompt history state.
useClickAway(textareaRef, promptHistoryApi.reset);
const focus: HotkeyCallback = useCallback(
(e) => {
onFocus();
@@ -93,6 +175,35 @@ export const ParamPositivePrompt = memo(() => {
dependencies: [focus],
});
// Helper: check if prompt textarea is focused
const isPromptFocused = useCallback(() => document.activeElement === textareaRef.current, []);
// Register hotkeys for browsing
useRegisteredHotkeys({
id: 'promptHistoryPrev',
category: 'app',
callback: (e) => {
if (isPromptFocused()) {
e.preventDefault();
promptHistoryApi.prev();
}
},
options: { preventDefault: true, enableOnFormTags: ['INPUT', 'SELECT', 'TEXTAREA'] },
dependencies: [promptHistoryApi.prev, isPromptFocused],
});
useRegisteredHotkeys({
id: 'promptHistoryNext',
category: 'app',
callback: (e) => {
if (isPromptFocused()) {
e.preventDefault();
promptHistoryApi.next();
}
},
options: { preventDefault: true, enableOnFormTags: ['INPUT', 'SELECT', 'TEXTAREA'] },
dependencies: [promptHistoryApi.next, isPromptFocused],
});
const dndTargetData = useMemo(() => promptGenerationFromImageDndTarget.getData(), []);
return (

View File

@@ -4,6 +4,7 @@ import {
Flex,
IconButton,
Input,
Kbd,
Popover,
PopoverBody,
PopoverContent,
@@ -22,6 +23,7 @@ import {
} from 'features/controlLayers/store/paramsSlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowArcLeftBold, PiClockCounterClockwise, PiTrashBold, PiTrashSimpleBold } from 'react-icons/pi';
export const PositivePromptHistoryIconButton = memo(() => {
@@ -50,6 +52,7 @@ export const PositivePromptHistoryIconButton = memo(() => {
PositivePromptHistoryIconButton.displayName = 'PositivePromptHistoryIconButton';
const PromptHistoryContent = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const positivePromptHistory = useAppSelector(selectPositivePromptHistory);
const [searchTerm, setSearchTerm] = useState('');
@@ -96,25 +99,32 @@ const PromptHistoryContent = memo(() => {
</Button>
</Flex>
<Divider />
{positivePromptHistory.length === 0 && (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Text color="base.300">No prompt history recorded.</Text>
</Flex>
)}
{positivePromptHistory.length !== 0 && filteredPrompts.length === 0 && (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Text color="base.300">No matching prompts in history.</Text>{' '}
</Flex>
)}
{filteredPrompts.length > 0 && (
<ScrollableContent>
<Flex flexDir="column">
{filteredPrompts.map((prompt, index) => (
<PromptItem key={`${prompt}-${index}`} prompt={prompt} />
))}
<Flex flexDir="column" flexGrow={1} minH={0}>
{positivePromptHistory.length === 0 && (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Text color="base.300">{t('prompt.noPromptHistory')}</Text>
</Flex>
</ScrollableContent>
)}
)}
{positivePromptHistory.length !== 0 && filteredPrompts.length === 0 && (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Text color="base.300">{t('prompt.noMatchingPrompts')}</Text>{' '}
</Flex>
)}
{filteredPrompts.length > 0 && (
<ScrollableContent>
<Flex flexDir="column">
{filteredPrompts.map((prompt, index) => (
<PromptItem key={`${prompt}-${index}`} prompt={prompt} />
))}
</Flex>
</ScrollableContent>
)}
</Flex>
<Flex alignItems="center" justifyContent="center" pt={1}>
<Text color="base.300" textAlign="center">
<Kbd textTransform="lowercase">alt+up/down</Kbd> {t('prompt.toSwitchBetweenPrompts')}
</Text>
</Flex>
</Flex>
);
});

View File

@@ -4,7 +4,7 @@ import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type {
ImageWithDims,
CroppableImageWithDims,
VideoAspectRatio,
VideoDuration,
VideoResolution,
@@ -16,7 +16,7 @@ import {
isVeo3AspectRatioID,
isVeo3DurationID,
isVeo3Resolution,
zImageWithDims,
zCroppableImageWithDims,
zVideoAspectRatio,
zVideoDuration,
zVideoResolution,
@@ -30,8 +30,8 @@ import { assert } from 'tsafe';
import z from 'zod';
const zVideoState = z.object({
_version: z.literal(1),
startingFrameImage: zImageWithDims.nullable(),
_version: z.literal(2),
startingFrameImage: zCroppableImageWithDims.nullable(),
videoModel: zModelIdentifierField.nullable(),
videoResolution: zVideoResolution,
videoDuration: zVideoDuration,
@@ -42,7 +42,7 @@ export type VideoState = z.infer<typeof zVideoState>;
const getInitialState = (): VideoState => {
return {
_version: 1,
_version: 2,
startingFrameImage: null,
videoModel: null,
videoResolution: '1080p',
@@ -55,7 +55,7 @@ const slice = createSlice({
name: 'video',
initialState: getInitialState(),
reducers: {
startingFrameImageChanged: (state, action: PayloadAction<ImageWithDims | null>) => {
startingFrameImageChanged: (state, action: PayloadAction<CroppableImageWithDims | null>) => {
state.startingFrameImage = action.payload;
},
@@ -119,6 +119,13 @@ export const videoSliceConfig: SliceConfig<typeof slice> = {
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
state._version = 2;
if (state.startingFrameImage) {
// startingFrameImage changed from ImageWithDims to CroppableImageWithDims
state.startingFrameImage = zCroppableImageWithDims.parse({ original: state.startingFrameImage });
}
}
return zVideoState.parse(state);
},
},

View File

@@ -0,0 +1,51 @@
# Queue Enqueue Patterns
This directory contains the hooks and utilities that translate UI actions into queue batches. The flow is intentionally
modular so adding a new enqueue type (e.g. a new generation mode) follows a predictable recipe.
## Key building blocks
- `hooks/useEnqueue*.ts` Feature-specific hooks (generate, canvas, upscaling, video, workflows). Each hook wires local
state to the shared enqueue utilities.
- `hooks/utils/graphBuilders.ts` Maps base models (sdxl, flux, etc.) to their graph builder functions and normalizes
synchronous vs. asynchronous builders.
- `hooks/utils/executeEnqueue.ts` Orchestrates the enqueue lifecycle:
1. dispatch the `enqueueRequested*` action
2. build the graph/batch data
3. call `queueApi.endpoints.enqueueBatch`
4. run success/error callbacks
## Adding a new enqueue type
1. **Implement the graph builder (if needed).**
- Create the graph construction logic in `features/nodes/util/graph/generation/...` so it returns a
`GraphBuilderReturn`.
- If the builder reuses existing primitives, consider wiring it into `graphBuilders.ts` by extending the `graphBuilderMap`.
2. **Create the enqueue hook.**
- Add `useEnqueue<Feature>.ts` mirroring the existing hooks. Import `executeEnqueue` and supply feature-specific
`build`, `prepareBatch`, and `onSuccess` callbacks.
- If the feature depends on a new base model, add it to `graphBuilders.ts`.
3. **Register the tab in `useInvoke`.**
- `useInvoke.ts` looks up handlers based on the active tab. Import your new hook and call it inside the `switch`
(or future registry) so the UI can enqueue from the feature.
4. **Add Redux action (optional).**
- Most enqueue hooks dispatch a `enqueueRequested*` action for devtools visibility. Create one with `createAction` if
you want similar tracing.
5. **Cover with tests.**
- Unit-test feature-specific behavior (graph selection, batch tweaks). The shared helpers already have coverage in
`hooks/utils/`.
## Tips
- Keep `build` lean: fetch state, compose graph/batch data, and return `null` when prerequisites are missing. The shared
helper will skip enqueueing and your `onError` will handle logging.
- Use the shared `prepareLinearUIBatch` for single-graph UI workflows. For advanced cases (multi-run batches, workflow
validation runs), supply a custom `prepareBatch` function.
- Prefer updating `graphBuilders.ts` when adding a new base model so every image-based enqueue automatically benefits.
With this structure, the main task when introducing a new enqueue type is describing how to build its graph and how to
massage the batch payload—everything else (dispatching, API calls, history updates) is handled by the utilities.

View File

@@ -1,154 +1,114 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildGemini2_5Graph } from 'features/nodes/util/graph/generation/buildGemini2_5Graph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import { selectCanvasDestination } from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
import { AssertionError } from 'tsafe';
import type { EnqueueBatchArg } from './utils/executeEnqueue';
import { executeEnqueue } from './utils/executeEnqueue';
import { buildGraphForBase } from './utils/graphBuilders';
const log = logger('generation');
export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas');
const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prepend: boolean) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedCanvas());
const state = getState();
const destination = selectCanvasDestination(state);
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return;
}
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
const generationMode = await canvasManager.compositor.getGenerationMode();
const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager };
switch (base) {
case 'sdxl':
return await buildSDXLGraph(graphBuilderArg);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(graphBuilderArg);
case `sd-3`:
return await buildSD3Graph(graphBuilderArg);
case `flux`:
return await buildFLUXGraph(graphBuilderArg);
case 'cogview4':
return await buildCogView4Graph(graphBuilderArg);
case 'imagen3':
return buildImagen3Graph(graphBuilderArg);
case 'imagen4':
return buildImagen4Graph(graphBuilderArg);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(graphBuilderArg);
case 'flux-kontext':
return buildFluxKontextGraph(graphBuilderArg);
case 'gemini-2.5':
return buildGemini2_5Graph(graphBuilderArg);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, seed, positivePrompt } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'canvas',
destination,
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const batchConfig = prepareBatchResult.value;
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
);
const enqueueResult = await req.unwrap();
// Push to prompt history on successful enqueue
dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
return { batchConfig, enqueueResult };
type CanvasBuildResult = {
batchConfig: EnqueueBatchArg;
};
export const useEnqueueCanvas = () => {
const store = useAppStore();
const canvasManager = useCanvasManagerSafe();
const enqueue = useCallback(
(prepend: boolean) => {
if (!canvasManager) {
log.error('Canvas manager is not available');
return;
return null;
}
return enqueueCanvas(store, canvasManager, prepend);
return executeEnqueue({
store,
options: { prepend },
requestedAction: enqueueRequestedCanvas,
log,
build: async ({ store: innerStore, options }) => {
const state = innerStore.getState();
const destination = selectCanvasDestination(state);
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return null;
}
const generationMode = await canvasManager.compositor.getGenerationMode();
const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager };
const buildGraphResult = await withResultAsync(
async () => await buildGraphForBase(model.base, graphBuilderArg)
);
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({ status, title, description });
return null;
}
const { g, seed, positivePrompt } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base: model.base,
prepend: options.prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'canvas',
destination,
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return null;
}
return {
batchConfig: prepareBatchResult.value,
} satisfies CanvasBuildResult;
},
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
onSuccess: ({ store: innerStore }) => {
const state = innerStore.getState();
innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
},
});
},
[canvasManager, store]
);
return enqueue;
};

View File

@@ -1,143 +1,103 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildGemini2_5Graph } from 'features/nodes/util/graph/generation/buildGemini2_5Graph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
import { AssertionError } from 'tsafe';
import type { EnqueueBatchArg } from './utils/executeEnqueue';
import { executeEnqueue } from './utils/executeEnqueue';
import { buildGraphForBase } from './utils/graphBuilders';
const log = logger('generation');
export const enqueueRequestedGenerate = createAction('app/enqueueRequestedGenerate');
const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedGenerate());
const state = getState();
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return;
}
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
switch (base) {
case 'sdxl':
return await buildSDXLGraph(graphBuilderArg);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(graphBuilderArg);
case `sd-3`:
return await buildSD3Graph(graphBuilderArg);
case `flux`:
return await buildFLUXGraph(graphBuilderArg);
case 'cogview4':
return await buildCogView4Graph(graphBuilderArg);
case 'imagen3':
return buildImagen3Graph(graphBuilderArg);
case 'imagen4':
return buildImagen4Graph(graphBuilderArg);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(graphBuilderArg);
case 'flux-kontext':
return buildFluxKontextGraph(graphBuilderArg);
case 'gemini-2.5':
return buildGemini2_5Graph(graphBuilderArg);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, seed, positivePrompt } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'generate',
destination: 'generate',
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const batchConfig = prepareBatchResult.value;
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
);
const enqueueResult = await req.unwrap();
// Push to prompt history on successful enqueue
dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
return { batchConfig, enqueueResult };
type GenerateBuildResult = {
batchConfig: EnqueueBatchArg;
};
export const useEnqueueGenerate = () => {
const store = useAppStore();
const enqueue = useCallback(
(prepend: boolean) => {
return enqueueGenerate(store, prepend);
return executeEnqueue({
store,
options: { prepend },
requestedAction: enqueueRequestedGenerate,
log,
build: async ({ store: innerStore, options }) => {
const state = innerStore.getState();
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return null;
}
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
const buildGraphResult = await withResultAsync(
async () => await buildGraphForBase(model.base, graphBuilderArg)
);
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({ status, title, description });
return null;
}
const { g, seed, positivePrompt } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base: model.base,
prepend: options.prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'generate',
destination: 'generate',
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return null;
}
return {
batchConfig: prepareBatchResult.value,
} satisfies GenerateBuildResult;
},
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
onSuccess: ({ store: innerStore }) => {
const state = innerStore.getState();
innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
},
});
},
[store]
);
return enqueue;
};

View File

@@ -1,62 +1,64 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
import { useCallback } from 'react';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { EnqueueBatchArg } from './utils/executeEnqueue';
import { executeEnqueue } from './utils/executeEnqueue';
export const enqueueRequestedUpscaling = createAction('app/enqueueRequestedUpscaling');
const log = logger('generation');
const enqueueUpscaling = async (store: AppStore, prepend: boolean) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedUpscaling());
const state = getState();
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return;
}
const base = model.base;
const { g, seed, positivePrompt } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch({
state,
g,
base,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'upscaling',
destination: 'gallery',
});
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
);
const enqueueResult = await req.unwrap();
// Push to prompt history on successful enqueue
dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
return { batchConfig, enqueueResult };
type UpscaleBuildResult = {
batchConfig: EnqueueBatchArg;
};
export const useEnqueueUpscaling = () => {
const store = useAppStore();
const enqueue = useCallback(
(prepend: boolean) => {
return enqueueUpscaling(store, prepend);
return executeEnqueue({
store,
options: { prepend },
requestedAction: enqueueRequestedUpscaling,
log,
build: async ({ store: innerStore, options }) => {
const state = innerStore.getState();
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return null;
}
const { g, seed, positivePrompt } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch({
state,
g,
base: model.base,
prepend: options.prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'upscaling',
destination: 'gallery',
});
return { batchConfig } satisfies UpscaleBuildResult;
},
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
onSuccess: ({ store: innerStore }) => {
const state = innerStore.getState();
innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
},
});
},
[store]
);
return enqueue;
};

View File

@@ -1,7 +1,6 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
@@ -14,114 +13,107 @@ import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
import { AssertionError } from 'tsafe';
import type { EnqueueBatchArg } from './utils/executeEnqueue';
import { executeEnqueue } from './utils/executeEnqueue';
const log = logger('generation');
export const enqueueRequestedVideos = createAction('app/enqueueRequestedVideos');
const enqueueVideo = async (store: AppStore, prepend: boolean) => {
const { dispatch, getState } = store;
type VideoBuildResult = {
batchConfig: EnqueueBatchArg;
};
dispatch(enqueueRequestedVideos());
const state = getState();
const model = state.video.videoModel;
if (!model) {
log.error('No model found in state');
return;
const getVideoGraphBuilder = (base: string) => {
switch (base) {
case 'veo3':
return buildVeo3VideoGraph;
case 'runway':
return buildRunwayVideoGraph;
default:
return null;
}
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
switch (base) {
case 'veo3':
return await buildVeo3VideoGraph(graphBuilderArg);
case 'runway':
return await buildRunwayVideoGraph(graphBuilderArg);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, positivePrompt, seed } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'videos',
destination: 'gallery',
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const batchConfig = prepareBatchResult.value;
// const batchConfig = {
// prepend,
// batch: {
// graph: g.getGraph(),
// runs: 1,
// origin,
// destination,
// },
// };
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
);
const enqueueResult = await req.unwrap();
// Push to prompt history on successful enqueue
dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
return { batchConfig, enqueueResult };
};
export const useEnqueueVideo = () => {
const store = useAppStore();
const enqueue = useCallback(
(prepend: boolean) => {
return enqueueVideo(store, prepend);
return executeEnqueue({
store,
options: { prepend },
requestedAction: enqueueRequestedVideos,
log,
build: async ({ store: innerStore, options }) => {
const state = innerStore.getState();
const model = state.video.videoModel;
if (!model) {
log.error('No model found in state');
return null;
}
const builder = getVideoGraphBuilder(model.base);
if (!builder) {
log.error({ base: model.base }, 'No graph builders for base');
return null;
}
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
const buildGraphResult = await withResultAsync(async () => await builder(graphBuilderArg));
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({ status, title, description });
return null;
}
const { g, positivePrompt, seed } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base: model.base,
prepend: options.prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'videos',
destination: 'gallery',
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return null;
}
return {
batchConfig: prepareBatchResult.value,
} satisfies VideoBuildResult;
},
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
onSuccess: ({ store: innerStore }) => {
const state = innerStore.getState();
innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
},
});
},
[store]
);
return enqueue;
};

View File

@@ -1,5 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppDispatch, AppStore, RootState } from 'app/store/store';
import type { AppDispatch, RootState } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { groupBy } from 'es-toolkit/compat';
import {
@@ -15,10 +15,11 @@ import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { useCallback } from 'react';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { Batch, EnqueueBatchArg, S } from 'services/api/types';
import { assert } from 'tsafe';
import { executeEnqueue } from './utils/executeEnqueue';
export const enqueueRequestedWorkflows = createAction('app/enqueueRequestedWorkflows');
const getBatchDataForWorkflowGeneration = async (state: RootState, dispatch: AppDispatch): Promise<Batch['data']> => {
@@ -119,60 +120,50 @@ const getValidationRunData = (state: RootState, templates: Templates): S['Valida
};
};
const enqueueWorkflows = async (
store: AppStore,
templates: Templates,
prepend: boolean,
isApiValidationRun: boolean
) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedWorkflows());
const state = getState();
const nodesState = selectNodesSlice(state);
const graph = buildNodesGraph(state, templates);
const workflow = buildWorkflowWithValidation(nodesState);
if (workflow) {
// embedded workflows don't have an id
delete workflow.id;
}
const runs = state.params.iterations;
const data = await getBatchDataForWorkflowGeneration(state, dispatch);
const batchConfig: EnqueueBatchArg = {
batch: {
graph,
workflow,
runs,
origin: 'workflows',
destination: 'gallery',
data,
},
prepend,
};
if (isApiValidationRun) {
batchConfig.validation_run_data = getValidationRunData(state, templates);
// If the batch is an API validation run, we only want to run it once
batchConfig.batch.runs = 1;
}
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
);
const enqueueResult = await req.unwrap();
return { batchConfig, enqueueResult };
};
export const useEnqueueWorkflows = () => {
const store = useAppStore();
const enqueue = useCallback(
(prepend: boolean, isApiValidationRun: boolean) => {
return enqueueWorkflows(store, $templates.get(), prepend, isApiValidationRun);
return executeEnqueue({
store,
options: { prepend, isApiValidationRun },
requestedAction: enqueueRequestedWorkflows,
build: async ({ store: innerStore, options }) => {
const { dispatch, getState } = innerStore;
const state = getState();
const nodesState = selectNodesSlice(state);
const templates = $templates.get();
const graph = buildNodesGraph(state, templates);
const workflow = buildWorkflowWithValidation(nodesState);
if (workflow) {
// embedded workflows don't have an id
delete workflow.id;
}
const data = await getBatchDataForWorkflowGeneration(state, dispatch);
const batchConfig: EnqueueBatchArg = {
batch: {
graph,
workflow,
runs: state.params.iterations,
origin: 'workflows',
destination: 'gallery',
data,
},
prepend: options.prepend,
};
if (options.isApiValidationRun) {
batchConfig.validation_run_data = getValidationRunData(state, templates);
batchConfig.batch.runs = 1;
}
return { batchConfig } satisfies { batchConfig: EnqueueBatchArg };
},
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
});
},
[store]
);

View File

@@ -0,0 +1,107 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStore, RootState } from 'app/store/store';
import type { EnqueueBatchArg, EnqueueBatchResponse } from './executeEnqueue';
import { executeEnqueue } from './executeEnqueue';
import { describe, expect, it, vi } from 'vitest';
const createTestStore = () => {
const state = {} as RootState;
const dispatch = vi.fn<(action: unknown) => unknown>((action) => {
if (typeof action === 'object' && action !== null && 'type' in action) {
return undefined;
}
const unwrap = vi.fn<() => Promise<EnqueueBatchResponse>>().mockResolvedValue({
batch_id: 'batch-1',
item_ids: ['item-1'],
} as EnqueueBatchResponse);
return { unwrap };
});
const getState = vi.fn(() => state);
return { dispatch, getState } as unknown as AppStore;
};
const createBatchArg = (prepend: boolean): EnqueueBatchArg => ({
prepend,
batch: {
graph: {} as EnqueueBatchArg['batch']['graph'],
runs: 1,
data: [],
origin: 'test',
destination: 'test',
},
});
describe('executeEnqueue', () => {
it('dispatches enqueue flow and invokes success callback', async () => {
const store = createTestStore();
const requestedAction = createAction('test/enqueue');
const options = { prepend: false } as const;
const batchConfig = createBatchArg(options.prepend);
const onSuccess = vi.fn();
const build = vi.fn(async () => ({ batchConfig }));
const prepareBatch = vi.fn(() => batchConfig);
const result = await executeEnqueue({
store,
options,
requestedAction,
build,
prepareBatch,
onSuccess,
log: { error: vi.fn() },
});
expect(store.dispatch).toHaveBeenCalledWith(requestedAction());
expect(build).toHaveBeenCalledWith({ store, options });
expect(prepareBatch).toHaveBeenCalledWith({ store, options, buildResult: { batchConfig } });
expect(onSuccess).toHaveBeenCalled();
expect(result?.batchConfig).toBe(batchConfig);
});
it('stops when build returns null', async () => {
const store = createTestStore();
const requestedAction = createAction('test/enqueue');
const options = { prepend: true } as const;
const build = vi.fn(async () => null);
const prepareBatch = vi.fn();
const result = await executeEnqueue({
store,
options,
requestedAction,
build,
prepareBatch,
log: { error: vi.fn() },
});
expect(result).toBeNull();
expect(build).toHaveBeenCalled();
expect(prepareBatch).not.toHaveBeenCalled();
});
it('invokes onError when build throws', async () => {
const store = createTestStore();
const requestedAction = createAction('test/enqueue');
const options = { prepend: false } as const;
const error = new Error('boom');
const build = vi.fn(async () => {
throw error;
});
const onError = vi.fn();
const logError = vi.fn();
const result = await executeEnqueue({
store,
options,
requestedAction,
build,
prepareBatch: vi.fn(),
onError,
log: { error: logError },
});
expect(result).toBeNull();
expect(onError).toHaveBeenCalledWith({ store, options, error });
expect(logError).toHaveBeenCalled();
});
});

View File

@@ -0,0 +1,70 @@
import type { ActionCreatorWithoutPayload } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { paths } from 'services/api/schema';
export type EnqueueBatchArg =
paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['requestBody']['content']['application/json'];
export type EnqueueBatchResponse =
paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['responses']['201']['content']['application/json'];
export type EnqueueOptionsBase = { prepend: boolean };
interface ExecuteEnqueueConfig<TOptions extends EnqueueOptionsBase, TBuildResult> {
store: AppStore;
options: TOptions;
requestedAction: ActionCreatorWithoutPayload<string>;
build: (context: { store: AppStore; options: TOptions }) => Promise<TBuildResult | null>;
prepareBatch: (context: { store: AppStore; options: TOptions; buildResult: TBuildResult }) => EnqueueBatchArg;
onSuccess?: (context: {
store: AppStore;
options: TOptions;
buildResult: TBuildResult;
batch: EnqueueBatchArg;
response: EnqueueBatchResponse;
}) => void;
onError?: (context: { store: AppStore; options: TOptions; error: unknown }) => void;
log?: ReturnType<typeof logger>;
}
export const executeEnqueue = async <TOptions extends EnqueueOptionsBase, TBuildResult>({
store,
options,
requestedAction,
build,
prepareBatch,
onSuccess,
onError,
log = logger('enqueue'),
}: ExecuteEnqueueConfig<TOptions, TBuildResult>) => {
const { dispatch } = store;
dispatch(requestedAction());
try {
const buildResult = await build({ store, options });
if (!buildResult) {
return null;
}
const batchConfig = prepareBatch({ store, options, buildResult });
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
);
const enqueueResult = await req.unwrap();
onSuccess?.({ store, options, buildResult, batch: batchConfig, response: enqueueResult });
return { batchConfig, enqueueResult };
} catch (error) {
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
onError?.({ store, options, error });
return null;
}
};

View File

@@ -0,0 +1,81 @@
import { describe, expect, it, vi } from 'vitest';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
import type { Invocation } from 'services/api/types';
import type { RootState } from 'app/store/store';
const mocks = vi.hoisted(() => {
const mockGraph: Graph = {} as Graph;
const mockPrompt = { id: 'prompt-node' } as Invocation<'string'>;
const asyncReturnValue = { g: mockGraph, positivePrompt: mockPrompt };
const syncReturnValue = { g: mockGraph, positivePrompt: mockPrompt };
return {
asyncReturnValue,
syncReturnValue,
buildSDXLGraphMock: vi.fn().mockResolvedValue(asyncReturnValue),
buildImagen3GraphMock: vi.fn().mockReturnValue(syncReturnValue),
createDefaultBuilder: () => vi.fn().mockResolvedValue(asyncReturnValue),
};
});
vi.mock('features/nodes/util/graph/generation/buildSDXLGraph', () => ({
buildSDXLGraph: mocks.buildSDXLGraphMock,
}));
vi.mock('features/nodes/util/graph/generation/buildSD1Graph', () => ({
buildSD1Graph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildSD3Graph', () => ({
buildSD3Graph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildFLUXGraph', () => ({
buildFLUXGraph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildFluxKontextGraph', () => ({
buildFluxKontextGraph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildCogView4Graph', () => ({
buildCogView4Graph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildImagen3Graph', () => ({
buildImagen3Graph: mocks.buildImagen3GraphMock,
}));
vi.mock('features/nodes/util/graph/generation/buildImagen4Graph', () => ({
buildImagen4Graph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildChatGPT4oGraph', () => ({
buildChatGPT4oGraph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildGemini2_5Graph', () => ({
buildGemini2_5Graph: mocks.createDefaultBuilder(),
}));
import { buildGraphForBase } from './graphBuilders';
describe('buildGraphForBase', () => {
const baseArg: GraphBuilderArg = {
generationMode: 'txt2img',
state: {} as RootState,
manager: null,
};
it('awaits asynchronous graph builders', async () => {
const result = await buildGraphForBase('sdxl', baseArg);
expect(result).toBe(mocks.asyncReturnValue);
expect(mocks.buildSDXLGraphMock).toHaveBeenCalledWith(baseArg);
});
it('supports synchronous graph builders', async () => {
const result = await buildGraphForBase('imagen3', baseArg);
expect(result).toBe(mocks.syncReturnValue);
expect(mocks.buildImagen3GraphMock).toHaveBeenCalledWith(baseArg);
});
it('throws for unknown base models', async () => {
await expect(buildGraphForBase('unknown-model', baseArg)).rejects.toThrow(
'No graph builders for base unknown-model'
);
});
});

View File

@@ -0,0 +1,34 @@
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildGemini2_5Graph } from 'features/nodes/util/graph/generation/buildGemini2_5Graph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
import { assert } from 'tsafe';
type GraphBuilderFn = (arg: GraphBuilderArg) => GraphBuilderReturn | Promise<GraphBuilderReturn>;
const graphBuilderMap: Record<string, GraphBuilderFn> = {
sdxl: buildSDXLGraph,
'sd-1': buildSD1Graph,
'sd-2': buildSD1Graph,
'sd-3': buildSD3Graph,
flux: buildFLUXGraph,
'flux-kontext': buildFluxKontextGraph,
cogview4: buildCogView4Graph,
imagen3: buildImagen3Graph,
imagen4: buildImagen4Graph,
'chatgpt-4o': buildChatGPT4oGraph,
'gemini-2.5': buildGemini2_5Graph,
};
export const buildGraphForBase = async (base: string, arg: GraphBuilderArg) => {
const builder = graphBuilderMap[base];
assert(builder, `No graph builders for base ${base}`);
return await builder(arg);
};

View File

@@ -309,7 +309,7 @@ const getReasonsWhyCannotEnqueueVideoTab = (arg: {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (video.videoModel?.base === 'runway' && !video.startingFrameImage?.image_name) {
if (video.videoModel?.base === 'runway' && !video.startingFrameImage?.original.image.image_name) {
reasons.push({ content: i18n.t('parameters.invoke.noStartingFrameImage') });
}

View File

@@ -1,20 +1,25 @@
import { Flex, FormLabel, Text } from '@invoke-ai/ui-library';
import { Flex, FormLabel, Icon, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { objectEquals } from '@observ33r/object-equals';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { ASPECT_RATIO_MAP } from 'features/controlLayers/store/types';
import { imageDTOToCroppableImage, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { Editor } from 'features/cropper/lib/editor';
import { cropImageModalApi } from 'features/cropper/store';
import { videoFrameFromImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { DndImageIcon, imageButtonSx } from 'features/dnd/DndImageIcon';
import {
selectStartingFrameImage,
selectVideoAspectRatio,
selectVideoModelRequiresStartingFrame,
startingFrameImageChanged,
} from 'features/parameters/store/videoSlice';
import { t } from 'i18next';
import { useCallback } from 'react';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { useImageDTO } from 'services/api/endpoints/images';
import { useCallback, useMemo } from 'react';
import { PiArrowCounterClockwiseBold, PiCropBold, PiWarningBold } from 'react-icons/pi';
import { useImageDTO, useUploadImageMutation } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
const dndTargetData = videoFrameFromImageDndTarget.getData({ frame: 'start' });
@@ -23,7 +28,10 @@ export const StartingFrameImage = () => {
const dispatch = useAppDispatch();
const requiresStartingFrame = useAppSelector(selectVideoModelRequiresStartingFrame);
const startingFrameImage = useAppSelector(selectStartingFrameImage);
const imageDTO = useImageDTO(startingFrameImage?.image_name);
const originalImageDTO = useImageDTO(startingFrameImage?.original.image.image_name);
const croppedImageDTO = useImageDTO(startingFrameImage?.crop?.image.image_name);
const videoAspectRatio = useAppSelector(selectVideoAspectRatio);
const [uploadImage] = useUploadImageMutation();
const onReset = useCallback(() => {
dispatch(startingFrameImageChanged(null));
@@ -31,27 +39,106 @@ export const StartingFrameImage = () => {
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
dispatch(startingFrameImageChanged(imageDTOToImageWithDims(imageDTO)));
dispatch(startingFrameImageChanged(imageDTOToCroppableImage(imageDTO)));
},
[dispatch]
);
const edit = useCallback(() => {
if (!originalImageDTO) {
return;
}
// We will create a new editor instance each time the user wants to edit
const editor = new Editor();
// When the user applies the crop, we will upload the cropped image and store the applied crop box so if the user
// re-opens the editor they see the same crop
const onApplyCrop = async () => {
const box = editor.getCropBox();
if (objectEquals(box, startingFrameImage?.crop?.box)) {
// If the box hasn't changed, don't do anything
return;
}
if (!box || objectEquals(box, { x: 0, y: 0, width: originalImageDTO.width, height: originalImageDTO.height })) {
// There is a crop applied but it is the whole iamge - revert to original image
dispatch(startingFrameImageChanged(imageDTOToCroppableImage(originalImageDTO)));
return;
}
const blob = await editor.exportImage('blob');
const file = new File([blob], 'image.png', { type: 'image/png' });
const newCroppedImageDTO = await uploadImage({
file,
is_intermediate: true,
image_category: 'user',
}).unwrap();
dispatch(
startingFrameImageChanged(
imageDTOToCroppableImage(originalImageDTO, {
image: imageDTOToImageWithDims(newCroppedImageDTO),
box,
ratio: editor.getCropAspectRatio(),
})
)
);
};
const onReady = async () => {
const initial = startingFrameImage?.crop
? { cropBox: startingFrameImage.crop.box, aspectRatio: startingFrameImage.crop.ratio }
: undefined;
// Load the image into the editor and open the modal once it's ready
await editor.loadImage(originalImageDTO.image_url, initial);
};
cropImageModalApi.open({ editor, onApplyCrop, onReady });
}, [dispatch, originalImageDTO, startingFrameImage?.crop, uploadImage]);
const fitsCurrentAspectRatio = useMemo(() => {
const imageDTO = croppedImageDTO ?? originalImageDTO;
if (!imageDTO) {
return true;
}
const imageRatio = imageDTO.width / imageDTO.height;
const targetRatio = ASPECT_RATIO_MAP[videoAspectRatio].ratio;
// Call it a fit if the image is within 10% of the target aspect ratio
return Math.abs((imageRatio - targetRatio) / targetRatio) < 0.1;
}, [croppedImageDTO, originalImageDTO, videoAspectRatio]);
return (
<Flex justifyContent="flex-start" flexDir="column" gap={2}>
<FormLabel>{t('parameters.startingFrameImage')}</FormLabel>
<FormLabel display="flex" alignItems="center" gap={2}>
<Text>{t('parameters.startingFrameImage')}</Text>
{!fitsCurrentAspectRatio && (
<Tooltip label={t('parameters.startingFrameImageAspectRatioWarning', { videoAspectRatio: videoAspectRatio })}>
<Flex alignItems="center">
<Icon as={PiWarningBold} size={16} color="warning.300" />
</Flex>
</Tooltip>
)}
</FormLabel>
<Flex position="relative" w={36} h={36} alignItems="center" justifyContent="center">
{!imageDTO && (
{!originalImageDTO && (
<UploadImageIconButton
w="full"
h="full"
isError={requiresStartingFrame && !imageDTO}
isError={requiresStartingFrame && !originalImageDTO}
onUpload={onUpload}
fontSize={36}
/>
)}
{imageDTO && (
{originalImageDTO && (
<>
<DndImage imageDTO={imageDTO} borderRadius="base" borderWidth={1} borderStyle="solid" />
<DndImage
imageDTO={croppedImageDTO ?? originalImageDTO}
borderRadius="base"
borderWidth={1}
borderStyle="solid"
/>
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<DndImageIcon
onClick={onReset}
@@ -59,6 +146,18 @@ export const StartingFrameImage = () => {
tooltip={t('common.reset')}
/>
</Flex>
<Flex position="absolute" flexDir="column" top={1} insetInlineStart={1} gap={1}>
<IconButton
variant="link"
sx={imageButtonSx}
aria-label={t('common.crop')}
onClick={edit}
icon={<PiCropBold size={16} />}
tooltip={t('common.crop')}
/>
</Flex>
<Text
position="absolute"
background="base.900"
@@ -73,7 +172,7 @@ export const StartingFrameImage = () => {
borderTopEndRadius="base"
borderBottomStartRadius="base"
pointerEvents="none"
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
>{`${croppedImageDTO?.width ?? originalImageDTO.width}x${croppedImageDTO?.height ?? originalImageDTO.height}`}</Text>
</>
)}
<DndDropTarget label="Drop" dndTarget={videoFrameFromImageDndTarget} dndTargetData={dndTargetData} />

View File

@@ -81,6 +81,10 @@ export const useHotkeyData = (): HotkeysData => {
addHotkey('app', 'selectGenerateTab', ['1']);
addHotkey('app', 'selectCanvasTab', ['2']);
addHotkey('app', 'selectUpscalingTab', ['3']);
// Prompt/history navigation (when prompt textarea is focused)
addHotkey('app', 'promptHistoryPrev', ['alt+up']);
addHotkey('app', 'promptHistoryNext', ['alt+down']);
if (isVideoEnabled) {
addHotkey('app', 'selectVideoTab', ['4']);
addHotkey('app', 'selectWorkflowsTab', ['5']);

View File

@@ -3,7 +3,7 @@ import { useAppStore } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
import { refImageAdded } from 'features/controlLayers/store/refImagesSlice';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
import { addGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { LaunchpadButton } from 'features/ui/layouts/LaunchpadButton';
@@ -23,7 +23,7 @@ export const LaunchpadAddStyleReference = memo((props: { extraAction?: () => voi
({
onUpload: (imageDTO: ImageDTO) => {
const config = getDefaultRefImageConfig(getState);
config.image = imageDTOToImageWithDims(imageDTO);
config.image = imageDTOToCroppableImage(imageDTO);
dispatch(refImageAdded({ overrides: { config } }));
props.extraAction?.();
},

View File

@@ -1,7 +1,7 @@
import { Flex, Heading, Icon, Text } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { imageDTOToCroppableImage } from 'features/controlLayers/store/util';
import { videoFrameFromImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { startingFrameImageChanged } from 'features/parameters/store/videoSlice';
@@ -21,7 +21,7 @@ export const LaunchpadStartingFrameButton = memo((props: { extraAction?: () => v
() =>
({
onUpload: (imageDTO: ImageDTO) => {
dispatch(startingFrameImageChanged(imageDTOToImageWithDims(imageDTO)));
dispatch(startingFrameImageChanged(imageDTOToCroppableImage(imageDTO)));
props.extraAction?.();
},
allowMultiple: false,

View File

@@ -1,6 +1,7 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { $authToken } from 'app/store/nanostores/authToken';
import { getStore } from 'app/store/nanostores/store';
import type { CroppableImageWithDims } from 'features/controlLayers/store/types';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
import type { components, paths } from 'services/api/schema';
import type {
@@ -593,3 +594,10 @@ export const useImageDTO = (imageName: string | null | undefined) => {
const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
return imageDTO ?? null;
};
export const useImageDTOFromCroppableImage = (croppableImage: CroppableImageWithDims | null) => {
const { currentData: imageDTO } = useGetImageDTOQuery(
croppableImage?.crop?.image.image_name ?? croppableImage?.original.image.image_name ?? skipToken
);
return imageDTO ?? null;
};

View File

@@ -1 +1 @@
__version__ = "6.7.0"
__version__ = "6.8.0rc1"