diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 44cba5629d..fff29142fd 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -225,7 +225,16 @@
"prompt": {
"addPromptTrigger": "Add Prompt Trigger",
"compatibleEmbeddings": "Compatible Embeddings",
- "noMatchingTriggers": "No matching triggers"
+ "noMatchingTriggers": "No matching triggers",
+ "generateFromImage": "Generate prompt from image",
+ "expandCurrentPrompt": "Expand Current Prompt",
+ "uploadImageForPromptGeneration": "Upload Image for Prompt Generation",
+ "expandingPrompt": "Expanding prompt...",
+ "resultTitle": "Prompt Expansion Complete",
+ "resultSubtitle": "Choose how to handle the expanded prompt:",
+ "replace": "Replace",
+ "insert": "Insert",
+ "discard": "Discard"
},
"queue": {
"queue": "Queue",
@@ -342,7 +351,7 @@
"copy": "Copy",
"currentlyInUse": "This image is currently in use in the following features:",
"drop": "Drop",
- "dropOrUpload": "$t(gallery.drop) or Upload",
+ "dropOrUpload": "Drop or Upload",
"dropToUpload": "$t(gallery.drop) to Upload",
"deleteImage_one": "Delete Image",
"deleteImage_other": "Delete {{count}} Images",
@@ -396,7 +405,8 @@
"compareHelp4": "Press Z or Esc to exit.",
"openViewer": "Open Viewer",
"closeViewer": "Close Viewer",
- "move": "Move"
+ "move": "Move",
+ "useForPromptGeneration": "Use for Prompt Generation"
},
"hotkeys": {
"hotkeys": "Hotkeys",
@@ -938,7 +948,8 @@
"selectModel": "Select a Model",
"noLoRAsInstalled": "No LoRAs installed",
"noRefinerModelsInstalled": "No SDXL Refiner models installed",
- "defaultVAE": "Default VAE"
+ "defaultVAE": "Default VAE",
+ "noCompatibleLoRAs": "No Compatible LoRAs"
},
"nodes": {
"arithmeticSequence": "Arithmetic Sequence",
@@ -1188,7 +1199,9 @@
"canvasIsSelectingObject": "Canvas is busy (selecting object)",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
- "systemDisconnected": "System disconnected"
+ "systemDisconnected": "System disconnected",
+ "promptExpansionPending": "Prompt expansion in progress",
+ "promptExpansionResultPending": "Please accept or discard your prompt expansion result"
},
"maskBlur": "Mask Blur",
"negativePromptPlaceholder": "Negative Prompt",
@@ -1389,7 +1402,12 @@
"fluxKontextIncompatibleGenerationMode": "Flux Kontext supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
- "workflowUnpublished": "Workflow Unpublished"
+ "workflowUnpublished": "Workflow Unpublished",
+ "sentToCanvas": "Sent to Canvas",
+ "sentToUpscale": "Sent to Upscale",
+ "promptGenerationStarted": "Prompt generation started",
+ "uploadAndPromptGenerationFailed": "Failed to upload image and generate prompt",
+ "promptExpansionFailed": "Prompt expansion failed"
},
"popovers": {
"clipSkip": {
diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts
index 000b99a1c5..afa3a402aa 100644
--- a/invokeai/frontend/web/src/app/types/invokeai.ts
+++ b/invokeai/frontend/web/src/app/types/invokeai.ts
@@ -78,6 +78,7 @@ export type AppConfig = {
allowPrivateStylePresets: boolean;
allowClientSideUpload: boolean;
allowPublishWorkflows: boolean;
+ allowPromptExpansion: boolean;
disabledTabs: TabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];
diff --git a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx
index 9c19a31e5f..ff53b8ec79 100644
--- a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx
+++ b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx
@@ -21,11 +21,15 @@ type UseImageUploadButtonArgs =
isDisabled?: boolean;
allowMultiple: false;
onUpload?: (imageDTO: ImageDTO) => void;
+ onUploadStarted?: (files: File) => void;
+ onError?: (error: unknown) => void;
}
| {
isDisabled?: boolean;
allowMultiple: true;
onUpload?: (imageDTOs: ImageDTO[]) => void;
+ onUploadStarted?: (files: File[]) => void;
+ onError?: (error: unknown) => void;
};
const log = logger('gallery');
@@ -49,7 +53,13 @@ const log = logger('gallery');
* // will open the file dialog on click
* // hidden, handles native upload functionality
*/
-export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
+export const useImageUploadButton = ({
+ onUpload,
+ isDisabled,
+ allowMultiple,
+ onUploadStarted,
+ onError,
+}: UseImageUploadButtonArgs) => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
const [uploadImage, request] = useUploadImageMutation();
@@ -71,6 +81,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
}
const file = files[0];
assert(file !== undefined); // should never happen
+ onUploadStarted?.(file);
const imageDTO = await uploadImage({
file,
image_category: 'user',
@@ -82,6 +93,8 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
onUpload(imageDTO);
}
} else {
+ onUploadStarted?.(files);
+
let imageDTOs: ImageDTO[] = [];
if (isClientSideUploadEnabled && files.length > 1) {
imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i)));
@@ -102,6 +115,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
}
}
} catch (error) {
+ onError?.(error);
toast({
id: 'UPLOAD_FAILED',
title: t('toast.imageUploadFailed'),
@@ -109,7 +123,17 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
});
}
},
- [allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload, t]
+ [
+ allowMultiple,
+ onUploadStarted,
+ uploadImage,
+ autoAddBoardId,
+ onUpload,
+ isClientSideUploadEnabled,
+ clientSideUpload,
+ onError,
+ t,
+ ]
);
const onDropRejected = useCallback(
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts
index 79f5a956f7..98de4fc76a 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts
@@ -270,7 +270,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
outputNodeId: string;
options?: RunGraphOptions;
}): Promise => {
- const dependencies = buildRunGraphDependencies(this.store, this.manager.socket);
+ const dependencies = buildRunGraphDependencies(this.store.dispatch, this.manager.socket);
const { output } = await runGraph({
dependencies,
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/util.ts b/invokeai/frontend/web/src/features/controlLayers/store/util.ts
index 77519051e3..aad3669df1 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/util.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/util.ts
@@ -19,6 +19,7 @@ import type {
RgbColor,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
+import type { ImageField } from 'features/nodes/types/common';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import type { PartialDeep } from 'type-fest';
@@ -59,6 +60,8 @@ export const imageDTOToImageWithDims = ({ image_name, width, height }: ImageDTO)
height,
});
+export const imageDTOToImageField = ({ image_name }: ImageDTO): ImageField => ({ image_name });
+
const DEFAULT_RG_MASK_FILL_COLORS: RgbColor[] = [
{ r: 121, g: 157, b: 219 }, // rgb(121, 157, 219)
{ r: 131, g: 214, b: 131 }, // rgb(131, 214, 131)
diff --git a/invokeai/frontend/web/src/features/dnd/dnd.ts b/invokeai/frontend/web/src/features/dnd/dnd.ts
index 5247b82b11..452f0494dd 100644
--- a/invokeai/frontend/web/src/features/dnd/dnd.ts
+++ b/invokeai/frontend/web/src/features/dnd/dnd.ts
@@ -22,6 +22,8 @@ import {
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import { selectFieldInputInstanceSafe, selectNodesSlice } from 'features/nodes/store/selectors';
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
+import { expandPrompt } from 'features/prompt/PromptExpansion/expand';
+import { promptExpansionApi } from 'features/prompt/PromptExpansion/state';
import type { ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
@@ -515,23 +517,48 @@ export const removeImageFromBoardDndTarget: DndTarget<
//#endregion
+//#region Prompt Generation From Image
+const _promptGenerationFromImage = buildTypeAndKey('prompt-generation-from-image');
+export type PromptGenerationFromImageDndTargetData = DndData<
+ typeof _promptGenerationFromImage.type,
+ typeof _promptGenerationFromImage.key,
+ void
+>;
+export const promptGenerationFromImageDndTarget: DndTarget<
+ PromptGenerationFromImageDndTargetData,
+ SingleImageDndSourceData
+> = {
+ ..._promptGenerationFromImage,
+ typeGuard: buildTypeGuard(_promptGenerationFromImage.key),
+ getData: buildGetData(_promptGenerationFromImage.key, _promptGenerationFromImage.type),
+ isValid: ({ sourceData }) => {
+ if (singleImageDndSource.typeGuard(sourceData)) {
+ return true;
+ }
+ return false;
+ },
+ handler: ({ sourceData, dispatch, getState }) => {
+ const { imageDTO } = sourceData.payload;
+ promptExpansionApi.setPending(imageDTO);
+ expandPrompt({ dispatch, getState, imageDTO });
+ },
+};
+//#endregion
+
export const dndTargets = [
- // Single Image
setGlobalReferenceImageDndTarget,
+ addGlobalReferenceImageDndTarget,
setRegionalGuidanceReferenceImageDndTarget,
setUpscaleInitialImageDndTarget,
setNodeImageFieldImageDndTarget,
+ addImagesToNodeImageFieldCollectionDndTarget,
setComparisonImageDndTarget,
newCanvasEntityFromImageDndTarget,
+ newCanvasFromImageDndTarget,
replaceCanvasEntityObjectsWithImageDndTarget,
addImageToBoardDndTarget,
removeImageFromBoardDndTarget,
- newCanvasFromImageDndTarget,
- addGlobalReferenceImageDndTarget,
- // Single or Multiple Image
- addImageToBoardDndTarget,
- removeImageFromBoardDndTarget,
- addImagesToNodeImageFieldCollectionDndTarget,
+ promptGenerationFromImageDndTarget,
] as const;
export type AnyDndTarget = (typeof dndTargets)[number];
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/ImageMenuItemUseForPromptGeneration.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/ImageMenuItemUseForPromptGeneration.tsx
new file mode 100644
index 0000000000..f4e0b64b5f
--- /dev/null
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/ImageMenuItemUseForPromptGeneration.tsx
@@ -0,0 +1,46 @@
+import { MenuItem } from '@invoke-ai/ui-library';
+import { useStore } from '@nanostores/react';
+import { useAppSelector, useAppStore } from 'app/store/storeHooks';
+import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
+import { expandPrompt } from 'features/prompt/PromptExpansion/expand';
+import { promptExpansionApi } from 'features/prompt/PromptExpansion/state';
+import { selectAllowPromptExpansion } from 'features/system/store/configSlice';
+import { toast } from 'features/toast/toast';
+import { memo, useCallback } from 'react';
+import { useTranslation } from 'react-i18next';
+import { PiTextTBold } from 'react-icons/pi';
+
+export const ImageMenuItemUseForPromptGeneration = memo(() => {
+ const { t } = useTranslation();
+ const { dispatch, getState } = useAppStore();
+ const imageDTO = useImageDTOContext();
+ const { isPending } = useStore(promptExpansionApi.$state);
+ const isPromptExpansionEnabled = useAppSelector(selectAllowPromptExpansion);
+
+ const handleUseForPromptGeneration = useCallback(() => {
+ promptExpansionApi.setPending(imageDTO);
+ expandPrompt({ dispatch, getState, imageDTO });
+ toast({
+ id: 'PROMPT_GENERATION_STARTED',
+ title: t('toast.promptGenerationStarted'),
+ status: 'info',
+ });
+ }, [dispatch, getState, imageDTO, t]);
+
+ if (!isPromptExpansionEnabled) {
+ return null;
+ }
+
+ return (
+ }
+ onClickCapture={handleUseForPromptGeneration}
+ id="use-for-prompt-generation"
+ isDisabled={isPending}
+ >
+ {t('gallery.useForPromptGeneration')}
+
+ );
+});
+
+ImageMenuItemUseForPromptGeneration.displayName = 'ImageMenuItemUseForPromptGeneration';
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx
index 368d265735..d664a5dc74 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx
@@ -14,6 +14,7 @@ import { ImageMenuItemSelectForCompare } from 'features/gallery/components/Image
import { ImageMenuItemSendToUpscale } from 'features/gallery/components/ImageContextMenu/ImageMenuItemSendToUpscale';
import { ImageMenuItemStarUnstar } from 'features/gallery/components/ImageContextMenu/ImageMenuItemStarUnstar';
import { ImageMenuItemUseAsRefImage } from 'features/gallery/components/ImageContextMenu/ImageMenuItemUseAsRefImage';
+import { ImageMenuItemUseForPromptGeneration } from 'features/gallery/components/ImageContextMenu/ImageMenuItemUseForPromptGeneration';
import { ImageDTOContextProvider } from 'features/gallery/contexts/ImageDTOContext';
import { memo } from 'react';
import type { ImageDTO } from 'services/api/types';
@@ -38,6 +39,7 @@ const SingleSelectionMenuItems = ({ imageDTO }: SingleSelectionMenuItemsProps) =
+
diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx
index 65945bc322..90ad4f0efa 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx
@@ -1,4 +1,5 @@
-import { Box, Textarea } from '@invoke-ai/ui-library';
+import { Box, Flex, Textarea } from '@invoke-ai/ui-library';
+import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { usePersistedTextAreaSize } from 'common/hooks/usePersistedTextareaSize';
import {
@@ -7,12 +8,17 @@ import {
selectModelSupportsNegativePrompt,
selectPositivePrompt,
} from 'features/controlLayers/store/paramsSlice';
+import { promptGenerationFromImageDndTarget } from 'features/dnd/dnd';
+import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
import { NegativePromptToggleButton } from 'features/parameters/components/Core/NegativePromptToggleButton';
import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { ViewModePrompt } from 'features/parameters/components/Prompts/ViewModePrompt';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
+import { PromptExpansionMenu } from 'features/prompt/PromptExpansion/PromptExpansionMenu';
+import { PromptExpansionOverlay } from 'features/prompt/PromptExpansion/PromptExpansionOverlay';
+import { promptExpansionApi } from 'features/prompt/PromptExpansion/state';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { SDXLConcatButton } from 'features/sdxl/components/SDXLPrompts/SDXLConcatButton';
@@ -21,7 +27,8 @@ import {
selectStylePresetViewMode,
} from 'features/stylePresets/store/stylePresetSlice';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
-import { memo, useCallback, useRef } from 'react';
+import { selectAllowPromptExpansion } from 'features/system/store/configSlice';
+import { memo, useCallback, useMemo, useRef } from 'react';
import type { HotkeyCallback } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { useListStylePresetsQuery } from 'services/api/endpoints/stylePresets';
@@ -39,6 +46,8 @@ export const ParamPositivePrompt = memo(() => {
const viewMode = useAppSelector(selectStylePresetViewMode);
const activeStylePresetId = useAppSelector(selectStylePresetActivePresetId);
const modelSupportsNegativePrompt = useAppSelector(selectModelSupportsNegativePrompt);
+ const { isPending: isPromptExpansionPending } = useStore(promptExpansionApi.$state);
+ const isPromptExpansionEnabled = useAppSelector(selectAllowPromptExpansion);
const textareaRef = useRef(null);
usePersistedTextAreaSize('positive_prompt', textareaRef, persistOptions);
@@ -64,6 +73,7 @@ export const ParamPositivePrompt = memo(() => {
prompt,
textareaRef: textareaRef,
onChange: handleChange,
+ isDisabled: isPromptExpansionPending,
});
const focus: HotkeyCallback = useCallback(
@@ -82,41 +92,56 @@ export const ParamPositivePrompt = memo(() => {
dependencies: [focus],
});
+ const dndTargetData = useMemo(() => promptGenerationFromImageDndTarget.getData(), []);
+
return (
-
-
-
-
-
- {baseModel === 'sdxl' && }
-
- {modelSupportsNegativePrompt && }
-
-
- {viewMode && (
-
+
+
+
- )}
-
-
+
+
+
+ {baseModel === 'sdxl' && }
+
+ {modelSupportsNegativePrompt && }
+
+ {isPromptExpansionEnabled && }
+
+
+ {viewMode && (
+
+ )}
+
+
+
+
+
);
});
diff --git a/invokeai/frontend/web/src/features/parameters/components/Prompts/PromptOverlayButtonWrapper.tsx b/invokeai/frontend/web/src/features/parameters/components/Prompts/PromptOverlayButtonWrapper.tsx
index a45279fd34..430805c0ee 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Prompts/PromptOverlayButtonWrapper.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Prompts/PromptOverlayButtonWrapper.tsx
@@ -10,7 +10,9 @@ export const PromptOverlayButtonWrapper = memo((props: PropsWithChildren) => (
p={2}
gap={2}
alignItems="center"
- justifyContent="center"
+ justifyContent="space-between"
+ bottom={4} // make sure textarea resize is accessible
+ pb={0}
>
{props.children}
diff --git a/invokeai/frontend/web/src/features/prompt/PromptExpansion/PromptExpansionMenu.tsx b/invokeai/frontend/web/src/features/prompt/PromptExpansion/PromptExpansionMenu.tsx
new file mode 100644
index 0000000000..d337926fdb
--- /dev/null
+++ b/invokeai/frontend/web/src/features/prompt/PromptExpansion/PromptExpansionMenu.tsx
@@ -0,0 +1,80 @@
+import { IconButton, Menu, MenuButton, MenuItem, MenuList, Text } from '@invoke-ai/ui-library';
+import { useStore } from '@nanostores/react';
+import { useAppStore } from 'app/store/storeHooks';
+import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
+import { WrappedError } from 'common/util/result';
+import { toast } from 'features/toast/toast';
+import { useCallback } from 'react';
+import { useTranslation } from 'react-i18next';
+import { PiMagicWandBold } from 'react-icons/pi';
+import type { ImageDTO } from 'services/api/types';
+
+import { expandPrompt } from './expand';
+import { promptExpansionApi } from './state';
+
+export const PromptExpansionMenu = () => {
+ const { dispatch, getState } = useAppStore();
+ const { t } = useTranslation();
+ const { isPending } = useStore(promptExpansionApi.$state);
+
+ const onUploadStarted = useCallback(() => {
+ promptExpansionApi.setPending();
+ }, []);
+
+ const onUpload = useCallback(
+ (imageDTO: ImageDTO) => {
+ promptExpansionApi.setPending(imageDTO);
+ expandPrompt({ dispatch, getState, imageDTO });
+ },
+ [dispatch, getState]
+ );
+
+ const onUploadError = useCallback(
+ (error: unknown) => {
+ const wrappedError = WrappedError.wrap(error);
+ promptExpansionApi.setError(wrappedError);
+ toast({
+ id: 'UPLOAD_AND_PROMPT_GENERATION_FAILED',
+ title: t('toast.uploadAndPromptGenerationFailed'),
+ status: 'error',
+ });
+ },
+ [t]
+ );
+
+ const uploadApi = useImageUploadButton({
+ allowMultiple: false,
+ onUpload,
+ onUploadStarted,
+ onError: onUploadError,
+ });
+
+ const onClickExpandPrompt = useCallback(() => {
+ promptExpansionApi.setPending();
+ expandPrompt({ dispatch, getState });
+ }, [dispatch, getState]);
+
+ return (
+ <>
+
+
+ >
+ );
+};
diff --git a/invokeai/frontend/web/src/features/prompt/PromptExpansion/PromptExpansionOverlay.tsx b/invokeai/frontend/web/src/features/prompt/PromptExpansion/PromptExpansionOverlay.tsx
new file mode 100644
index 0000000000..3ba7b7ef87
--- /dev/null
+++ b/invokeai/frontend/web/src/features/prompt/PromptExpansion/PromptExpansionOverlay.tsx
@@ -0,0 +1,69 @@
+import { Box, Flex, Image, Spinner, Text } from '@invoke-ai/ui-library';
+import { useStore } from '@nanostores/react';
+import { PromptExpansionResultOverlay } from 'features/prompt/PromptExpansion/PromptExpansionResultOverlay';
+import { memo } from 'react';
+import { useTranslation } from 'react-i18next';
+import { PiMagicWandBold } from 'react-icons/pi';
+
+import { promptExpansionApi } from './state';
+
+export const PromptExpansionOverlay = memo(() => {
+ const { isSuccess, isPending, result, imageDTO } = useStore(promptExpansionApi.$state);
+ const { t } = useTranslation();
+
+ // Show result overlay when completed
+ if (isSuccess) {
+ return ;
+ }
+
+ // Show pending overlay when pending
+ if (!isPending) {
+ return null;
+ }
+
+ return (
+
+ {/* Show dimmed source image if available */}
+ {imageDTO && (
+
+
+
+ )}
+
+
+
+
+
+
+
+ {t('prompt.expandingPrompt')}
+
+
+
+ );
+});
+
+PromptExpansionOverlay.displayName = 'PromptExpansionOverlay';
diff --git a/invokeai/frontend/web/src/features/prompt/PromptExpansion/PromptExpansionResultOverlay.tsx b/invokeai/frontend/web/src/features/prompt/PromptExpansion/PromptExpansionResultOverlay.tsx
new file mode 100644
index 0000000000..015bf5946d
--- /dev/null
+++ b/invokeai/frontend/web/src/features/prompt/PromptExpansion/PromptExpansionResultOverlay.tsx
@@ -0,0 +1,76 @@
+import { ButtonGroup, Flex, Icon, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { positivePromptChanged, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
+import { useCallback } from 'react';
+import { PiCheckBold, PiMagicWandBold, PiPlusBold, PiXBold } from 'react-icons/pi';
+
+import { promptExpansionApi } from './state';
+
+interface PromptExpansionResultOverlayProps {
+ expandedText: string;
+}
+
+export const PromptExpansionResultOverlay = ({ expandedText }: PromptExpansionResultOverlayProps) => {
+ const dispatch = useAppDispatch();
+ const positivePrompt = useAppSelector(selectPositivePrompt);
+
+ const handleReplace = useCallback(() => {
+ dispatch(positivePromptChanged(expandedText));
+ promptExpansionApi.reset();
+ }, [dispatch, expandedText]);
+
+ const handleInsert = useCallback(() => {
+ const currentText = positivePrompt;
+ const newText = currentText ? `${currentText}\n${expandedText}` : expandedText;
+ dispatch(positivePromptChanged(newText));
+ promptExpansionApi.reset();
+ }, [dispatch, expandedText, positivePrompt]);
+
+ const handleDiscard = useCallback(() => {
+ promptExpansionApi.reset();
+ }, []);
+
+ return (
+
+
+
+
+ {expandedText}
+
+
+
+
+
+
+ }
+ colorScheme="invokeGreen"
+ size="xs"
+ aria-label="Replace"
+ />
+
+
+
+ }
+ colorScheme="invokeBlue"
+ size="xs"
+ aria-label="Insert"
+ />
+
+
+
+ }
+ colorScheme="invokeRed"
+ size="xs"
+ aria-label="Discard"
+ />
+
+
+
+ );
+};
diff --git a/invokeai/frontend/web/src/features/prompt/PromptExpansion/expand.ts b/invokeai/frontend/web/src/features/prompt/PromptExpansion/expand.ts
new file mode 100644
index 0000000000..a41a2d9259
--- /dev/null
+++ b/invokeai/frontend/web/src/features/prompt/PromptExpansion/expand.ts
@@ -0,0 +1,43 @@
+import type { AppDispatch, AppGetState } from 'app/store/store';
+import { toast } from 'features/toast/toast';
+import { t } from 'i18next';
+import { buildRunGraphDependencies, runGraph } from 'services/api/run-graph';
+import type { ImageDTO } from 'services/api/types';
+import { $socket } from 'services/events/stores';
+import { assert } from 'tsafe';
+
+import { buildPromptExpansionGraph } from './graph';
+import { promptExpansionApi } from './state';
+
+export const expandPrompt = async (arg: { dispatch: AppDispatch; getState: AppGetState; imageDTO?: ImageDTO }) => {
+ const { dispatch, getState, imageDTO } = arg;
+ const socket = $socket.get();
+ if (!socket) {
+ return;
+ }
+ const { graph, outputNodeId } = buildPromptExpansionGraph({
+ state: getState(),
+ imageDTO,
+ });
+ const dependencies = buildRunGraphDependencies(dispatch, socket);
+ try {
+ const { output } = await runGraph({
+ graph,
+ outputNodeId,
+ dependencies,
+ options: {
+ prepend: true,
+ timeout: 10000,
+ },
+ });
+ assert(output.type === 'string_output');
+ promptExpansionApi.setSuccess(output.value);
+ } catch (error) {
+ promptExpansionApi.reset();
+ toast({
+ id: 'PROMPT_EXPANSION_FAILED',
+ title: t('toast.promptExpansionFailed'),
+ status: 'error',
+ });
+ }
+};
diff --git a/invokeai/frontend/web/src/features/prompt/PromptExpansion/graph.ts b/invokeai/frontend/web/src/features/prompt/PromptExpansion/graph.ts
new file mode 100644
index 0000000000..dcd4b8ab6d
--- /dev/null
+++ b/invokeai/frontend/web/src/features/prompt/PromptExpansion/graph.ts
@@ -0,0 +1,43 @@
+import type { RootState } from 'app/store/store';
+import { getPrefixedId } from 'features/controlLayers/konva/util';
+import { selectBase, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
+import { imageDTOToImageField } from 'features/controlLayers/store/util';
+import { Graph } from 'features/nodes/util/graph/generation/Graph';
+import type { ImageDTO } from 'services/api/types';
+import { assert } from 'tsafe';
+
+export const buildPromptExpansionGraph = ({
+ state,
+ imageDTO,
+}: {
+ state: RootState;
+ imageDTO?: ImageDTO;
+}): { graph: Graph; outputNodeId: string } => {
+ const base = selectBase(state);
+ assert(base, 'No main model found in state');
+
+ const architecture = ['sdxl', 'sdxl-refiner'].includes(base) ? 'tag_based' : 'sentence_based';
+
+ if (imageDTO) {
+ const graph = new Graph(getPrefixedId('claude-analyze-image-graph'));
+ const outputNode = graph.addNode({
+ // @ts-expect-error: These nodes are not available in the OSS application
+ type: 'claude_analyze_image',
+ id: getPrefixedId('claude_analyze_image'),
+ model_architecture: architecture,
+ image: imageDTOToImageField(imageDTO),
+ });
+ return { graph, outputNodeId: outputNode.id };
+ } else {
+ const positivePrompt = selectPositivePrompt(state);
+ const graph = new Graph(getPrefixedId('claude-expand-prompt-graph'));
+ const outputNode = graph.addNode({
+ // @ts-expect-error: These nodes are not available in the OSS application
+ type: 'claude_expand_prompt',
+ id: getPrefixedId('claude_expand_prompt'),
+ model_architecture: architecture,
+ prompt: positivePrompt,
+ });
+ return { graph, outputNodeId: outputNode.id };
+ }
+};
diff --git a/invokeai/frontend/web/src/features/prompt/PromptExpansion/state.ts b/invokeai/frontend/web/src/features/prompt/PromptExpansion/state.ts
new file mode 100644
index 0000000000..14cafe664a
--- /dev/null
+++ b/invokeai/frontend/web/src/features/prompt/PromptExpansion/state.ts
@@ -0,0 +1,98 @@
+import { deepClone } from 'common/util/deepClone';
+import { atom } from 'nanostores';
+import type { ImageDTO } from 'services/api/types';
+
+type SuccessState = {
+ isSuccess: true;
+ isError: false;
+ isPending: false;
+ result: string;
+ error: null;
+ imageDTO?: ImageDTO;
+};
+
+type ErrorState = {
+ isSuccess: false;
+ isError: true;
+ isPending: false;
+ result: null;
+ error: Error;
+ imageDTO?: ImageDTO;
+};
+
+type PendingState = {
+ isSuccess: false;
+ isError: false;
+ isPending: true;
+ result: null;
+ error: null;
+ imageDTO?: ImageDTO;
+};
+
+type IdleState = {
+ isSuccess: false;
+ isError: false;
+ isPending: false;
+ result: null;
+ error: null;
+ imageDTO?: ImageDTO;
+};
+
+export type PromptExpansionRequestState = IdleState | PendingState | SuccessState | ErrorState;
+
+const IDLE_STATE: IdleState = {
+ isSuccess: false,
+ isError: false,
+ isPending: false,
+ result: null,
+ error: null,
+ imageDTO: undefined,
+};
+
+const $state = atom(deepClone(IDLE_STATE));
+
+const reset = () => {
+ $state.set(deepClone(IDLE_STATE));
+};
+
+const setPending = (imageDTO?: ImageDTO) => {
+ $state.set({
+ ...$state.get(),
+ isSuccess: false,
+ isError: false,
+ isPending: true,
+ result: null,
+ error: null,
+ imageDTO,
+ });
+};
+
+const setSuccess = (result: string) => {
+ $state.set({
+ ...$state.get(),
+ isSuccess: true,
+ isError: false,
+ isPending: false,
+ result,
+ error: null,
+ });
+};
+
+const setError = (error: Error) => {
+ $state.set({
+ ...$state.get(),
+ isSuccess: false,
+ isError: true,
+ isPending: false,
+ result: null,
+ error,
+ });
+};
+
+export const promptExpansionApi = {
+ $state,
+ reset,
+ setPending,
+ setSuccess,
+ setError,
+};
diff --git a/invokeai/frontend/web/src/features/prompt/usePrompt.ts b/invokeai/frontend/web/src/features/prompt/usePrompt.ts
index c01facd693..c98005f19d 100644
--- a/invokeai/frontend/web/src/features/prompt/usePrompt.ts
+++ b/invokeai/frontend/web/src/features/prompt/usePrompt.ts
@@ -8,9 +8,10 @@ type UseInsertTriggerArg = {
prompt: string;
textareaRef: RefObject;
onChange: (v: string) => void;
+ isDisabled?: boolean;
};
-export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertTriggerArg) => {
+export const usePrompt = ({ prompt, textareaRef, onChange: _onChange, isDisabled = false }: UseInsertTriggerArg) => {
const { isOpen, onClose, onOpen } = useDisclosure();
const onChange: ChangeEventHandler = useCallback(
@@ -73,12 +74,12 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
const onKeyDown: KeyboardEventHandler = useCallback(
(e) => {
- if (e.key === '<') {
+ if (e.key === '<' && !isDisabled) {
onOpen();
e.preventDefault();
}
},
- [onOpen]
+ [onOpen, isDisabled]
);
return {
diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts
index a5536ed8b8..ae2fb9a204 100644
--- a/invokeai/frontend/web/src/features/queue/store/readiness.ts
+++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts
@@ -36,6 +36,7 @@ import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { getGridSize } from 'features/parameters/util/optimalDimension';
+import { promptExpansionApi, type PromptExpansionRequestState } from 'features/prompt/PromptExpansion/state';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { TabName } from 'features/ui/store/uiTypes';
@@ -89,9 +90,22 @@ const debouncedUpdateReasons = debounce(
config: AppConfig,
store: AppStore,
isInPublishFlow: boolean,
- isChatGPT4oHighModelDisabled: (model: ParameterModel) => boolean
+ isChatGPT4oHighModelDisabled: (model: ParameterModel) => boolean,
+ promptExpansionRequest: PromptExpansionRequestState
) => {
- if (tab === 'canvas') {
+ if (tab === 'generate') {
+ const model = selectMainModelConfig(store.getState());
+ const reasons = await getReasonsWhyCannotEnqueueGenerateTab({
+ isConnected,
+ model,
+ params,
+ refImages,
+ dynamicPrompts,
+ isChatGPT4oHighModelDisabled,
+ promptExpansionRequest,
+ });
+ $reasonsWhyCannotEnqueue.set(reasons);
+ } else if (tab === 'canvas') {
const model = selectMainModelConfig(store.getState());
const reasons = await getReasonsWhyCannotEnqueueCanvasTab({
isConnected,
@@ -106,6 +120,7 @@ const debouncedUpdateReasons = debounce(
canvasIsCompositing,
canvasIsSelectingObject,
isChatGPT4oHighModelDisabled,
+ promptExpansionRequest,
});
$reasonsWhyCannotEnqueue.set(reasons);
} else if (tab === 'workflows') {
@@ -124,6 +139,7 @@ const debouncedUpdateReasons = debounce(
upscale,
config,
params,
+ promptExpansionRequest,
});
$reasonsWhyCannotEnqueue.set(reasons);
} else {
@@ -155,6 +171,7 @@ export const useReadinessWatcher = () => {
const canvasIsCompositing = useStore(canvasManager?.compositor.$isBusy ?? $false);
const isInPublishFlow = useStore($isInPublishFlow);
const { isChatGPT4oHighModelDisabled } = useIsModelDisabled();
+ const promptExpansionRequest = useStore(promptExpansionApi.$state);
useEffect(() => {
debouncedUpdateReasons(
@@ -176,7 +193,8 @@ export const useReadinessWatcher = () => {
config,
store,
isInPublishFlow,
- isChatGPT4oHighModelDisabled
+ isChatGPT4oHighModelDisabled,
+ promptExpansionRequest
);
}, [
store,
@@ -198,11 +216,88 @@ export const useReadinessWatcher = () => {
workflowSettings,
isInPublishFlow,
isChatGPT4oHighModelDisabled,
+ promptExpansionRequest,
]);
};
const disconnectedReason = (t: typeof i18n.t) => ({ content: t('parameters.invoke.systemDisconnected') });
+const getReasonsWhyCannotEnqueueGenerateTab = (arg: {
+ isConnected: boolean;
+ model: MainModelConfig | null | undefined;
+ params: ParamsState;
+ refImages: RefImagesState;
+ dynamicPrompts: DynamicPromptsState;
+ isChatGPT4oHighModelDisabled: (model: ParameterModel) => boolean;
+ promptExpansionRequest: PromptExpansionRequestState;
+}) => {
+ const {
+ isConnected,
+ model,
+ params,
+ refImages,
+ dynamicPrompts,
+ isChatGPT4oHighModelDisabled,
+ promptExpansionRequest,
+ } = arg;
+ const { positivePrompt } = params;
+ const reasons: Reason[] = [];
+
+ if (!isConnected) {
+ reasons.push(disconnectedReason(i18n.t));
+ }
+
+ if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
+ reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
+ }
+
+ if (!model) {
+ reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
+ }
+
+ if (model?.base === 'flux') {
+ if (!params.t5EncoderModel) {
+ reasons.push({ content: i18n.t('parameters.invoke.noT5EncoderModelSelected') });
+ }
+ if (!params.clipEmbedModel) {
+ reasons.push({ content: i18n.t('parameters.invoke.noCLIPEmbedModelSelected') });
+ }
+ if (!params.fluxVAE) {
+ reasons.push({ content: i18n.t('parameters.invoke.noFLUXVAEModelSelected') });
+ }
+ }
+
+ if (model && isChatGPT4oHighModelDisabled(model)) {
+ reasons.push({ content: i18n.t('parameters.invoke.modelDisabledForTrial', { modelName: model.name }) });
+ }
+
+ if (promptExpansionRequest.isPending) {
+ reasons.push({ content: i18n.t('parameters.invoke.promptExpansionPending') });
+ } else if (promptExpansionRequest.isSuccess) {
+ reasons.push({ content: i18n.t('parameters.invoke.promptExpansionResultPending') });
+ }
+
+ // Flux Kontext only supports 1x Reference Image at a time.
+ const referenceImageCount = refImages.entities.length;
+
+ if (model?.base === 'flux-kontext' && referenceImageCount > 1) {
+ reasons.push({ content: i18n.t('parameters.invoke.fluxKontextMultipleReferenceImages') });
+ }
+
+ refImages.entities.forEach((entity, i) => {
+ const layerNumber = i + 1;
+ const refImageLiteral = i18n.t(LAYER_TYPE_TO_TKEY['reference_image']);
+ const prefix = `${refImageLiteral} #${layerNumber}`;
+ const problems = getGlobalReferenceImageWarnings(entity, model);
+
+ if (problems.length) {
+ const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
+ reasons.push({ prefix, content });
+ }
+ });
+
+ return reasons;
+};
const getReasonsWhyCannotEnqueueWorkflowsTab = async (arg: {
dispatch: AppDispatch;
nodesState: NodesState;
@@ -293,8 +388,9 @@ const getReasonsWhyCannotEnqueueUpscaleTab = (arg: {
upscale: UpscaleState;
config: AppConfig;
params: ParamsState;
+ promptExpansionRequest: PromptExpansionRequestState;
}) => {
- const { isConnected, upscale, config, params } = arg;
+ const { isConnected, upscale, config, params, promptExpansionRequest } = arg;
const reasons: Reason[] = [];
if (!isConnected) {
@@ -331,6 +427,12 @@ const getReasonsWhyCannotEnqueueUpscaleTab = (arg: {
}
}
+ if (promptExpansionRequest.isPending) {
+ reasons.push({ content: i18n.t('parameters.invoke.promptExpansionPending') });
+ } else if (promptExpansionRequest.isSuccess) {
+ reasons.push({ content: i18n.t('parameters.invoke.promptExpansionResultPending') });
+ }
+
return reasons;
};
@@ -347,6 +449,7 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
canvasIsCompositing: boolean;
canvasIsSelectingObject: boolean;
isChatGPT4oHighModelDisabled: (model: ParameterModel) => boolean;
+ promptExpansionRequest: PromptExpansionRequestState;
}) => {
const {
isConnected,
@@ -361,6 +464,7 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
canvasIsCompositing,
canvasIsSelectingObject,
isChatGPT4oHighModelDisabled,
+ promptExpansionRequest,
} = arg;
const { positivePrompt } = params;
const reasons: Reason[] = [];
@@ -497,6 +601,12 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
reasons.push({ content: i18n.t('parameters.invoke.modelDisabledForTrial', { modelName: model.name }) });
}
+ if (promptExpansionRequest.isPending) {
+ reasons.push({ content: i18n.t('parameters.invoke.promptExpansionPending') });
+ } else if (promptExpansionRequest.isSuccess) {
+ reasons.push({ content: i18n.t('parameters.invoke.promptExpansionResultPending') });
+ }
+
const enabledControlLayers = canvas.controlLayers.entities.filter((controlLayer) => controlLayer.isEnabled);
// FLUX only supports 1x Control LoRA at a time.
diff --git a/invokeai/frontend/web/src/features/system/store/configSlice.ts b/invokeai/frontend/web/src/features/system/store/configSlice.ts
index 71f17d14c2..e2f9660d9b 100644
--- a/invokeai/frontend/web/src/features/system/store/configSlice.ts
+++ b/invokeai/frontend/web/src/features/system/store/configSlice.ts
@@ -25,6 +25,7 @@ const initialConfigState: AppConfig & { didLoad: boolean } = {
allowPrivateStylePresets: false,
allowClientSideUpload: false,
allowPublishWorkflows: false,
+ allowPromptExpansion: false,
shouldShowCredits: false,
disabledTabs: [],
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
@@ -228,6 +229,7 @@ export const selectMetadataFetchDebounce = createConfigSelector((config) => conf
export const selectIsModelsTabDisabled = createConfigSelector((config) => config.disabledTabs.includes('models'));
export const selectIsClientSideUploadEnabled = createConfigSelector((config) => config.allowClientSideUpload);
export const selectAllowPublishWorkflows = createConfigSelector((config) => config.allowPublishWorkflows);
+export const selectAllowPromptExpansion = createConfigSelector((config) => config.allowPromptExpansion);
export const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);
export const selectShouldShowCredits = createConfigSelector((config) => config.shouldShowCredits);
export const selectEnabledTabs = createConfigSelector((config) => {
diff --git a/invokeai/frontend/web/src/services/api/run-graph.ts b/invokeai/frontend/web/src/services/api/run-graph.ts
index 00d47bb01c..583255a81d 100644
--- a/invokeai/frontend/web/src/services/api/run-graph.ts
+++ b/invokeai/frontend/web/src/services/api/run-graph.ts
@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
-import type { AppStore } from 'app/store/store';
+import type { AppDispatch } from 'app/store/store';
import { Mutex } from 'async-mutex';
import { withResultAsync, WrappedError } from 'common/util/result';
import { parseify } from 'common/util/serialize';
@@ -135,7 +135,7 @@ export const runGraph = (arg: RunGraphArg): Promise => {
* Creates production dependencies for runGraph using Redux store and socket.
*/
export const buildRunGraphDependencies = (
- store: AppStore,
+ dispatch: AppDispatch,
socket: {
on: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
off: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
@@ -143,17 +143,15 @@ export const buildRunGraphDependencies = (
): GraphRunnerDependencies => ({
executor: {
enqueueBatch: (batch) =>
- store
- .dispatch(
- queueApi.endpoints.enqueueBatch.initiate(batch, {
- ...enqueueMutationFixedCacheKeyOptions,
- track: false,
- })
- )
- .unwrap(),
- getQueueItem: (id) => store.dispatch(queueApi.endpoints.getQueueItem.initiate(id, { subscribe: false })).unwrap(),
+ dispatch(
+ queueApi.endpoints.enqueueBatch.initiate(batch, {
+ ...enqueueMutationFixedCacheKeyOptions,
+ track: false,
+ })
+ ).unwrap(),
+ getQueueItem: (id) => dispatch(queueApi.endpoints.getQueueItem.initiate(id, { subscribe: false })).unwrap(),
cancelQueueItem: (id) =>
- store.dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id: id }, { track: false })).unwrap(),
+ dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id: id }, { track: false })).unwrap(),
},
eventHandler: {
subscribe: (handler) => socket.on('queue_item_status_changed', handler),