feat(ui): prompt expansion (#8140)

* initializing prompt expansion and putting response in prompt box working for all methods

* properly disable UI and show loading state on prompt box when there is a pending prompt expansion item

* misc wrapup: disable apploying prompt templates, dont block textarea resize handle

* update progress to differentiate between prompt expansion and non

* cleanup

* lint

* more cleanup

* add image to background of loading state

* add allowPromptExpansion for front-end gating

* updated readiness text for needing to accept or discard

* fix tsc

* lint

* lint

* refactor(ui): prompt expansion logic

* tidy(ui): remove unnecessary changes

* revert(ui): unused arg on useImageUploadButton

* feat(ui): simplify prompt expansion state

* set pending for dragndrop and context menu

* add readiness logic for generate tab

* missing translation

* update error handling for prompt expansion

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Mary Hipp Rogers
2025-07-02 10:26:48 -04:00
committed by GitHub
parent 2dd1bc54c9
commit 038010a1ca
20 changed files with 739 additions and 71 deletions

View File

@@ -225,7 +225,16 @@
"prompt": {
"addPromptTrigger": "Add Prompt Trigger",
"compatibleEmbeddings": "Compatible Embeddings",
"noMatchingTriggers": "No matching triggers"
"noMatchingTriggers": "No matching triggers",
"generateFromImage": "Generate prompt from image",
"expandCurrentPrompt": "Expand Current Prompt",
"uploadImageForPromptGeneration": "Upload Image for Prompt Generation",
"expandingPrompt": "Expanding prompt...",
"resultTitle": "Prompt Expansion Complete",
"resultSubtitle": "Choose how to handle the expanded prompt:",
"replace": "Replace",
"insert": "Insert",
"discard": "Discard"
},
"queue": {
"queue": "Queue",
@@ -342,7 +351,7 @@
"copy": "Copy",
"currentlyInUse": "This image is currently in use in the following features:",
"drop": "Drop",
"dropOrUpload": "$t(gallery.drop) or Upload",
"dropOrUpload": "Drop or Upload",
"dropToUpload": "$t(gallery.drop) to Upload",
"deleteImage_one": "Delete Image",
"deleteImage_other": "Delete {{count}} Images",
@@ -396,7 +405,8 @@
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit.",
"openViewer": "Open Viewer",
"closeViewer": "Close Viewer",
"move": "Move"
"move": "Move",
"useForPromptGeneration": "Use for Prompt Generation"
},
"hotkeys": {
"hotkeys": "Hotkeys",
@@ -938,7 +948,8 @@
"selectModel": "Select a Model",
"noLoRAsInstalled": "No LoRAs installed",
"noRefinerModelsInstalled": "No SDXL Refiner models installed",
"defaultVAE": "Default VAE"
"defaultVAE": "Default VAE",
"noCompatibleLoRAs": "No Compatible LoRAs"
},
"nodes": {
"arithmeticSequence": "Arithmetic Sequence",
@@ -1188,7 +1199,9 @@
"canvasIsSelectingObject": "Canvas is busy (selecting object)",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
"systemDisconnected": "System disconnected"
"systemDisconnected": "System disconnected",
"promptExpansionPending": "Prompt expansion in progress",
"promptExpansionResultPending": "Please accept or discard your prompt expansion result"
},
"maskBlur": "Mask Blur",
"negativePromptPlaceholder": "Negative Prompt",
@@ -1389,7 +1402,12 @@
"fluxKontextIncompatibleGenerationMode": "Flux Kontext supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
"workflowUnpublished": "Workflow Unpublished"
"workflowUnpublished": "Workflow Unpublished",
"sentToCanvas": "Sent to Canvas",
"sentToUpscale": "Sent to Upscale",
"promptGenerationStarted": "Prompt generation started",
"uploadAndPromptGenerationFailed": "Failed to upload image and generate prompt",
"promptExpansionFailed": "Prompt expansion failed"
},
"popovers": {
"clipSkip": {

View File

@@ -78,6 +78,7 @@ export type AppConfig = {
allowPrivateStylePresets: boolean;
allowClientSideUpload: boolean;
allowPublishWorkflows: boolean;
allowPromptExpansion: boolean;
disabledTabs: TabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];

View File

@@ -21,11 +21,15 @@ type UseImageUploadButtonArgs =
isDisabled?: boolean;
allowMultiple: false;
onUpload?: (imageDTO: ImageDTO) => void;
onUploadStarted?: (files: File) => void;
onError?: (error: unknown) => void;
}
| {
isDisabled?: boolean;
allowMultiple: true;
onUpload?: (imageDTOs: ImageDTO[]) => void;
onUploadStarted?: (files: File[]) => void;
onError?: (error: unknown) => void;
};
const log = logger('gallery');
@@ -49,7 +53,13 @@ const log = logger('gallery');
* <Button {...getUploadButtonProps()} /> // will open the file dialog on click
* <input {...getUploadInputProps()} /> // hidden, handles native upload functionality
*/
export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
export const useImageUploadButton = ({
onUpload,
isDisabled,
allowMultiple,
onUploadStarted,
onError,
}: UseImageUploadButtonArgs) => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
const [uploadImage, request] = useUploadImageMutation();
@@ -71,6 +81,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
}
const file = files[0];
assert(file !== undefined); // should never happen
onUploadStarted?.(file);
const imageDTO = await uploadImage({
file,
image_category: 'user',
@@ -82,6 +93,8 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
onUpload(imageDTO);
}
} else {
onUploadStarted?.(files);
let imageDTOs: ImageDTO[] = [];
if (isClientSideUploadEnabled && files.length > 1) {
imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i)));
@@ -102,6 +115,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
}
}
} catch (error) {
onError?.(error);
toast({
id: 'UPLOAD_FAILED',
title: t('toast.imageUploadFailed'),
@@ -109,7 +123,17 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
});
}
},
[allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload, t]
[
allowMultiple,
onUploadStarted,
uploadImage,
autoAddBoardId,
onUpload,
isClientSideUploadEnabled,
clientSideUpload,
onError,
t,
]
);
const onDropRejected = useCallback(

View File

@@ -270,7 +270,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
outputNodeId: string;
options?: RunGraphOptions;
}): Promise<ImageDTO> => {
const dependencies = buildRunGraphDependencies(this.store, this.manager.socket);
const dependencies = buildRunGraphDependencies(this.store.dispatch, this.manager.socket);
const { output } = await runGraph({
dependencies,

View File

@@ -19,6 +19,7 @@ import type {
RgbColor,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import type { ImageField } from 'features/nodes/types/common';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import type { PartialDeep } from 'type-fest';
@@ -59,6 +60,8 @@ export const imageDTOToImageWithDims = ({ image_name, width, height }: ImageDTO)
height,
});
export const imageDTOToImageField = ({ image_name }: ImageDTO): ImageField => ({ image_name });
const DEFAULT_RG_MASK_FILL_COLORS: RgbColor[] = [
{ r: 121, g: 157, b: 219 }, // rgb(121, 157, 219)
{ r: 131, g: 214, b: 131 }, // rgb(131, 214, 131)

View File

@@ -22,6 +22,8 @@ import {
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import { selectFieldInputInstanceSafe, selectNodesSlice } from 'features/nodes/store/selectors';
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import { expandPrompt } from 'features/prompt/PromptExpansion/expand';
import { promptExpansionApi } from 'features/prompt/PromptExpansion/state';
import type { ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
@@ -515,23 +517,48 @@ export const removeImageFromBoardDndTarget: DndTarget<
//#endregion
//#region Prompt Generation From Image
const _promptGenerationFromImage = buildTypeAndKey('prompt-generation-from-image');
export type PromptGenerationFromImageDndTargetData = DndData<
typeof _promptGenerationFromImage.type,
typeof _promptGenerationFromImage.key,
void
>;
export const promptGenerationFromImageDndTarget: DndTarget<
PromptGenerationFromImageDndTargetData,
SingleImageDndSourceData
> = {
..._promptGenerationFromImage,
typeGuard: buildTypeGuard(_promptGenerationFromImage.key),
getData: buildGetData(_promptGenerationFromImage.key, _promptGenerationFromImage.type),
isValid: ({ sourceData }) => {
if (singleImageDndSource.typeGuard(sourceData)) {
return true;
}
return false;
},
handler: ({ sourceData, dispatch, getState }) => {
const { imageDTO } = sourceData.payload;
promptExpansionApi.setPending(imageDTO);
expandPrompt({ dispatch, getState, imageDTO });
},
};
//#endregion
export const dndTargets = [
// Single Image
setGlobalReferenceImageDndTarget,
addGlobalReferenceImageDndTarget,
setRegionalGuidanceReferenceImageDndTarget,
setUpscaleInitialImageDndTarget,
setNodeImageFieldImageDndTarget,
addImagesToNodeImageFieldCollectionDndTarget,
setComparisonImageDndTarget,
newCanvasEntityFromImageDndTarget,
newCanvasFromImageDndTarget,
replaceCanvasEntityObjectsWithImageDndTarget,
addImageToBoardDndTarget,
removeImageFromBoardDndTarget,
newCanvasFromImageDndTarget,
addGlobalReferenceImageDndTarget,
// Single or Multiple Image
addImageToBoardDndTarget,
removeImageFromBoardDndTarget,
addImagesToNodeImageFieldCollectionDndTarget,
promptGenerationFromImageDndTarget,
] as const;
export type AnyDndTarget = (typeof dndTargets)[number];

View File

@@ -0,0 +1,46 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { expandPrompt } from 'features/prompt/PromptExpansion/expand';
import { promptExpansionApi } from 'features/prompt/PromptExpansion/state';
import { selectAllowPromptExpansion } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTextTBold } from 'react-icons/pi';
export const ImageMenuItemUseForPromptGeneration = memo(() => {
const { t } = useTranslation();
const { dispatch, getState } = useAppStore();
const imageDTO = useImageDTOContext();
const { isPending } = useStore(promptExpansionApi.$state);
const isPromptExpansionEnabled = useAppSelector(selectAllowPromptExpansion);
const handleUseForPromptGeneration = useCallback(() => {
promptExpansionApi.setPending(imageDTO);
expandPrompt({ dispatch, getState, imageDTO });
toast({
id: 'PROMPT_GENERATION_STARTED',
title: t('toast.promptGenerationStarted'),
status: 'info',
});
}, [dispatch, getState, imageDTO, t]);
if (!isPromptExpansionEnabled) {
return null;
}
return (
<MenuItem
icon={<PiTextTBold />}
onClickCapture={handleUseForPromptGeneration}
id="use-for-prompt-generation"
isDisabled={isPending}
>
{t('gallery.useForPromptGeneration')}
</MenuItem>
);
});
ImageMenuItemUseForPromptGeneration.displayName = 'ImageMenuItemUseForPromptGeneration';

View File

@@ -14,6 +14,7 @@ import { ImageMenuItemSelectForCompare } from 'features/gallery/components/Image
import { ImageMenuItemSendToUpscale } from 'features/gallery/components/ImageContextMenu/ImageMenuItemSendToUpscale';
import { ImageMenuItemStarUnstar } from 'features/gallery/components/ImageContextMenu/ImageMenuItemStarUnstar';
import { ImageMenuItemUseAsRefImage } from 'features/gallery/components/ImageContextMenu/ImageMenuItemUseAsRefImage';
import { ImageMenuItemUseForPromptGeneration } from 'features/gallery/components/ImageContextMenu/ImageMenuItemUseForPromptGeneration';
import { ImageDTOContextProvider } from 'features/gallery/contexts/ImageDTOContext';
import { memo } from 'react';
import type { ImageDTO } from 'services/api/types';
@@ -38,6 +39,7 @@ const SingleSelectionMenuItems = ({ imageDTO }: SingleSelectionMenuItemsProps) =
<ImageMenuItemMetadataRecallActions />
<MenuDivider />
<ImageMenuItemSendToUpscale />
<ImageMenuItemUseForPromptGeneration />
<ImageMenuItemUseAsRefImage />
<ImageMenuItemNewCanvasFromImageSubMenu />
<ImageMenuItemNewLayerFromImageSubMenu />

View File

@@ -1,4 +1,5 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { Box, Flex, Textarea } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { usePersistedTextAreaSize } from 'common/hooks/usePersistedTextareaSize';
import {
@@ -7,12 +8,17 @@ import {
selectModelSupportsNegativePrompt,
selectPositivePrompt,
} from 'features/controlLayers/store/paramsSlice';
import { promptGenerationFromImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
import { NegativePromptToggleButton } from 'features/parameters/components/Core/NegativePromptToggleButton';
import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { ViewModePrompt } from 'features/parameters/components/Prompts/ViewModePrompt';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptExpansionMenu } from 'features/prompt/PromptExpansion/PromptExpansionMenu';
import { PromptExpansionOverlay } from 'features/prompt/PromptExpansion/PromptExpansionOverlay';
import { promptExpansionApi } from 'features/prompt/PromptExpansion/state';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { SDXLConcatButton } from 'features/sdxl/components/SDXLPrompts/SDXLConcatButton';
@@ -21,7 +27,8 @@ import {
selectStylePresetViewMode,
} from 'features/stylePresets/store/stylePresetSlice';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo, useCallback, useRef } from 'react';
import { selectAllowPromptExpansion } from 'features/system/store/configSlice';
import { memo, useCallback, useMemo, useRef } from 'react';
import type { HotkeyCallback } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { useListStylePresetsQuery } from 'services/api/endpoints/stylePresets';
@@ -39,6 +46,8 @@ export const ParamPositivePrompt = memo(() => {
const viewMode = useAppSelector(selectStylePresetViewMode);
const activeStylePresetId = useAppSelector(selectStylePresetActivePresetId);
const modelSupportsNegativePrompt = useAppSelector(selectModelSupportsNegativePrompt);
const { isPending: isPromptExpansionPending } = useStore(promptExpansionApi.$state);
const isPromptExpansionEnabled = useAppSelector(selectAllowPromptExpansion);
const textareaRef = useRef<HTMLTextAreaElement>(null);
usePersistedTextAreaSize('positive_prompt', textareaRef, persistOptions);
@@ -64,6 +73,7 @@ export const ParamPositivePrompt = memo(() => {
prompt,
textareaRef: textareaRef,
onChange: handleChange,
isDisabled: isPromptExpansionPending,
});
const focus: HotkeyCallback = useCallback(
@@ -82,41 +92,56 @@ export const ParamPositivePrompt = memo(() => {
dependencies: [focus],
});
const dndTargetData = useMemo(() => promptGenerationFromImageDndTarget.getData(), []);
return (
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative">
<Textarea
className="positive-prompt-textarea"
name="prompt"
ref={textareaRef}
value={prompt}
onChange={onChange}
onKeyDown={onKeyDown}
variant="darkFilled"
borderTopWidth={24} // This prevents the prompt from being hidden behind the header
paddingInlineEnd={10}
paddingInlineStart={3}
paddingTop={0}
paddingBottom={3}
resize="vertical"
minH={32}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
{baseModel === 'sdxl' && <SDXLConcatButton />}
<ShowDynamicPromptsPreviewButton />
{modelSupportsNegativePrompt && <NegativePromptToggleButton />}
</PromptOverlayButtonWrapper>
<PromptLabel label="Prompt" />
{viewMode && (
<ViewModePrompt
prompt={prompt}
presetPrompt={activeStylePreset?.preset_data.positive_prompt || ''}
label={`${t('parameters.positivePromptPlaceholder')} (${t('stylePresets.preview')})`}
<Box pos="relative">
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative">
<Textarea
className="positive-prompt-textarea"
name="prompt"
ref={textareaRef}
value={prompt}
onChange={onChange}
onKeyDown={onKeyDown}
variant="darkFilled"
borderTopWidth={24} // This prevents the prompt from being hidden behind the header
paddingInlineEnd={10}
paddingInlineStart={3}
paddingTop={0}
paddingBottom={3}
resize="vertical"
minH={isPromptExpansionEnabled ? 44 : 32}
isDisabled={isPromptExpansionPending}
/>
)}
</Box>
</PromptPopover>
<PromptOverlayButtonWrapper>
<Flex flexDir="column" gap={2} justifyContent="flex-start" alignItems="center">
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
{baseModel === 'sdxl' && <SDXLConcatButton />}
<ShowDynamicPromptsPreviewButton />
{modelSupportsNegativePrompt && <NegativePromptToggleButton />}
</Flex>
{isPromptExpansionEnabled && <PromptExpansionMenu />}
</PromptOverlayButtonWrapper>
<PromptLabel label="Prompt" />
{viewMode && (
<ViewModePrompt
prompt={prompt}
presetPrompt={activeStylePreset?.preset_data.positive_prompt || ''}
label={`${t('parameters.positivePromptPlaceholder')} (${t('stylePresets.preview')})`}
/>
)}
<DndDropTarget
dndTarget={promptGenerationFromImageDndTarget}
dndTargetData={dndTargetData}
label={t('prompt.generateFromImage')}
isDisabled={isPromptExpansionPending}
/>
<PromptExpansionOverlay />
</Box>
</PromptPopover>
</Box>
);
});

View File

@@ -10,7 +10,9 @@ export const PromptOverlayButtonWrapper = memo((props: PropsWithChildren) => (
p={2}
gap={2}
alignItems="center"
justifyContent="center"
justifyContent="space-between"
bottom={4} // make sure textarea resize is accessible
pb={0}
>
{props.children}
</Flex>

View File

@@ -0,0 +1,80 @@
import { IconButton, Menu, MenuButton, MenuItem, MenuList, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { WrappedError } from 'common/util/result';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiMagicWandBold } from 'react-icons/pi';
import type { ImageDTO } from 'services/api/types';
import { expandPrompt } from './expand';
import { promptExpansionApi } from './state';
export const PromptExpansionMenu = () => {
const { dispatch, getState } = useAppStore();
const { t } = useTranslation();
const { isPending } = useStore(promptExpansionApi.$state);
const onUploadStarted = useCallback(() => {
promptExpansionApi.setPending();
}, []);
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
promptExpansionApi.setPending(imageDTO);
expandPrompt({ dispatch, getState, imageDTO });
},
[dispatch, getState]
);
const onUploadError = useCallback(
(error: unknown) => {
const wrappedError = WrappedError.wrap(error);
promptExpansionApi.setError(wrappedError);
toast({
id: 'UPLOAD_AND_PROMPT_GENERATION_FAILED',
title: t('toast.uploadAndPromptGenerationFailed'),
status: 'error',
});
},
[t]
);
const uploadApi = useImageUploadButton({
allowMultiple: false,
onUpload,
onUploadStarted,
onError: onUploadError,
});
const onClickExpandPrompt = useCallback(() => {
promptExpansionApi.setPending();
expandPrompt({ dispatch, getState });
}, [dispatch, getState]);
return (
<>
<Menu>
<MenuButton
as={IconButton}
icon={<PiMagicWandBold size={16} />}
size="sm"
borderRadius="100%"
colorScheme="invokeYellow"
isDisabled={isPending}
/>
<MenuList>
<MenuItem onClick={onClickExpandPrompt} isDisabled={isPending}>
<Text>{t('prompt.expandCurrentPrompt')}</Text>
</MenuItem>
<MenuItem {...uploadApi.getUploadButtonProps()} isDisabled={isPending}>
<Text>{t('prompt.uploadImageForPromptGeneration')}</Text>
</MenuItem>
</MenuList>
</Menu>
<input {...uploadApi.getUploadInputProps()} />
</>
);
};

View File

@@ -0,0 +1,69 @@
import { Box, Flex, Image, Spinner, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { PromptExpansionResultOverlay } from 'features/prompt/PromptExpansion/PromptExpansionResultOverlay';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiMagicWandBold } from 'react-icons/pi';
import { promptExpansionApi } from './state';
export const PromptExpansionOverlay = memo(() => {
const { isSuccess, isPending, result, imageDTO } = useStore(promptExpansionApi.$state);
const { t } = useTranslation();
// Show result overlay when completed
if (isSuccess) {
return <PromptExpansionResultOverlay expandedText={result} />;
}
// Show pending overlay when pending
if (!isPending) {
return null;
}
return (
<Box
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
bg="base.900"
opacity={0.8}
borderRadius="base"
zIndex={10}
display="flex"
alignItems="center"
justifyContent="center"
animation="pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite"
>
{/* Show dimmed source image if available */}
{imageDTO && (
<Box
position="absolute"
top={2}
left={2}
right={2}
bottom={2}
opacity={0.5}
borderRadius="base"
overflow="hidden"
>
<Image src={imageDTO.thumbnail_url} objectFit="contain" w="full" h="full" borderRadius="base" />
</Box>
)}
<Flex direction="column" alignItems="center" gap={3} color="invokeYellow.300" position="relative" zIndex={1}>
<Box position="relative" display="flex" alignItems="center" justifyContent="center">
<PiMagicWandBold size={24} />
<Spinner size="sm" position="absolute" color="invokeYellow.400" thickness="2px" />
</Box>
<Text fontSize="sm" fontWeight="medium" textAlign="center">
{t('prompt.expandingPrompt')}
</Text>
</Flex>
</Box>
);
});
PromptExpansionOverlay.displayName = 'PromptExpansionOverlay';

View File

@@ -0,0 +1,76 @@
import { ButtonGroup, Flex, Icon, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { positivePromptChanged, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
import { useCallback } from 'react';
import { PiCheckBold, PiMagicWandBold, PiPlusBold, PiXBold } from 'react-icons/pi';
import { promptExpansionApi } from './state';
interface PromptExpansionResultOverlayProps {
expandedText: string;
}
export const PromptExpansionResultOverlay = ({ expandedText }: PromptExpansionResultOverlayProps) => {
const dispatch = useAppDispatch();
const positivePrompt = useAppSelector(selectPositivePrompt);
const handleReplace = useCallback(() => {
dispatch(positivePromptChanged(expandedText));
promptExpansionApi.reset();
}, [dispatch, expandedText]);
const handleInsert = useCallback(() => {
const currentText = positivePrompt;
const newText = currentText ? `${currentText}\n${expandedText}` : expandedText;
dispatch(positivePromptChanged(newText));
promptExpansionApi.reset();
}, [dispatch, expandedText, positivePrompt]);
const handleDiscard = useCallback(() => {
promptExpansionApi.reset();
}, []);
return (
<Flex pos="absolute" inset={0} bg="base.800" backdropFilter="blur(8px)" zIndex={10} direction="column">
<Flex flex={1} p={2} borderRadius="md" overflowY="auto" minH={0}>
<Text fontSize="sm" w="full" pr={7}>
<Icon as={PiMagicWandBold} boxSize={5} display="inline" mr={2} color="invokeYellow.500" />
{expandedText}
</Text>
</Flex>
<Flex gap={2} p={1} justify="flex-end" pos="absolute" bottom={0} right={0} flexDirection="column">
<ButtonGroup orientation="vertical">
<Tooltip label="Replace" placement="right">
<IconButton
onClick={handleReplace}
icon={<PiCheckBold />}
colorScheme="invokeGreen"
size="xs"
aria-label="Replace"
/>
</Tooltip>
<Tooltip label="Insert" placement="right">
<IconButton
onClick={handleInsert}
icon={<PiPlusBold />}
colorScheme="invokeBlue"
size="xs"
aria-label="Insert"
/>
</Tooltip>
</ButtonGroup>
<Tooltip label="Discard" placement="right">
<IconButton
onClick={handleDiscard}
icon={<PiXBold />}
colorScheme="invokeRed"
size="xs"
aria-label="Discard"
/>
</Tooltip>
</Flex>
</Flex>
);
};

View File

@@ -0,0 +1,43 @@
import type { AppDispatch, AppGetState } from 'app/store/store';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { buildRunGraphDependencies, runGraph } from 'services/api/run-graph';
import type { ImageDTO } from 'services/api/types';
import { $socket } from 'services/events/stores';
import { assert } from 'tsafe';
import { buildPromptExpansionGraph } from './graph';
import { promptExpansionApi } from './state';
export const expandPrompt = async (arg: { dispatch: AppDispatch; getState: AppGetState; imageDTO?: ImageDTO }) => {
const { dispatch, getState, imageDTO } = arg;
const socket = $socket.get();
if (!socket) {
return;
}
const { graph, outputNodeId } = buildPromptExpansionGraph({
state: getState(),
imageDTO,
});
const dependencies = buildRunGraphDependencies(dispatch, socket);
try {
const { output } = await runGraph({
graph,
outputNodeId,
dependencies,
options: {
prepend: true,
timeout: 10000,
},
});
assert(output.type === 'string_output');
promptExpansionApi.setSuccess(output.value);
} catch (error) {
promptExpansionApi.reset();
toast({
id: 'PROMPT_EXPANSION_FAILED',
title: t('toast.promptExpansionFailed'),
status: 'error',
});
}
};

View File

@@ -0,0 +1,43 @@
import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectBase, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
import { imageDTOToImageField } from 'features/controlLayers/store/util';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
export const buildPromptExpansionGraph = ({
state,
imageDTO,
}: {
state: RootState;
imageDTO?: ImageDTO;
}): { graph: Graph; outputNodeId: string } => {
const base = selectBase(state);
assert(base, 'No main model found in state');
const architecture = ['sdxl', 'sdxl-refiner'].includes(base) ? 'tag_based' : 'sentence_based';
if (imageDTO) {
const graph = new Graph(getPrefixedId('claude-analyze-image-graph'));
const outputNode = graph.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'claude_analyze_image',
id: getPrefixedId('claude_analyze_image'),
model_architecture: architecture,
image: imageDTOToImageField(imageDTO),
});
return { graph, outputNodeId: outputNode.id };
} else {
const positivePrompt = selectPositivePrompt(state);
const graph = new Graph(getPrefixedId('claude-expand-prompt-graph'));
const outputNode = graph.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'claude_expand_prompt',
id: getPrefixedId('claude_expand_prompt'),
model_architecture: architecture,
prompt: positivePrompt,
});
return { graph, outputNodeId: outputNode.id };
}
};

View File

@@ -0,0 +1,98 @@
import { deepClone } from 'common/util/deepClone';
import { atom } from 'nanostores';
import type { ImageDTO } from 'services/api/types';
type SuccessState = {
isSuccess: true;
isError: false;
isPending: false;
result: string;
error: null;
imageDTO?: ImageDTO;
};
type ErrorState = {
isSuccess: false;
isError: true;
isPending: false;
result: null;
error: Error;
imageDTO?: ImageDTO;
};
type PendingState = {
isSuccess: false;
isError: false;
isPending: true;
result: null;
error: null;
imageDTO?: ImageDTO;
};
type IdleState = {
isSuccess: false;
isError: false;
isPending: false;
result: null;
error: null;
imageDTO?: ImageDTO;
};
export type PromptExpansionRequestState = IdleState | PendingState | SuccessState | ErrorState;
const IDLE_STATE: IdleState = {
isSuccess: false,
isError: false,
isPending: false,
result: null,
error: null,
imageDTO: undefined,
};
const $state = atom<PromptExpansionRequestState>(deepClone(IDLE_STATE));
const reset = () => {
$state.set(deepClone(IDLE_STATE));
};
const setPending = (imageDTO?: ImageDTO) => {
$state.set({
...$state.get(),
isSuccess: false,
isError: false,
isPending: true,
result: null,
error: null,
imageDTO,
});
};
const setSuccess = (result: string) => {
$state.set({
...$state.get(),
isSuccess: true,
isError: false,
isPending: false,
result,
error: null,
});
};
const setError = (error: Error) => {
$state.set({
...$state.get(),
isSuccess: false,
isError: true,
isPending: false,
result: null,
error,
});
};
export const promptExpansionApi = {
$state,
reset,
setPending,
setSuccess,
setError,
};

View File

@@ -8,9 +8,10 @@ type UseInsertTriggerArg = {
prompt: string;
textareaRef: RefObject<HTMLTextAreaElement>;
onChange: (v: string) => void;
isDisabled?: boolean;
};
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInsertTriggerArg) => {
export const usePrompt = ({ prompt, textareaRef, onChange: _onChange, isDisabled = false }: UseInsertTriggerArg) => {
const { isOpen, onClose, onOpen } = useDisclosure();
const onChange: ChangeEventHandler<HTMLTextAreaElement> = useCallback(
@@ -73,12 +74,12 @@ export const usePrompt = ({ prompt, textareaRef, onChange: _onChange }: UseInser
const onKeyDown: KeyboardEventHandler<HTMLTextAreaElement> = useCallback(
(e) => {
if (e.key === '<') {
if (e.key === '<' && !isDisabled) {
onOpen();
e.preventDefault();
}
},
[onOpen]
[onOpen, isDisabled]
);
return {

View File

@@ -36,6 +36,7 @@ import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { getGridSize } from 'features/parameters/util/optimalDimension';
import { promptExpansionApi, type PromptExpansionRequestState } from 'features/prompt/PromptExpansion/state';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { TabName } from 'features/ui/store/uiTypes';
@@ -89,9 +90,22 @@ const debouncedUpdateReasons = debounce(
config: AppConfig,
store: AppStore,
isInPublishFlow: boolean,
isChatGPT4oHighModelDisabled: (model: ParameterModel) => boolean
isChatGPT4oHighModelDisabled: (model: ParameterModel) => boolean,
promptExpansionRequest: PromptExpansionRequestState
) => {
if (tab === 'canvas') {
if (tab === 'generate') {
const model = selectMainModelConfig(store.getState());
const reasons = await getReasonsWhyCannotEnqueueGenerateTab({
isConnected,
model,
params,
refImages,
dynamicPrompts,
isChatGPT4oHighModelDisabled,
promptExpansionRequest,
});
$reasonsWhyCannotEnqueue.set(reasons);
} else if (tab === 'canvas') {
const model = selectMainModelConfig(store.getState());
const reasons = await getReasonsWhyCannotEnqueueCanvasTab({
isConnected,
@@ -106,6 +120,7 @@ const debouncedUpdateReasons = debounce(
canvasIsCompositing,
canvasIsSelectingObject,
isChatGPT4oHighModelDisabled,
promptExpansionRequest,
});
$reasonsWhyCannotEnqueue.set(reasons);
} else if (tab === 'workflows') {
@@ -124,6 +139,7 @@ const debouncedUpdateReasons = debounce(
upscale,
config,
params,
promptExpansionRequest,
});
$reasonsWhyCannotEnqueue.set(reasons);
} else {
@@ -155,6 +171,7 @@ export const useReadinessWatcher = () => {
const canvasIsCompositing = useStore(canvasManager?.compositor.$isBusy ?? $false);
const isInPublishFlow = useStore($isInPublishFlow);
const { isChatGPT4oHighModelDisabled } = useIsModelDisabled();
const promptExpansionRequest = useStore(promptExpansionApi.$state);
useEffect(() => {
debouncedUpdateReasons(
@@ -176,7 +193,8 @@ export const useReadinessWatcher = () => {
config,
store,
isInPublishFlow,
isChatGPT4oHighModelDisabled
isChatGPT4oHighModelDisabled,
promptExpansionRequest
);
}, [
store,
@@ -198,11 +216,88 @@ export const useReadinessWatcher = () => {
workflowSettings,
isInPublishFlow,
isChatGPT4oHighModelDisabled,
promptExpansionRequest,
]);
};
const disconnectedReason = (t: typeof i18n.t) => ({ content: t('parameters.invoke.systemDisconnected') });
const getReasonsWhyCannotEnqueueGenerateTab = (arg: {
isConnected: boolean;
model: MainModelConfig | null | undefined;
params: ParamsState;
refImages: RefImagesState;
dynamicPrompts: DynamicPromptsState;
isChatGPT4oHighModelDisabled: (model: ParameterModel) => boolean;
promptExpansionRequest: PromptExpansionRequestState;
}) => {
const {
isConnected,
model,
params,
refImages,
dynamicPrompts,
isChatGPT4oHighModelDisabled,
promptExpansionRequest,
} = arg;
const { positivePrompt } = params;
const reasons: Reason[] = [];
if (!isConnected) {
reasons.push(disconnectedReason(i18n.t));
}
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
}
if (!model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (model?.base === 'flux') {
if (!params.t5EncoderModel) {
reasons.push({ content: i18n.t('parameters.invoke.noT5EncoderModelSelected') });
}
if (!params.clipEmbedModel) {
reasons.push({ content: i18n.t('parameters.invoke.noCLIPEmbedModelSelected') });
}
if (!params.fluxVAE) {
reasons.push({ content: i18n.t('parameters.invoke.noFLUXVAEModelSelected') });
}
}
if (model && isChatGPT4oHighModelDisabled(model)) {
reasons.push({ content: i18n.t('parameters.invoke.modelDisabledForTrial', { modelName: model.name }) });
}
if (promptExpansionRequest.isPending) {
reasons.push({ content: i18n.t('parameters.invoke.promptExpansionPending') });
} else if (promptExpansionRequest.isSuccess) {
reasons.push({ content: i18n.t('parameters.invoke.promptExpansionResultPending') });
}
// Flux Kontext only supports 1x Reference Image at a time.
const referenceImageCount = refImages.entities.length;
if (model?.base === 'flux-kontext' && referenceImageCount > 1) {
reasons.push({ content: i18n.t('parameters.invoke.fluxKontextMultipleReferenceImages') });
}
refImages.entities.forEach((entity, i) => {
const layerNumber = i + 1;
const refImageLiteral = i18n.t(LAYER_TYPE_TO_TKEY['reference_image']);
const prefix = `${refImageLiteral} #${layerNumber}`;
const problems = getGlobalReferenceImageWarnings(entity, model);
if (problems.length) {
const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
reasons.push({ prefix, content });
}
});
return reasons;
};
const getReasonsWhyCannotEnqueueWorkflowsTab = async (arg: {
dispatch: AppDispatch;
nodesState: NodesState;
@@ -293,8 +388,9 @@ const getReasonsWhyCannotEnqueueUpscaleTab = (arg: {
upscale: UpscaleState;
config: AppConfig;
params: ParamsState;
promptExpansionRequest: PromptExpansionRequestState;
}) => {
const { isConnected, upscale, config, params } = arg;
const { isConnected, upscale, config, params, promptExpansionRequest } = arg;
const reasons: Reason[] = [];
if (!isConnected) {
@@ -331,6 +427,12 @@ const getReasonsWhyCannotEnqueueUpscaleTab = (arg: {
}
}
if (promptExpansionRequest.isPending) {
reasons.push({ content: i18n.t('parameters.invoke.promptExpansionPending') });
} else if (promptExpansionRequest.isSuccess) {
reasons.push({ content: i18n.t('parameters.invoke.promptExpansionResultPending') });
}
return reasons;
};
@@ -347,6 +449,7 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
canvasIsCompositing: boolean;
canvasIsSelectingObject: boolean;
isChatGPT4oHighModelDisabled: (model: ParameterModel) => boolean;
promptExpansionRequest: PromptExpansionRequestState;
}) => {
const {
isConnected,
@@ -361,6 +464,7 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
canvasIsCompositing,
canvasIsSelectingObject,
isChatGPT4oHighModelDisabled,
promptExpansionRequest,
} = arg;
const { positivePrompt } = params;
const reasons: Reason[] = [];
@@ -497,6 +601,12 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
reasons.push({ content: i18n.t('parameters.invoke.modelDisabledForTrial', { modelName: model.name }) });
}
if (promptExpansionRequest.isPending) {
reasons.push({ content: i18n.t('parameters.invoke.promptExpansionPending') });
} else if (promptExpansionRequest.isSuccess) {
reasons.push({ content: i18n.t('parameters.invoke.promptExpansionResultPending') });
}
const enabledControlLayers = canvas.controlLayers.entities.filter((controlLayer) => controlLayer.isEnabled);
// FLUX only supports 1x Control LoRA at a time.

View File

@@ -25,6 +25,7 @@ const initialConfigState: AppConfig & { didLoad: boolean } = {
allowPrivateStylePresets: false,
allowClientSideUpload: false,
allowPublishWorkflows: false,
allowPromptExpansion: false,
shouldShowCredits: false,
disabledTabs: [],
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
@@ -228,6 +229,7 @@ export const selectMetadataFetchDebounce = createConfigSelector((config) => conf
export const selectIsModelsTabDisabled = createConfigSelector((config) => config.disabledTabs.includes('models'));
export const selectIsClientSideUploadEnabled = createConfigSelector((config) => config.allowClientSideUpload);
export const selectAllowPublishWorkflows = createConfigSelector((config) => config.allowPublishWorkflows);
export const selectAllowPromptExpansion = createConfigSelector((config) => config.allowPromptExpansion);
export const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);
export const selectShouldShowCredits = createConfigSelector((config) => config.shouldShowCredits);
export const selectEnabledTabs = createConfigSelector((config) => {

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import type { AppDispatch } from 'app/store/store';
import { Mutex } from 'async-mutex';
import { withResultAsync, WrappedError } from 'common/util/result';
import { parseify } from 'common/util/serialize';
@@ -135,7 +135,7 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
* Creates production dependencies for runGraph using Redux store and socket.
*/
export const buildRunGraphDependencies = (
store: AppStore,
dispatch: AppDispatch,
socket: {
on: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
off: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
@@ -143,17 +143,15 @@ export const buildRunGraphDependencies = (
): GraphRunnerDependencies => ({
executor: {
enqueueBatch: (batch) =>
store
.dispatch(
queueApi.endpoints.enqueueBatch.initiate(batch, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
)
.unwrap(),
getQueueItem: (id) => store.dispatch(queueApi.endpoints.getQueueItem.initiate(id, { subscribe: false })).unwrap(),
dispatch(
queueApi.endpoints.enqueueBatch.initiate(batch, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
).unwrap(),
getQueueItem: (id) => dispatch(queueApi.endpoints.getQueueItem.initiate(id, { subscribe: false })).unwrap(),
cancelQueueItem: (id) =>
store.dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id: id }, { track: false })).unwrap(),
dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id: id }, { track: false })).unwrap(),
},
eventHandler: {
subscribe: (handler) => socket.on('queue_item_status_changed', handler),