Use metadata ip adapter (#4715)

* add control net to useRecallParams

* got recall controlnets working

* fix metadata viewer controlnet

* fix type errors

* fix controlnet metadata viewer

* add ip adapter to metadata

* added ip adapter to recall parameters

* got ip adapter recall working, still need to fix type errors

* fix type issues

* clean up logs

* python formatting

* cleanup

* fix(ui): only store `image_name` as ip adapter image

* fix(ui): use nullish coalescing operator for numbers

Need to use the nullish coalescing operator `??` instead of false-y coalescing operator `||` when the value being check is a number. This prevents unintended coalescing when the value is zero and therefore false-y.

* feat(ui): fall back on default values for ip adapter metadata

* fix(ui): remove unused schema

* feat(ui): re-use existing schemas in metadata schema

* fix(ui): do not disable invocationCache

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
chainchompa
2023-09-28 05:05:32 -04:00
committed by GitHub
parent 309e2414ce
commit c7f80cd163
21 changed files with 225 additions and 40 deletions

View File

@@ -6,6 +6,7 @@ import {
CoreMetadata,
LoRAMetadataItem,
ControlNetMetadataItem,
IPAdapterMetadataItem,
} from 'features/nodes/types/types';
import {
refinerModelChanged,
@@ -23,16 +24,22 @@ import { useTranslation } from 'react-i18next';
import { ImageDTO } from 'services/api/types';
import {
controlNetModelsAdapter,
ipAdapterModelsAdapter,
useGetIPAdapterModelsQuery,
loraModelsAdapter,
useGetControlNetModelsQuery,
useGetLoRAModelsQuery,
} from '../../../services/api/endpoints/models';
import {
ControlNetConfig,
IPAdapterConfig,
controlNetEnabled,
controlNetRecalled,
controlNetReset,
initialControlNet,
initialIPAdapterState,
ipAdapterRecalled,
isIPAdapterEnabledChanged,
} from '../../controlNet/store/controlNetSlice';
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
import { initialImageSelected, modelSelected } from '../store/actions';
@@ -52,6 +59,7 @@ import {
isValidHeight,
isValidLoRAModel,
isValidControlNetModel,
isValidIPAdapterModel,
isValidMainModel,
isValidNegativePrompt,
isValidPositivePrompt,
@@ -512,8 +520,6 @@ export const useRecallParameters = () => {
})
);
dispatch(controlNetEnabled());
parameterSetToast();
},
[
@@ -524,6 +530,92 @@ export const useRecallParameters = () => {
]
);
/**
* Recall IP Adapter with toast
*/
const { ipAdapters } = useGetIPAdapterModelsQuery(undefined, {
selectFromResult: (result) => ({
ipAdapters: result.data
? ipAdapterModelsAdapter.getSelectors().selectAll(result.data)
: [],
}),
});
const prepareIPAdapterMetadataItem = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
if (!isValidIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) {
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
}
const {
image,
ip_adapter_model,
weight,
begin_step_percent,
end_step_percent,
} = ipAdapterMetadataItem;
const matchingIPAdapterModel = ipAdapters.find(
(c) =>
c.base_model === ip_adapter_model?.base_model &&
c.model_name === ip_adapter_model?.model_name
);
if (!matchingIPAdapterModel) {
return { ipAdapter: null, error: 'IP Adapter model is not installed' };
}
const isCompatibleBaseModel =
matchingIPAdapterModel?.base_model === model?.base_model;
if (!isCompatibleBaseModel) {
return {
ipAdapter: null,
error: 'IP Adapter incompatible with currently-selected model',
};
}
const ipAdapter: IPAdapterConfig = {
adapterImage: image?.image_name ?? null,
model: matchingIPAdapterModel,
weight: weight ?? initialIPAdapterState.weight,
beginStepPct: begin_step_percent ?? initialIPAdapterState.beginStepPct,
endStepPct: end_step_percent ?? initialIPAdapterState.endStepPct,
};
return { ipAdapter, error: null };
},
[ipAdapters, model?.base_model]
);
const recallIPAdapter = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
const result = prepareIPAdapterMetadataItem(ipAdapterMetadataItem);
if (!result.ipAdapter) {
parameterNotSetToast(result.error);
return;
}
dispatch(
ipAdapterRecalled({
...result.ipAdapter,
})
);
dispatch(isIPAdapterEnabledChanged(true));
parameterSetToast();
},
[
prepareIPAdapterMetadataItem,
dispatch,
parameterSetToast,
parameterNotSetToast,
]
);
/*
* Sets image as initial image with toast
*/
@@ -563,6 +655,7 @@ export const useRecallParameters = () => {
refiner_start,
loras,
controlnets,
ipAdapters,
} = metadata;
if (isValidCfgScale(cfg_scale)) {
@@ -653,7 +746,9 @@ export const useRecallParameters = () => {
});
dispatch(controlNetReset());
dispatch(controlNetEnabled());
if (controlnets?.length) {
dispatch(controlNetEnabled());
}
controlnets?.forEach((controlnet) => {
const result = prepareControlNetMetadataItem(controlnet);
if (result.controlnet) {
@@ -661,6 +756,16 @@ export const useRecallParameters = () => {
}
});
if (ipAdapters?.length) {
dispatch(isIPAdapterEnabledChanged(true));
}
ipAdapters?.forEach((ipAdapter) => {
const result = prepareIPAdapterMetadataItem(ipAdapter);
if (result.ipAdapter) {
dispatch(ipAdapterRecalled(result.ipAdapter));
}
});
allParameterSetToast();
},
[
@@ -669,6 +774,7 @@ export const useRecallParameters = () => {
dispatch,
prepareLoRAMetadataItem,
prepareControlNetMetadataItem,
prepareIPAdapterMetadataItem,
]
);
@@ -688,6 +794,7 @@ export const useRecallParameters = () => {
recallStrength,
recallLoRA,
recallControlNet,
recallIPAdapter,
recallAllParameters,
sendToImageToImage,
};

View File

@@ -343,6 +343,12 @@ export type IPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
/**
* Zod schema for l2l strength parameter
*/
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidIPAdapterModel = (
val: unknown
): val is IPAdapterModelParam => zIPAdapterModel.safeParse(val).success;
export const zStrength = z.number().min(0).max(1);
/**
* Type alias for l2l strength parameter, inferred from its zod schema