Files
InvokeAI/invokeai/app/invocations/metadata_linked.py
skunkworxdark 566282bff0 Update metadata_linked.py
added metadata_to_string_collection, metadata_to_integer_collection, metadata_to_float_collection, metadata_to_bool_collection
2025-04-16 06:28:22 +10:00

1305 lines
44 KiB
Python

# Adopted from @skunworkxdark's metadata nodes (MIT License)
# https://github.com/skunkworxdark/metadata-linked-nodes
# Thanks to @skunworkxdark for the original implementation!
import copy
from typing import Any, Dict, Literal, Optional, TypeVar, Union
from pydantic import model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet import ControlField, ControlNetInvocation
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
MetadataField,
OutputField,
UIType,
WithMetadata,
)
from invokeai.app.invocations.flux_denoise import FluxDenoiseInvocation
from invokeai.app.invocations.ip_adapter import IPAdapterField, IPAdapterInvocation
from invokeai.app.invocations.metadata import LoRAMetadataField, MetadataOutput
from invokeai.app.invocations.model import (
CLIPField,
LoRAField,
LoRALoaderOutput,
ModelIdentifierField,
SDXLLoRALoaderOutput,
UNetField,
VAEField,
VAEOutput,
)
from invokeai.app.invocations.primitives import (
BooleanCollectionOutput,
BooleanOutput,
FloatCollectionOutput,
FloatOutput,
IntegerCollectionOutput,
IntegerOutput,
LatentsOutput,
StringCollectionOutput,
StringOutput,
)
from invokeai.app.invocations.scheduler import SchedulerOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField, T2IAdapterInvocation
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import ModelType, SubModelType
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.version import __version__
CUSTOM_LABEL: str = "* CUSTOM LABEL *"
CORE_LABELS = Literal[
f"{CUSTOM_LABEL}",
"positive_prompt",
"positive_style_prompt",
"negative_prompt",
"negative_style_prompt",
"width",
"height",
"seed",
"cfg_scale",
"cfg_rescale_multiplier",
"steps",
"scheduler",
"clip_skip",
"model",
"vae",
"seamless_x",
"seamless_y",
"guidance",
"cfg_scale_start_step",
"cfg_scale_end_step",
]
CORE_LABELS_STRING = Literal[
f"{CUSTOM_LABEL}",
"positive_prompt",
"positive_style_prompt",
"negative_prompt",
"negative_style_prompt",
]
CORE_LABELS_INTEGER = Literal[
f"{CUSTOM_LABEL}",
"width",
"height",
"seed",
"steps",
"clip_skip",
"cfg_scale_start_step",
"cfg_scale_end_step",
]
CORE_LABELS_FLOAT = Literal[
f"{CUSTOM_LABEL}",
"cfg_scale",
"cfg_rescale_multiplier",
"guidance",
]
CORE_LABELS_BOOL = Literal[
f"{CUSTOM_LABEL}",
"seamless_x",
"seamless_y",
]
CORE_LABELS_SCHEDULER = Literal[
f"{CUSTOM_LABEL}",
"scheduler",
]
CORE_LABELS_MODEL = Literal[
f"{CUSTOM_LABEL}",
"model",
]
CORE_LABELS_VAE = Literal[
f"{CUSTOM_LABEL}",
"vae",
]
T = TypeVar("T")
def append_list(item_cls: type[T], new_item: T, items: Union[T, list[T], None] = None) -> list[T]:
"""Combines any number of items or lists into a single list,
ensuring consistency in type.
Args:
item_cls: The expected type of elements in the list.
items: An existing list or single item of type `item_cls`.
new_items: Additional item(s) to append. (default=None)
Returns:
The updated list containing valid items.
Raises:
ValueError: If any item in the list or new_item is not of the expected type.
"""
if not isinstance(new_item, item_cls):
raise ValueError(f"Invalid new_item type in: {new_item}, expected {item_cls}")
if items is None:
return [new_item]
result: list[T] = []
if isinstance(items, item_cls):
result.append(items)
elif isinstance(items, list) and all(isinstance(i, item_cls) for i in items):
result.extend(items)
else:
raise ValueError(f"Invalid items type in: {items}, expected {item_cls}")
result.append(new_item)
return result
def validate_custom_label(
model: Union[
"MetadataItemLinkedInvocation",
"MetadataToStringInvocation",
"MetadataToIntegerInvocation",
"MetadataToFloatInvocation",
"MetadataToBoolInvocation",
"MetadataToSchedulerInvocation",
"MetadataToModelInvocation",
"MetadataToSDXLModelInvocation",
"MetadataToVAEInvocation",
],
):
if model.label == CUSTOM_LABEL:
if model.custom_label is None or model.custom_label.strip() == "":
raise ValueError("You must enter a Custom Label")
return model
def extract_model_key(
metadata: dict[str, Any],
label: Union[str, None],
default_key: str,
model_type: ModelType,
context: InvocationContext,
) -> str:
"""
Extracts a model key from the metadata based on the given label.
Args:
metadata (dict): The metadata root dictionary.
label (str): The label to search for.
default_key (str): The default model key to return if not found.
model_type (ModelType): model_type to use in the search if a model name_is found in the metadata
context (object): The context object containing models.
Returns:
Model key
"""
if label in metadata:
if "key" in metadata[label]:
if context.models.exists(metadata[label]["key"]):
return metadata[label]["key"]
if "name" in metadata[label]:
search_model = context.models.search_by_attrs(name=metadata[label]["name"], type=model_type)
if len(search_model) > 0:
return search_model[0].key
if "model_name" in metadata[label]:
search_model = context.models.search_by_attrs(name=metadata[label]["model_name"], type=model_type)
if len(search_model) > 0:
return search_model[0].key
return default_key
def get_model(
model_key: str,
context: InvocationContext,
) -> ModelIdentifierField:
"""
Gets a model based upon a model_key
Args:
mode_key (str): The model key to get
context (object): The context object containing models.
Returns:
ModelIdentifierField
"""
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
x = context.models.get_config(model_key)
return ModelIdentifierField.from_config(x)
@invocation(
"metadata_item_linked",
title="Metadata Item Linked",
tags=["metadata"],
category="metadata",
version="1.0.1",
classification=Classification.Beta,
)
class MetadataItemLinkedInvocation(BaseInvocation, WithMetadata):
"""Used to Create/Add/Update a value into a metadata label"""
label: CORE_LABELS = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any)
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> MetadataOutput:
k = self.custom_label if self.label == CUSTOM_LABEL else self.label
v = self.value.vae if isinstance(self.value, VAEField) else self.value
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
data.update({str(k): v})
data.update({"app_version": __version__})
return MetadataOutput(metadata=MetadataField.model_validate(data))
@invocation(
"metadata_from_image",
title="Metadata From Image",
tags=["metadata"],
category="metadata",
version="1.0.1",
classification=Classification.Beta,
)
class MetadataFromImageInvocation(BaseInvocation):
"""Used to create a core metadata item then Add/Update it to the provided metadata"""
image: ImageField = InputField(description=FieldDescriptions.image)
def invoke(self, context: InvocationContext) -> MetadataOutput:
data: Dict[str, Any] = {}
image_metadata = context.images.get_metadata(self.image.image_name)
if image_metadata is not None:
data.update(image_metadata.root)
return MetadataOutput(metadata=MetadataField.model_validate(data))
@invocation(
"metadata_to_string",
title="Metadata To String",
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Beta,
)
class MetadataToStringInvocation(BaseInvocation, WithMetadata):
"""Extracts a string value of a label from metadata"""
label: CORE_LABELS_STRING = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: str = InputField(description="The default string to use if not found in the metadata")
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> StringOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return StringOutput(value=str(output))
@invocation(
"metadata_to_integer",
title="Metadata To Integer",
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Beta,
)
class MetadataToIntegerInvocation(BaseInvocation, WithMetadata):
"""Extracts an integer value of a label from metadata"""
label: CORE_LABELS_INTEGER = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: int = InputField(description="The default integer to use if not found in the metadata")
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> IntegerOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return IntegerOutput(value=int(output))
@invocation(
"metadata_to_float",
title="Metadata To Float",
tags=["metadata"],
category="metadata",
version="1.1.0",
classification=Classification.Beta,
)
class MetadataToFloatInvocation(BaseInvocation, WithMetadata):
"""Extracts a Float value of a label from metadata"""
label: CORE_LABELS_FLOAT = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: float = InputField(description="The default float to use if not found in the metadata")
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> FloatOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return FloatOutput(value=float(output))
@invocation(
"metadata_to_bool",
title="Metadata To Bool",
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Beta,
)
class MetadataToBoolInvocation(BaseInvocation, WithMetadata):
"""Extracts a Boolean value of a label from metadata"""
label: CORE_LABELS_BOOL = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: bool = InputField(description="The default bool to use if not found in the metadata")
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> BooleanOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return BooleanOutput(value=bool(output))
@invocation(
"metadata_to_scheduler",
title="Metadata To Scheduler",
tags=["metadata"],
category="metadata",
version="1.0.1",
classification=Classification.Beta,
)
class MetadataToSchedulerInvocation(BaseInvocation, WithMetadata):
"""Extracts a Scheduler value of a label from metadata"""
label: CORE_LABELS_SCHEDULER = InputField(
default="scheduler",
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description="The default scheduler to use if not found in the metadata",
ui_type=UIType.Scheduler,
)
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> SchedulerOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return SchedulerOutput(scheduler=output)
@invocation_output("metadata_to_model_output")
class MetadataToModelOutput(BaseInvocationOutput):
"""String to main model output"""
model: ModelIdentifierField = OutputField(
description=FieldDescriptions.main_model,
title="Model",
ui_type=UIType.MainModel,
)
name: str = OutputField(description="Model Name", title="Name")
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
@invocation_output("metadata_to_sdxl_model_output")
class MetadataToSDXLModelOutput(BaseInvocationOutput):
"""String to SDXL main model output"""
model: ModelIdentifierField = OutputField(
description=FieldDescriptions.main_model,
title="Model",
ui_type=UIType.SDXLMainModel,
)
name: str = OutputField(description="Model Name", title="Name")
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"metadata_to_model",
title="Metadata To Model",
tags=["metadata"],
category="metadata",
version="1.3.0",
classification=Classification.Beta,
)
class MetadataToModelInvocation(BaseInvocation, WithMetadata):
"""Extracts a Model value of a label from metadata"""
label: CORE_LABELS_MODEL = InputField(
default="model",
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: ModelIdentifierField = InputField(
description="The default model to use if not found in the metadata",
ui_type=UIType.MainModel,
)
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> MetadataToModelOutput:
data = {} if self.metadata is None else self.metadata.root
label = self.custom_label if self.label == CUSTOM_LABEL else self.label
model_key = extract_model_key(data, label, self.default_value.key, ModelType.Main, context)
model = get_model(model_key, context)
return MetadataToModelOutput(
model=model,
name=f"{model.base}: {model.name}",
unet=UNetField(
unet=model.model_copy(update={"submodel_type": SubModelType.UNet}),
scheduler=model.model_copy(update={"submodel_type": SubModelType.Scheduler}),
loras=[],
),
clip=CLIPField(
tokenizer=model.model_copy(update={"submodel_type": SubModelType.Tokenizer}),
text_encoder=model.model_copy(update={"submodel_type": SubModelType.TextEncoder}),
loras=[],
skipped_layers=0,
),
vae=VAEField(
vae=model.model_copy(update={"submodel_type": SubModelType.VAE}),
),
)
@invocation(
"metadata_to_sdxl_model",
title="Metadata To SDXL Model",
tags=["metadata"],
category="metadata",
version="1.3.0",
classification=Classification.Beta,
)
class MetadataToSDXLModelInvocation(BaseInvocation, WithMetadata):
"""Extracts a SDXL Model value of a label from metadata"""
label: CORE_LABELS_MODEL = InputField(
default="model",
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: ModelIdentifierField = InputField(
description="The default SDXL Model to use if not found in the metadata",
ui_type=UIType.SDXLMainModel,
)
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> MetadataToSDXLModelOutput:
data = {} if self.metadata is None else self.metadata.root
label = self.custom_label if self.label == CUSTOM_LABEL else self.label
model_key = extract_model_key(data, label, self.default_value.key, ModelType.Main, context)
model = get_model(model_key, context)
return MetadataToSDXLModelOutput(
model=model,
name=f"{model.base}: {model.name}",
unet=UNetField(
unet=model.model_copy(update={"submodel_type": SubModelType.UNet}),
scheduler=model.model_copy(update={"submodel_type": SubModelType.Scheduler}),
loras=[],
),
clip=CLIPField(
tokenizer=model.model_copy(update={"submodel_type": SubModelType.Tokenizer}),
text_encoder=model.model_copy(update={"submodel_type": SubModelType.TextEncoder}),
loras=[],
skipped_layers=0,
),
clip2=CLIPField(
tokenizer=model.model_copy(update={"submodel_type": SubModelType.Tokenizer2}),
text_encoder=model.model_copy(update={"submodel_type": SubModelType.TextEncoder2}),
loras=[],
skipped_layers=0,
),
vae=VAEField(
vae=model.model_copy(update={"submodel_type": SubModelType.VAE}),
),
)
@invocation_output("latents_meta_output")
class LatentsMetaOutput(LatentsOutput, MetadataOutput):
"""Latents + metadata"""
@invocation(
"denoise_latents_meta",
title=f"{DenoiseLatentsInvocation.UIConfig.title} + Metadata",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.1.1",
)
class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata):
def invoke(self, context: InvocationContext) -> LatentsMetaOutput:
def _to_json(obj: Union[Any, list[Any]]):
if not isinstance(obj, list):
obj = [obj]
return [
item.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
for item in obj
]
def _loras_to_json(obj: Union[Any, list[Any]]):
if not isinstance(obj, list):
obj = [obj]
output: list[dict[str, Any]] = []
for item in obj:
output.append(
LoRAMetadataField(
model=item.lora,
weight=item.weight,
).model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
)
return output
obj = super().invoke(context)
md: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
md.update({"width": obj.width})
md.update({"height": obj.height})
md.update({"steps": self.steps})
md.update({"cfg_scale": self.cfg_scale})
md.update({"cfg_rescale_multiplier": self.cfg_rescale_multiplier})
md.update({"denoising_start": self.denoising_start})
md.update({"denoising_end": self.denoising_end})
md.update({"scheduler": self.scheduler})
md.update({"model": self.unet.unet})
if isinstance(self.control, ControlField) or (isinstance(self.control, list) and len(self.control) > 0):
md.update({"controlnets": _to_json(self.control)})
if isinstance(self.ip_adapter, IPAdapterField) or (
isinstance(self.ip_adapter, list) and len(self.ip_adapter) > 0
):
md.update({"ipAdapters": _to_json(self.ip_adapter)})
if isinstance(self.t2i_adapter, T2IAdapterField) or (
isinstance(self.t2i_adapter, list) and len(self.t2i_adapter) > 0
):
md.update({"t2iAdapters": _to_json(self.t2i_adapter)})
if len(self.unet.loras) > 0:
md.update({"loras": _loras_to_json(self.unet.loras)})
if self.noise is not None:
md.update({"seed": self.noise.seed})
params = obj.__dict__.copy()
del params["type"]
return LatentsMetaOutput(**params, metadata=MetadataField.model_validate(md))
@invocation(
"flux_denoise_meta",
title=f"{FluxDenoiseInvocation.UIConfig.title} + Metadata",
tags=["flux", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.0.1",
)
class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata):
"""Run denoising process with a FLUX transformer model + metadata."""
def invoke(self, context: InvocationContext) -> LatentsMetaOutput:
def _loras_to_json(obj: Union[Any, list[Any]]):
if not isinstance(obj, list):
obj = [obj]
output: list[dict[str, Any]] = []
for item in obj:
output.append(
LoRAMetadataField(
model=item.lora,
weight=item.weight,
).model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
)
return output
obj = super().invoke(context)
md: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
md.update({"width": obj.width})
md.update({"height": obj.height})
md.update({"steps": self.num_steps})
md.update({"guidance": self.guidance})
md.update({"denoising_start": self.denoising_start})
md.update({"denoising_end": self.denoising_end})
md.update({"model": self.transformer.transformer})
md.update({"seed": self.seed})
md.update({"cfg_scale": self.cfg_scale})
md.update({"cfg_scale_start_step": self.cfg_scale_start_step})
md.update({"cfg_scale_end_step": self.cfg_scale_end_step})
if len(self.transformer.loras) > 0:
md.update({"loras": _loras_to_json(self.transformer.loras)})
params = obj.__dict__.copy()
del params["type"]
return LatentsMetaOutput(**params, metadata=MetadataField.model_validate(md))
@invocation(
"metadata_to_vae",
title="Metadata To VAE",
tags=["metadata"],
category="metadata",
version="1.2.1",
classification=Classification.Beta,
)
class MetadataToVAEInvocation(BaseInvocation, WithMetadata):
"""Extracts a VAE value of a label from metadata"""
label: CORE_LABELS_VAE = InputField(
default="vae",
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: VAEField = InputField(
description="The default VAE to use if not found in the metadata",
)
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> VAEOutput:
data = {} if self.metadata is None else self.metadata.root
label = self.custom_label if self.label == CUSTOM_LABEL else self.label
model_key = extract_model_key(data, label, self.default_value.vae.key, ModelType.VAE, context)
model = get_model(model_key, context)
model.submodel_type = SubModelType.VAE
return VAEOutput(vae=VAEField(vae=model))
@invocation_output("metadata_to_lora_collection_output")
class MetadataToLorasCollectionOutput(BaseInvocationOutput):
"""Model loader output"""
lora: list[LoRAField] = OutputField(description="Collection of LoRA model and weights", title="LoRAs")
@invocation(
"metadata_to_lora_collection",
title="Metadata To LoRA Collection",
tags=["metadata"],
category="metadata",
version="1.1.0",
classification=Classification.Beta,
)
class MetadataToLorasCollectionInvocation(BaseInvocation, WithMetadata):
"""Extracts Lora(s) from metadata into a collection"""
custom_label: str = InputField(
default="loras",
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=[], description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
def invoke(self, context: InvocationContext) -> MetadataToLorasCollectionOutput:
metadata = {} if self.metadata is None else self.metadata.root
key: str = self.custom_label.strip()
if not key:
key = "loras"
if key in metadata:
loras = metadata[key]
else:
loras = []
input_loras = self.loras if isinstance(self.loras, list) else [self.loras]
output = MetadataToLorasCollectionOutput(lora=[])
added_loras: list[str] = []
for lora in input_loras:
assert lora is LoRAField
if lora.lora.key in added_loras:
continue
output.lora.append(lora)
added_loras.append(lora.lora.key)
for lora in loras:
model_key = extract_model_key(lora, "model", "", ModelType.LoRA, context)
if not model_key:
model_key = extract_model_key(lora, "lora", "", ModelType.LoRA, context)
if model_key:
model = get_model(model_key, context)
weight = float(lora["weight"])
if model.key in added_loras:
continue
output.lora.append(LoRAField(lora=model, weight=weight))
return output
@invocation(
"metadata_to_loras",
title="Metadata To LoRAs",
tags=["metadata"],
category="metadata",
version="1.1.1",
classification=Classification.Beta,
)
class MetadataToLorasInvocation(BaseInvocation, WithMetadata):
"""Extracts a Loras value of a label from metadata"""
unet: Optional[UNetField] = InputField(
default=None,
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
data = {} if self.metadata is None else self.metadata.root
key = "loras"
if key in data:
loras = data[key]
else:
loras = []
output = LoRALoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
for lora in loras:
model_key = extract_model_key(lora, "model", "", ModelType.LoRA, context)
if model_key != "":
model = get_model(model_key, context)
weight = float(lora["weight"])
if output.unet is not None:
if any(lora.lora.key == model_key for lora in output.unet.loras):
context.logger.info(f'LoRA "{model_key}" already applied to unet')
else:
output.unet.loras.append(
LoRAField(
lora=model,
weight=weight,
)
)
if output.clip is not None:
if any(lora.lora.key == model_key for lora in output.clip.loras):
context.logger.info(f'LoRA "{model_key}" already applied to clip')
else:
output.clip.loras.append(
LoRAField(
lora=model,
weight=weight,
)
)
return output
@invocation(
"metadata_to_sdlx_loras",
title="Metadata To SDXL LoRAs",
tags=["metadata"],
category="metadata",
version="1.1.1",
classification=Classification.Beta,
)
class MetadataToSDXLLorasInvocation(BaseInvocation, WithMetadata):
"""Extracts a SDXL Loras value of a label from metadata"""
unet: Optional[UNetField] = InputField(
default=None,
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 1",
)
clip2: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 2",
)
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
data = {} if self.metadata is None else self.metadata.root
key = "loras"
if key in data:
loras = data[key]
else:
loras = []
output = SDXLLoRALoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
if self.clip2 is not None:
output.clip2 = copy.deepcopy(self.clip2)
for lora in loras:
model_key = extract_model_key(lora, "model", "", ModelType.LoRA, context)
if model_key != "":
model = get_model(model_key, context)
weight = float(lora["weight"])
if output.unet is not None:
if any(lora.lora.key == model_key for lora in output.unet.loras):
context.logger.info(f'LoRA "{model_key}" already applied to unet')
else:
output.unet.loras.append(
LoRAField(
lora=model,
weight=weight,
)
)
if output.clip is not None:
if any(lora.lora.key == model_key for lora in output.clip.loras):
context.logger.info(f'LoRA "{model_key}" already applied to clip')
else:
output.clip.loras.append(
LoRAField(
lora=model,
weight=weight,
)
)
if output.clip2 is not None:
if any(lora.lora.key == model_key for lora in output.clip2.loras):
context.logger.info(f'LoRA "{model_key}" already applied to clip')
else:
output.clip2.loras.append(
LoRAField(
lora=model,
weight=weight,
)
)
return output
@invocation_output("md_control_list_output")
class MDControlListOutput(BaseInvocationOutput):
# Outputs
control_list: Optional[Union[ControlField, list[ControlField]]] = OutputField(
description=FieldDescriptions.control,
title="ControlNet-List",
)
@invocation(
"metadata_to_controlnets",
title="Metadata To ControlNets",
tags=["metadata"],
category="metadata",
version="1.2.0",
classification=Classification.Beta,
)
class MetadataToControlnetsInvocation(BaseInvocation, WithMetadata):
"""Extracts a Controlnets value of a label from metadata"""
control_list: Optional[Union[ControlField, list[ControlField]]] = InputField(
default=None,
title="ControlNet-List",
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> MDControlListOutput:
data = {} if self.metadata is None else self.metadata.root
key = "controlnets"
if key in data:
md_controls = data[key]
else:
md_controls = []
controls: Optional[Union[ControlField, list[ControlField]]]
if self.control_list is not None:
controls = self.control_list
else:
controls = []
for x in md_controls:
model_key = extract_model_key(x, "control_model", "", ModelType.ControlNet, context)
model = get_model(model_key, context)
cn = ControlNetInvocation(
image=x["image"],
control_model=model,
control_weight=x["control_weight"],
begin_step_percent=x["begin_step_percent"],
end_step_percent=x["end_step_percent"],
control_mode=x["control_mode"],
resize_mode=x["resize_mode"],
)
i = cn.invoke(context)
controls = append_list(ControlField, i.control, controls)
return MDControlListOutput(control_list=controls)
@invocation_output("md_ip_adapter_list_output")
class MDIPAdapterListOutput(BaseInvocationOutput):
# Outputs
ip_adapter_list: Optional[Union[IPAdapterField, list[IPAdapterField]]] = OutputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter-List"
)
@invocation(
"metadata_to_ip_adapters",
title="Metadata To IP-Adapters",
tags=["metadata"],
category="metadata",
version="1.2.0",
classification=Classification.Beta,
)
class MetadataToIPAdaptersInvocation(BaseInvocation, WithMetadata):
"""Extracts a IP-Adapters value of a label from metadata"""
ip_adapter_list: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
description=FieldDescriptions.ip_adapter,
title="IP-Adapter-List",
default=None,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> MDIPAdapterListOutput:
data = {} if self.metadata is None else self.metadata.root
key = "ipAdapters"
if key in data:
md_adapters = data[key]
else:
md_adapters = []
adapters: Optional[Union[IPAdapterField, list[IPAdapterField]]]
if self.ip_adapter_list is not None:
adapters = self.ip_adapter_list
else:
adapters = []
for x in md_adapters:
model_key = extract_model_key(x, "ip_adapter_model", "", ModelType.IPAdapter, context)
model = get_model(model_key, context)
ipa = IPAdapterInvocation(
image=x["image"],
ip_adapter_model=model,
weight=x["weight"],
begin_step_percent=x["begin_step_percent"],
end_step_percent=x["end_step_percent"],
)
i = ipa.invoke(context)
adapters = append_list(IPAdapterField, i.ip_adapter, adapters)
return MDIPAdapterListOutput(ip_adapter_list=adapters)
@invocation_output("md_ip_adapters_output")
class MDT2IAdapterListOutput(BaseInvocationOutput):
# Outputs
t2i_adapter_list: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = OutputField(
description=FieldDescriptions.t2i_adapter, title="T2I Adapter-List"
)
@invocation(
"metadata_to_t2i_adapters",
title="Metadata To T2I-Adapters",
tags=["metadata"],
category="metadata",
version="1.2.0",
classification=Classification.Beta,
)
class MetadataToT2IAdaptersInvocation(BaseInvocation, WithMetadata):
"""Extracts a T2I-Adapters value of a label from metadata"""
t2i_adapter_list: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
description=FieldDescriptions.ip_adapter,
title="T2I-Adapter",
default=None,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> MDT2IAdapterListOutput:
data = {} if self.metadata is None else self.metadata.root
key = "t2iAdapters"
if key in data:
md_adapters = data[key]
else:
md_adapters = []
adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]]
if self.t2i_adapter_list is not None:
adapters = self.t2i_adapter_list
else:
adapters = []
for x in md_adapters:
model_key = extract_model_key(x, "t2i_adapter_model", "", ModelType.T2IAdapter, context)
model = get_model(model_key, context)
t2i = T2IAdapterInvocation(
image=x["image"],
t2i_adapter_model=model,
weight=x["weight"],
begin_step_percent=x["begin_step_percent"],
end_step_percent=x["end_step_percent"],
resize_mode=x["resize_mode"],
)
i = t2i.invoke(context)
adapters = append_list(T2IAdapterField, i.t2i_adapter, adapters)
return MDT2IAdapterListOutput(t2i_adapter_list=adapters)
@invocation(
"metadata_to_string_collection",
title="Metadata To String Collection",
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Beta,
)
class MetadataToStringCollectionInvocation(BaseInvocation, WithMetadata):
"""Extracts a string collection value of a label from metadata"""
label: CORE_LABELS_STRING = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: list[str] = InputField(
description="The default string collection to use if not found in the metadata"
)
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return StringCollectionOutput(collection=output)
@invocation(
"metadata_to_integer_collection",
title="Metadata To Integer Collection",
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Beta,
)
class MetadataToIntegerCollectionInvocation(BaseInvocation, WithMetadata):
"""Extracts an integer value Collection of a label from metadata"""
label: CORE_LABELS_INTEGER = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: list[int] = InputField(description="The default integer to use if not found in the metadata")
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return IntegerCollectionOutput(collection=output)
@invocation(
"metadata_to_float_collection",
title="Metadata To Float Collection",
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Beta,
)
class MetadataToFloatCollectionInvocation(BaseInvocation, WithMetadata):
"""Extracts a Float value Collection of a label from metadata"""
label: CORE_LABELS_FLOAT = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: list[float] = InputField(description="The default float to use if not found in the metadata")
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return FloatCollectionOutput(collection=output)
@invocation(
"metadata_to_bool_collection",
title="Metadata To Bool Collection",
tags=["metadata"],
category="metadata",
version="1.0.0",
classification=Classification.Beta,
)
class MetadataToBoolCollectionInvocation(BaseInvocation, WithMetadata):
"""Extracts a Boolean value Collection of a label from metadata"""
label: CORE_LABELS_BOOL = InputField(
default=CUSTOM_LABEL,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
custom_label: Optional[str] = InputField(
default=None,
description=FieldDescriptions.metadata_item_label,
input=Input.Direct,
)
default_value: list[bool] = InputField(description="The default bool to use if not found in the metadata")
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
data: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
output = data.get(str(self.custom_label if self.label == CUSTOM_LABEL else self.label), self.default_value)
return BooleanCollectionOutput(collection=output)