diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index a84befcb2e..a1e19614fe 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -12,7 +12,9 @@ from invokeai.app.invocations.baseinvocation import ( invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import ControlField +from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField +from invokeai.app.invocations.primitives import ImageField from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from ...version import __version__ @@ -25,6 +27,18 @@ class LoRAMetadataField(BaseModelExcludeNull): weight: float = Field(description="The weight of the LoRA model") +class IPAdapterMetadataField(BaseModelExcludeNull): + image: ImageField = Field(description="The IP-Adapter image prompt.") + ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.") + weight: float = Field(description="The weight of the IP-Adapter model") + begin_step_percent: float = Field( + default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)" + ) + end_step_percent: float = Field( + default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)" + ) + + class CoreMetadata(BaseModelExcludeNull): """Core generation metadata for an image generated in InvokeAI.""" @@ -48,6 +62,7 @@ class CoreMetadata(BaseModelExcludeNull): ) model: MainModelField = Field(description="The main model used for inference") controlnets: list[ControlField] = Field(description="The ControlNets used for inference") + ipAdapters: list[IPAdapterMetadataField] = Field(description="The IP Adapters used for inference") loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") vae: Optional[VAEModelField] = Field( default=None, @@ -123,6 +138,7 @@ class MetadataAccumulatorInvocation(BaseInvocation): ) model: MainModelField = InputField(description="The main model used for inference") controlnets: list[ControlField] = InputField(description="The ControlNets used for inference") + ipAdapters: list[IPAdapterMetadataField] = InputField(description="The IP Adapters used for inference") loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference") strength: Optional[float] = InputField( default=None, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index 7c51b44aa2..0861bba333 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -113,7 +113,7 @@ export const addRequestedSingleImageDeletionListener = () => { // Remove IP Adapter Set Image if image is deleted. if ( - getState().controlNet.ipAdapterInfo.adapterImage?.image_name === + getState().controlNet.ipAdapterInfo.adapterImage === imageDTO.image_name ) { dispatch(ipAdapterImageChanged(null)); @@ -238,7 +238,7 @@ export const addRequestedMultipleImageDeletionListener = () => { // Remove IP Adapter Set Image if image is deleted. if ( - getState().controlNet.ipAdapterInfo.adapterImage?.image_name === + getState().controlNet.ipAdapterInfo.adapterImage === imageDTO.image_name ) { dispatch(ipAdapterImageChanged(null)); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts index d38a20a917..133c826186 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts @@ -118,7 +118,7 @@ export const addImageDroppedListener = () => { activeData.payloadType === 'IMAGE_DTO' && activeData.payload.imageDTO ) { - dispatch(ipAdapterImageChanged(activeData.payload.imageDTO)); + dispatch(ipAdapterImageChanged(activeData.payload.imageDTO.image_name)); dispatch(isIPAdapterEnabledChanged(true)); return; } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index b27c922342..831f732e0f 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -111,7 +111,7 @@ export const addImageUploadedFulfilledListener = () => { } if (postUploadAction?.type === 'SET_IP_ADAPTER_IMAGE') { - dispatch(ipAdapterImageChanged(imageDTO)); + dispatch(ipAdapterImageChanged(imageDTO.image_name)); dispatch(isIPAdapterEnabledChanged(true)); dispatch( addToast({ diff --git a/invokeai/frontend/web/src/features/controlNet/components/ipAdapter/ParamIPAdapterImage.tsx b/invokeai/frontend/web/src/features/controlNet/components/ipAdapter/ParamIPAdapterImage.tsx index 21229a81a5..701f8ad7a9 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ipAdapter/ParamIPAdapterImage.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ipAdapter/ParamIPAdapterImage.tsx @@ -33,7 +33,7 @@ const ParamIPAdapterImage = () => { const { t } = useTranslation(); const { currentData: imageDTO } = useGetImageDTOQuery( - ipAdapterInfo.adapterImage?.image_name ?? skipToken + ipAdapterInfo.adapterImage ?? skipToken ); const draggableData = useMemo(() => { diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index 9b5dec68f3..a82692e70e 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -6,7 +6,6 @@ import { import { cloneDeep, forEach } from 'lodash-es'; import { imagesApi } from 'services/api/endpoints/images'; import { components } from 'services/api/schema'; -import { ImageDTO } from 'services/api/types'; import { appSocketInvocationError } from 'services/events/actions'; import { controlNetImageProcessed } from './actions'; import { @@ -60,7 +59,7 @@ export type ControlNetConfig = { }; export type IPAdapterConfig = { - adapterImage: ImageDTO | null; + adapterImage: string | null; model: IPAdapterModelParam | null; weight: number; beginStepPct: number; @@ -388,7 +387,10 @@ export const controlNetSlice = createSlice({ isIPAdapterEnabledChanged: (state, action: PayloadAction) => { state.isIPAdapterEnabled = action.payload; }, - ipAdapterImageChanged: (state, action: PayloadAction) => { + ipAdapterRecalled: (state, action: PayloadAction) => { + state.ipAdapterInfo = action.payload; + }, + ipAdapterImageChanged: (state, action: PayloadAction) => { state.ipAdapterInfo.adapterImage = action.payload; }, ipAdapterWeightChanged: (state, action: PayloadAction) => { @@ -471,6 +473,7 @@ export const { controlNetReset, controlNetAutoConfigToggled, isIPAdapterEnabledChanged, + ipAdapterRecalled, ipAdapterImageChanged, ipAdapterWeightChanged, ipAdapterModelChanged, diff --git a/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts index 151e975634..51808de2cd 100644 --- a/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts +++ b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts @@ -27,8 +27,7 @@ export const getImageUsage = (state: RootState, image_name: string) => { c.controlImage === image_name || c.processedControlImage === image_name ); - const isIPAdapterImage = - controlNet.ipAdapterInfo.adapterImage?.image_name === image_name; + const isIPAdapterImage = controlNet.ipAdapterInfo.adapterImage === image_name; const imageUsage: ImageUsage = { isInitialImage, diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 25d8e1e5ac..f43db5f22f 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -2,6 +2,7 @@ import { ControlNetMetadataItem, CoreMetadata, LoRAMetadataItem, + IPAdapterMetadataItem, } from 'features/nodes/types/types'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { memo, useMemo, useCallback } from 'react'; @@ -34,6 +35,7 @@ const ImageMetadataActions = (props: Props) => { recallStrength, recallLoRA, recallControlNet, + recallIPAdapter, } = useRecallParameters(); const handleRecallPositivePrompt = useCallback(() => { @@ -90,6 +92,13 @@ const ImageMetadataActions = (props: Props) => { [recallControlNet] ); + const handleRecallIPAdapter = useCallback( + (ipAdapter: IPAdapterMetadataItem) => { + recallIPAdapter(ipAdapter); + }, + [recallIPAdapter] + ); + const validControlNets: ControlNetMetadataItem[] = useMemo(() => { return metadata?.controlnets ? metadata.controlnets.filter((controlnet) => @@ -98,6 +107,14 @@ const ImageMetadataActions = (props: Props) => { : []; }, [metadata?.controlnets]); + const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => { + return metadata?.ipAdapters + ? metadata.ipAdapters.filter((ipAdapter) => + isValidControlNetModel(ipAdapter.ip_adapter_model) + ) + : []; + }, [metadata?.ipAdapters]); + if (!metadata || Object.keys(metadata).length === 0) { return null; } @@ -211,6 +228,14 @@ const ImageMetadataActions = (props: Props) => { onClick={() => handleRecallControlNet(controlnet)} /> ))} + {validIPAdapters.map((ipAdapter, index) => ( + handleRecallIPAdapter(ipAdapter)} + /> + ))} ); }; diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index eb8baf513e..2e64a36926 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -412,8 +412,9 @@ export type IPAdapterModel = z.infer; export const zIPAdapterField = z.object({ image: zImageField, ip_adapter_model: zIPAdapterModel, - image_encoder_model: z.string().trim().min(1), weight: z.number(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), }); export type IPAdapterField = z.infer; @@ -1145,6 +1146,10 @@ const zControlNetMetadataItem = zControlField.deepPartial(); export type ControlNetMetadataItem = z.infer; +const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); + +export type IPAdapterMetadataItem = z.infer; + export const zCoreMetadata = z .object({ app_version: z.string().nullish().catch(null), @@ -1164,16 +1169,9 @@ export const zCoreMetadata = z .union([zMainModel.deepPartial(), zOnnxModel.deepPartial()]) .nullish() .catch(null), - controlnets: z.array(zControlField.deepPartial()).nullish().catch(null), - loras: z - .array( - z.object({ - lora: zLoRAModelField.deepPartial(), - weight: z.number(), - }) - ) - .nullish() - .catch(null), + controlnets: z.array(zControlNetMetadataItem).nullish().catch(null), + ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null), + loras: z.array(zLoRAMetadataItem).nullish().catch(null), vae: zVaeModelField.nullish().catch(null), strength: z.number().nullish().catch(null), init_image: z.string().nullish().catch(null), diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts index d645b274ec..a55de8e301 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts @@ -1,7 +1,14 @@ import { RootState } from 'app/store/store'; -import { IPAdapterInvocation } from 'services/api/types'; +import { + IPAdapterInvocation, + MetadataAccumulatorInvocation, +} from 'services/api/types'; import { NonNullableGraph } from '../../types/types'; -import { CANVAS_COHERENCE_DENOISE_LATENTS, IP_ADAPTER } from './constants'; +import { + CANVAS_COHERENCE_DENOISE_LATENTS, + IP_ADAPTER, + METADATA_ACCUMULATOR, +} from './constants'; export const addIPAdapterToLinearGraph = ( state: RootState, @@ -10,9 +17,9 @@ export const addIPAdapterToLinearGraph = ( ): void => { const { isIPAdapterEnabled, ipAdapterInfo } = state.controlNet; - // const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as - // | MetadataAccumulatorInvocation - // | undefined; + const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as + | MetadataAccumulatorInvocation + | undefined; if (isIPAdapterEnabled && ipAdapterInfo.model) { const ipAdapterNode: IPAdapterInvocation = { @@ -30,23 +37,29 @@ export const addIPAdapterToLinearGraph = ( if (ipAdapterInfo.adapterImage) { ipAdapterNode.image = { - image_name: ipAdapterInfo.adapterImage.image_name, + image_name: ipAdapterInfo.adapterImage, }; } else { return; } graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation; + if (metadataAccumulator?.ipAdapters) { + const ipAdapterField = { + image: { + image_name: ipAdapterInfo.adapterImage, + }, + ip_adapter_model: { + base_model: ipAdapterInfo.model?.base_model, + model_name: ipAdapterInfo.model?.model_name, + }, + weight: ipAdapterInfo.weight, + begin_step_percent: ipAdapterInfo.beginStepPct, + end_step_percent: ipAdapterInfo.endStepPct, + }; - // if (metadataAccumulator?.ip_adapters) { - // // metadata accumulator only needs the ip_adapter field - not the whole node - // // extract what we need and add to the accumulator - // const ipAdapterField = omit(ipAdapterNode, [ - // 'id', - // 'type', - // ]) as IPAdapterField; - // metadataAccumulator.ip_adapters.push(ipAdapterField); - // } + metadataAccumulator.ipAdapters.push(ipAdapterField); + } graph.edges.push({ source: { node_id: ipAdapterNode.id, field: 'ip_adapter' }, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index 4dbbac9f96..aed3478f57 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -327,6 +327,7 @@ export const buildCanvasImageToImageGraph = ( vae: undefined, // option; set in addVAEToGraph controlnets: [], // populated in addControlNetToLinearGraph loras: [], // populated in addLoRAsToGraph + ipAdapters: [], // populated in addIPAdapterToLinearGraph clip_skip: clipSkip, strength, init_image: initialImage.image_name, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts index d958b78a90..e2e9bcc58d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts @@ -338,6 +338,7 @@ export const buildCanvasSDXLImageToImageGraph = ( vae: undefined, // option; set in addVAEToGraph controlnets: [], // populated in addControlNetToLinearGraph loras: [], // populated in addLoRAsToGraph + ipAdapters: [], // populated in addIPAdapterToLinearGraph strength, init_image: initialImage.image_name, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts index 9f9a442b99..2ee6844918 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts @@ -320,6 +320,7 @@ export const buildCanvasSDXLTextToImageGraph = ( vae: undefined, // option; set in addVAEToGraph controlnets: [], // populated in addControlNetToLinearGraph loras: [], // populated in addLoRAsToGraph + ipAdapters: [], // populated in addIPAdapterToLinearGraph }; graph.edges.push({ diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index 2aa0b2b47d..29cfd86cb4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -308,6 +308,7 @@ export const buildCanvasTextToImageGraph = ( vae: undefined, // option; set in addVAEToGraph controlnets: [], // populated in addControlNetToLinearGraph loras: [], // populated in addLoRAsToGraph + ipAdapters: [], // populated in addIPAdapterToLinearGraph clip_skip: clipSkip, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index bf8d9ea314..a2f73d8dba 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -328,6 +328,7 @@ export const buildLinearImageToImageGraph = ( vae: undefined, // option; set in addVAEToGraph controlnets: [], // populated in addControlNetToLinearGraph loras: [], // populated in addLoRAsToGraph + ipAdapters: [], // populated in addIPAdapterToLinearGraph clip_skip: clipSkip, strength, init_image: initialImage.imageName, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts index bc02fb5a66..1fe894488c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts @@ -348,6 +348,7 @@ export const buildLinearSDXLImageToImageGraph = ( vae: undefined, controlnets: [], loras: [], + ipAdapters: [], strength: strength, init_image: initialImage.imageName, positive_style_prompt: positiveStylePrompt, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts index 22a7dd4192..cf92b16044 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts @@ -242,6 +242,7 @@ export const buildLinearSDXLTextToImageGraph = ( vae: undefined, controlnets: [], loras: [], + ipAdapters: [], positive_style_prompt: positiveStylePrompt, negative_style_prompt: negativeStylePrompt, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index d7af045803..f658cfd182 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -250,6 +250,7 @@ export const buildLinearTextToImageGraph = ( vae: undefined, // option; set in addVAEToGraph controlnets: [], // populated in addControlNetToLinearGraph loras: [], // populated in addLoRAsToGraph + ipAdapters: [], // populated in addIPAdapterToLinearGraph clip_skip: clipSkip, }; diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index d8561ab122..71d074afce 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -6,6 +6,7 @@ import { CoreMetadata, LoRAMetadataItem, ControlNetMetadataItem, + IPAdapterMetadataItem, } from 'features/nodes/types/types'; import { refinerModelChanged, @@ -23,16 +24,22 @@ import { useTranslation } from 'react-i18next'; import { ImageDTO } from 'services/api/types'; import { controlNetModelsAdapter, + ipAdapterModelsAdapter, + useGetIPAdapterModelsQuery, loraModelsAdapter, useGetControlNetModelsQuery, useGetLoRAModelsQuery, } from '../../../services/api/endpoints/models'; import { ControlNetConfig, + IPAdapterConfig, controlNetEnabled, controlNetRecalled, controlNetReset, initialControlNet, + initialIPAdapterState, + ipAdapterRecalled, + isIPAdapterEnabledChanged, } from '../../controlNet/store/controlNetSlice'; import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice'; import { initialImageSelected, modelSelected } from '../store/actions'; @@ -52,6 +59,7 @@ import { isValidHeight, isValidLoRAModel, isValidControlNetModel, + isValidIPAdapterModel, isValidMainModel, isValidNegativePrompt, isValidPositivePrompt, @@ -512,8 +520,6 @@ export const useRecallParameters = () => { }) ); - dispatch(controlNetEnabled()); - parameterSetToast(); }, [ @@ -524,6 +530,92 @@ export const useRecallParameters = () => { ] ); + /** + * Recall IP Adapter with toast + */ + + const { ipAdapters } = useGetIPAdapterModelsQuery(undefined, { + selectFromResult: (result) => ({ + ipAdapters: result.data + ? ipAdapterModelsAdapter.getSelectors().selectAll(result.data) + : [], + }), + }); + + const prepareIPAdapterMetadataItem = useCallback( + (ipAdapterMetadataItem: IPAdapterMetadataItem) => { + if (!isValidIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) { + return { ipAdapter: null, error: 'Invalid IP Adapter model' }; + } + + const { + image, + ip_adapter_model, + weight, + begin_step_percent, + end_step_percent, + } = ipAdapterMetadataItem; + + const matchingIPAdapterModel = ipAdapters.find( + (c) => + c.base_model === ip_adapter_model?.base_model && + c.model_name === ip_adapter_model?.model_name + ); + + if (!matchingIPAdapterModel) { + return { ipAdapter: null, error: 'IP Adapter model is not installed' }; + } + + const isCompatibleBaseModel = + matchingIPAdapterModel?.base_model === model?.base_model; + + if (!isCompatibleBaseModel) { + return { + ipAdapter: null, + error: 'IP Adapter incompatible with currently-selected model', + }; + } + + const ipAdapter: IPAdapterConfig = { + adapterImage: image?.image_name ?? null, + model: matchingIPAdapterModel, + weight: weight ?? initialIPAdapterState.weight, + beginStepPct: begin_step_percent ?? initialIPAdapterState.beginStepPct, + endStepPct: end_step_percent ?? initialIPAdapterState.endStepPct, + }; + + return { ipAdapter, error: null }; + }, + [ipAdapters, model?.base_model] + ); + + const recallIPAdapter = useCallback( + (ipAdapterMetadataItem: IPAdapterMetadataItem) => { + const result = prepareIPAdapterMetadataItem(ipAdapterMetadataItem); + + if (!result.ipAdapter) { + parameterNotSetToast(result.error); + return; + } + + dispatch( + ipAdapterRecalled({ + ...result.ipAdapter, + }) + ); + + dispatch(isIPAdapterEnabledChanged(true)); + + parameterSetToast(); + }, + [ + prepareIPAdapterMetadataItem, + dispatch, + parameterSetToast, + parameterNotSetToast, + ] + ); + /* * Sets image as initial image with toast */ @@ -563,6 +655,7 @@ export const useRecallParameters = () => { refiner_start, loras, controlnets, + ipAdapters, } = metadata; if (isValidCfgScale(cfg_scale)) { @@ -653,7 +746,9 @@ export const useRecallParameters = () => { }); dispatch(controlNetReset()); - dispatch(controlNetEnabled()); + if (controlnets?.length) { + dispatch(controlNetEnabled()); + } controlnets?.forEach((controlnet) => { const result = prepareControlNetMetadataItem(controlnet); if (result.controlnet) { @@ -661,6 +756,16 @@ export const useRecallParameters = () => { } }); + if (ipAdapters?.length) { + dispatch(isIPAdapterEnabledChanged(true)); + } + ipAdapters?.forEach((ipAdapter) => { + const result = prepareIPAdapterMetadataItem(ipAdapter); + if (result.ipAdapter) { + dispatch(ipAdapterRecalled(result.ipAdapter)); + } + }); + allParameterSetToast(); }, [ @@ -669,6 +774,7 @@ export const useRecallParameters = () => { dispatch, prepareLoRAMetadataItem, prepareControlNetMetadataItem, + prepareIPAdapterMetadataItem, ] ); @@ -688,6 +794,7 @@ export const useRecallParameters = () => { recallStrength, recallLoRA, recallControlNet, + recallIPAdapter, recallAllParameters, sendToImageToImage, }; diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index 1b29993712..532bdb92f3 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -343,6 +343,12 @@ export type IPAdapterModelParam = z.infer; /** * Zod schema for l2l strength parameter */ +/** + * Validates/type-guards a value as a model parameter + */ +export const isValidIPAdapterModel = ( + val: unknown +): val is IPAdapterModelParam => zIPAdapterModel.safeParse(val).success; export const zStrength = z.number().min(0).max(1); /** * Type alias for l2l strength parameter, inferred from its zod schema diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index cc47c0766d..30f365951b 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -2115,6 +2115,11 @@ export type components = { * @description The ControlNets used for inference */ controlnets: components["schemas"]["ControlField"][]; + /** + * Loras + * @description The LoRAs used for inference + */ + ipAdapters: components["schemas"]["IPAdapterField"][]; /** * Loras * @description The LoRAs used for inference @@ -3178,7 +3183,7 @@ export type components = { * Image Encoder Model * @description The name of the CLIP image encoder model. */ - image_encoder_model: components["schemas"]["CLIPVisionModelField"]; + image_encoder_model?: components["schemas"]["CLIPVisionModelField"]; /** * Weight * @description The weight given to the ControlNet @@ -5814,6 +5819,11 @@ export type components = { * @description The LoRAs used for inference */ loras?: components["schemas"]["LoRAMetadataField"][]; + /** + * Strength + * @description The strength used for latents-to-latents + */ + ipAdapters?: components["schemas"]["IPAdapterField"][]; /** * Strength * @description The strength used for latents-to-latents