Use metadata ip adapter (#4715)

* add control net to useRecallParams

* got recall controlnets working

* fix metadata viewer controlnet

* fix type errors

* fix controlnet metadata viewer

* add ip adapter to metadata

* added ip adapter to recall parameters

* got ip adapter recall working, still need to fix type errors

* fix type issues

* clean up logs

* python formatting

* cleanup

* fix(ui): only store `image_name` as ip adapter image

* fix(ui): use nullish coalescing operator for numbers

Need to use the nullish coalescing operator `??` instead of false-y coalescing operator `||` when the value being check is a number. This prevents unintended coalescing when the value is zero and therefore false-y.

* feat(ui): fall back on default values for ip adapter metadata

* fix(ui): remove unused schema

* feat(ui): re-use existing schemas in metadata schema

* fix(ui): do not disable invocationCache

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
chainchompa
2023-09-28 05:05:32 -04:00
committed by GitHub
parent 309e2414ce
commit c7f80cd163
21 changed files with 225 additions and 40 deletions

View File

@@ -412,8 +412,9 @@ export type IPAdapterModel = z.infer<typeof zIPAdapterModel>;
export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: zIPAdapterModel,
image_encoder_model: z.string().trim().min(1),
weight: z.number(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
});
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
@@ -1145,6 +1146,10 @@ const zControlNetMetadataItem = zControlField.deepPartial();
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
export type IPAdapterMetadataItem = z.infer<typeof zIPAdapterMetadataItem>;
export const zCoreMetadata = z
.object({
app_version: z.string().nullish().catch(null),
@@ -1164,16 +1169,9 @@ export const zCoreMetadata = z
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
.nullish()
.catch(null),
controlnets: z.array(zControlField.deepPartial()).nullish().catch(null),
loras: z
.array(
z.object({
lora: zLoRAModelField.deepPartial(),
weight: z.number(),
})
)
.nullish()
.catch(null),
controlnets: z.array(zControlNetMetadataItem).nullish().catch(null),
ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null),
loras: z.array(zLoRAMetadataItem).nullish().catch(null),
vae: zVaeModelField.nullish().catch(null),
strength: z.number().nullish().catch(null),
init_image: z.string().nullish().catch(null),

View File

@@ -1,7 +1,14 @@
import { RootState } from 'app/store/store';
import { IPAdapterInvocation } from 'services/api/types';
import {
IPAdapterInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import { CANVAS_COHERENCE_DENOISE_LATENTS, IP_ADAPTER } from './constants';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,
IP_ADAPTER,
METADATA_ACCUMULATOR,
} from './constants';
export const addIPAdapterToLinearGraph = (
state: RootState,
@@ -10,9 +17,9 @@ export const addIPAdapterToLinearGraph = (
): void => {
const { isIPAdapterEnabled, ipAdapterInfo } = state.controlNet;
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
// | MetadataAccumulatorInvocation
// | undefined;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (isIPAdapterEnabled && ipAdapterInfo.model) {
const ipAdapterNode: IPAdapterInvocation = {
@@ -30,23 +37,29 @@ export const addIPAdapterToLinearGraph = (
if (ipAdapterInfo.adapterImage) {
ipAdapterNode.image = {
image_name: ipAdapterInfo.adapterImage.image_name,
image_name: ipAdapterInfo.adapterImage,
};
} else {
return;
}
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
if (metadataAccumulator?.ipAdapters) {
const ipAdapterField = {
image: {
image_name: ipAdapterInfo.adapterImage,
},
ip_adapter_model: {
base_model: ipAdapterInfo.model?.base_model,
model_name: ipAdapterInfo.model?.model_name,
},
weight: ipAdapterInfo.weight,
begin_step_percent: ipAdapterInfo.beginStepPct,
end_step_percent: ipAdapterInfo.endStepPct,
};
// if (metadataAccumulator?.ip_adapters) {
// // metadata accumulator only needs the ip_adapter field - not the whole node
// // extract what we need and add to the accumulator
// const ipAdapterField = omit(ipAdapterNode, [
// 'id',
// 'type',
// ]) as IPAdapterField;
// metadataAccumulator.ip_adapters.push(ipAdapterField);
// }
metadataAccumulator.ipAdapters.push(ipAdapterField);
}
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },

View File

@@ -327,6 +327,7 @@ export const buildCanvasImageToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
clip_skip: clipSkip,
strength,
init_image: initialImage.image_name,

View File

@@ -338,6 +338,7 @@ export const buildCanvasSDXLImageToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
strength,
init_image: initialImage.image_name,
};

View File

@@ -320,6 +320,7 @@ export const buildCanvasSDXLTextToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
};
graph.edges.push({

View File

@@ -308,6 +308,7 @@ export const buildCanvasTextToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
clip_skip: clipSkip,
};

View File

@@ -328,6 +328,7 @@ export const buildLinearImageToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
clip_skip: clipSkip,
strength,
init_image: initialImage.imageName,

View File

@@ -348,6 +348,7 @@ export const buildLinearSDXLImageToImageGraph = (
vae: undefined,
controlnets: [],
loras: [],
ipAdapters: [],
strength: strength,
init_image: initialImage.imageName,
positive_style_prompt: positiveStylePrompt,

View File

@@ -242,6 +242,7 @@ export const buildLinearSDXLTextToImageGraph = (
vae: undefined,
controlnets: [],
loras: [],
ipAdapters: [],
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
};

View File

@@ -250,6 +250,7 @@ export const buildLinearTextToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
clip_skip: clipSkip,
};