mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Use metadata ip adapter (#4715)
* add control net to useRecallParams * got recall controlnets working * fix metadata viewer controlnet * fix type errors * fix controlnet metadata viewer * add ip adapter to metadata * added ip adapter to recall parameters * got ip adapter recall working, still need to fix type errors * fix type issues * clean up logs * python formatting * cleanup * fix(ui): only store `image_name` as ip adapter image * fix(ui): use nullish coalescing operator for numbers Need to use the nullish coalescing operator `??` instead of false-y coalescing operator `||` when the value being check is a number. This prevents unintended coalescing when the value is zero and therefore false-y. * feat(ui): fall back on default values for ip adapter metadata * fix(ui): remove unused schema * feat(ui): re-use existing schemas in metadata schema * fix(ui): do not disable invocationCache --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import {
|
||||
CoreMetadata,
|
||||
LoRAMetadataItem,
|
||||
ControlNetMetadataItem,
|
||||
IPAdapterMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
@@ -23,16 +24,22 @@ import { useTranslation } from 'react-i18next';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
controlNetModelsAdapter,
|
||||
ipAdapterModelsAdapter,
|
||||
useGetIPAdapterModelsQuery,
|
||||
loraModelsAdapter,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
} from '../../../services/api/endpoints/models';
|
||||
import {
|
||||
ControlNetConfig,
|
||||
IPAdapterConfig,
|
||||
controlNetEnabled,
|
||||
controlNetRecalled,
|
||||
controlNetReset,
|
||||
initialControlNet,
|
||||
initialIPAdapterState,
|
||||
ipAdapterRecalled,
|
||||
isIPAdapterEnabledChanged,
|
||||
} from '../../controlNet/store/controlNetSlice';
|
||||
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
@@ -52,6 +59,7 @@ import {
|
||||
isValidHeight,
|
||||
isValidLoRAModel,
|
||||
isValidControlNetModel,
|
||||
isValidIPAdapterModel,
|
||||
isValidMainModel,
|
||||
isValidNegativePrompt,
|
||||
isValidPositivePrompt,
|
||||
@@ -512,8 +520,6 @@ export const useRecallParameters = () => {
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(controlNetEnabled());
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[
|
||||
@@ -524,6 +530,92 @@ export const useRecallParameters = () => {
|
||||
]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall IP Adapter with toast
|
||||
*/
|
||||
|
||||
const { ipAdapters } = useGetIPAdapterModelsQuery(undefined, {
|
||||
selectFromResult: (result) => ({
|
||||
ipAdapters: result.data
|
||||
? ipAdapterModelsAdapter.getSelectors().selectAll(result.data)
|
||||
: [],
|
||||
}),
|
||||
});
|
||||
|
||||
const prepareIPAdapterMetadataItem = useCallback(
|
||||
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
|
||||
if (!isValidIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) {
|
||||
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
|
||||
}
|
||||
|
||||
const {
|
||||
image,
|
||||
ip_adapter_model,
|
||||
weight,
|
||||
begin_step_percent,
|
||||
end_step_percent,
|
||||
} = ipAdapterMetadataItem;
|
||||
|
||||
const matchingIPAdapterModel = ipAdapters.find(
|
||||
(c) =>
|
||||
c.base_model === ip_adapter_model?.base_model &&
|
||||
c.model_name === ip_adapter_model?.model_name
|
||||
);
|
||||
|
||||
if (!matchingIPAdapterModel) {
|
||||
return { ipAdapter: null, error: 'IP Adapter model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel =
|
||||
matchingIPAdapterModel?.base_model === model?.base_model;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
ipAdapter: null,
|
||||
error: 'IP Adapter incompatible with currently-selected model',
|
||||
};
|
||||
}
|
||||
|
||||
const ipAdapter: IPAdapterConfig = {
|
||||
adapterImage: image?.image_name ?? null,
|
||||
model: matchingIPAdapterModel,
|
||||
weight: weight ?? initialIPAdapterState.weight,
|
||||
beginStepPct: begin_step_percent ?? initialIPAdapterState.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialIPAdapterState.endStepPct,
|
||||
};
|
||||
|
||||
return { ipAdapter, error: null };
|
||||
},
|
||||
[ipAdapters, model?.base_model]
|
||||
);
|
||||
|
||||
const recallIPAdapter = useCallback(
|
||||
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
|
||||
const result = prepareIPAdapterMetadataItem(ipAdapterMetadataItem);
|
||||
|
||||
if (!result.ipAdapter) {
|
||||
parameterNotSetToast(result.error);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
ipAdapterRecalled({
|
||||
...result.ipAdapter,
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(isIPAdapterEnabledChanged(true));
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[
|
||||
prepareIPAdapterMetadataItem,
|
||||
dispatch,
|
||||
parameterSetToast,
|
||||
parameterNotSetToast,
|
||||
]
|
||||
);
|
||||
|
||||
/*
|
||||
* Sets image as initial image with toast
|
||||
*/
|
||||
@@ -563,6 +655,7 @@ export const useRecallParameters = () => {
|
||||
refiner_start,
|
||||
loras,
|
||||
controlnets,
|
||||
ipAdapters,
|
||||
} = metadata;
|
||||
|
||||
if (isValidCfgScale(cfg_scale)) {
|
||||
@@ -653,7 +746,9 @@ export const useRecallParameters = () => {
|
||||
});
|
||||
|
||||
dispatch(controlNetReset());
|
||||
dispatch(controlNetEnabled());
|
||||
if (controlnets?.length) {
|
||||
dispatch(controlNetEnabled());
|
||||
}
|
||||
controlnets?.forEach((controlnet) => {
|
||||
const result = prepareControlNetMetadataItem(controlnet);
|
||||
if (result.controlnet) {
|
||||
@@ -661,6 +756,16 @@ export const useRecallParameters = () => {
|
||||
}
|
||||
});
|
||||
|
||||
if (ipAdapters?.length) {
|
||||
dispatch(isIPAdapterEnabledChanged(true));
|
||||
}
|
||||
ipAdapters?.forEach((ipAdapter) => {
|
||||
const result = prepareIPAdapterMetadataItem(ipAdapter);
|
||||
if (result.ipAdapter) {
|
||||
dispatch(ipAdapterRecalled(result.ipAdapter));
|
||||
}
|
||||
});
|
||||
|
||||
allParameterSetToast();
|
||||
},
|
||||
[
|
||||
@@ -669,6 +774,7 @@ export const useRecallParameters = () => {
|
||||
dispatch,
|
||||
prepareLoRAMetadataItem,
|
||||
prepareControlNetMetadataItem,
|
||||
prepareIPAdapterMetadataItem,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -688,6 +794,7 @@ export const useRecallParameters = () => {
|
||||
recallStrength,
|
||||
recallLoRA,
|
||||
recallControlNet,
|
||||
recallIPAdapter,
|
||||
recallAllParameters,
|
||||
sendToImageToImage,
|
||||
};
|
||||
|
||||
@@ -343,6 +343,12 @@ export type IPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
|
||||
/**
|
||||
* Zod schema for l2l strength parameter
|
||||
*/
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
export const isValidIPAdapterModel = (
|
||||
val: unknown
|
||||
): val is IPAdapterModelParam => zIPAdapterModel.safeParse(val).success;
|
||||
export const zStrength = z.number().min(0).max(1);
|
||||
/**
|
||||
* Type alias for l2l strength parameter, inferred from its zod schema
|
||||
|
||||
Reference in New Issue
Block a user