Compare commits

...

5 Commits

35 changed files with 2007 additions and 42 deletions

View File

@@ -56,6 +56,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
CLIPLEmbedModel = "CLIPLEmbedModelField"
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
StructuralLoRAModel = "StructuralLoRAModelField"
# endregion
# region Misc Field Types
@@ -143,6 +144,7 @@ class FieldDescriptions:
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
structural_lora_model = "Structural LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"

View File

@@ -1,5 +1,5 @@
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple
from typing import Callable, Iterator, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
@@ -22,7 +22,7 @@ from invokeai.app.invocations.fields import (
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.model import LoRAField, StructuralLoRAField, TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
@@ -33,8 +33,10 @@ from invokeai.backend.flux.extensions.instantx_controlnet_extension import Insta
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule_fractional,
generate_img_ids,
@@ -284,6 +286,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
device=x.device,
)
img_cond = None
if struct_lora := self.transformer.structural_lora:
# What should we do when we have multiple of these?
if not self.controlnet_vae:
raise ValueError("controlnet_vae must be set when using a strutural lora")
ae_info = context.models.load(self.controlnet_vae.vae)
img = context.images.get_pil(struct_lora.img.image_name)
with ae_info as ae:
assert isinstance(ae, AutoEncoder)
img_cond = prepare_control(self.height, self.width, self.seed, ae, img)
# Load the transformer model.
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
@@ -345,6 +357,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
controlnet_extensions=controlnet_extensions,
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond,
)
x = unpack(x.float(), self.height, self.width)
@@ -682,7 +695,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras]
if self.transformer.structural_lora:
loras.append(self.transformer.structural_lora)
for lora in loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)

View File

@@ -81,8 +81,10 @@ class FluxModelLoaderInvocation(BaseInvocation):
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
transformer=TransformerField(transformer=transformer, loras=[], structural_loras=[]),
clip=CLIPField(
tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0
),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],

View File

@@ -0,0 +1,74 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField, StructuralLoRAField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_structural_lora_loader_output")
class FluxStructuralLoRALoaderOutput(BaseInvocationOutput):
"""Flux Structural LoRA Loader Output"""
transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
@invocation(
"flux_structural_lora_loader",
title="Flux Structural LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxStructuralLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.structural_lora_model, title="Structural LoRA", ui_type=UIType.StructuralLoRAModel
)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
image: ImageField = InputField(
description="The image to encode.",
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
def invoke(self, context: InvocationContext) -> FluxStructuralLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
# Check for existing LoRAs with the same key.
if (
self.transformer
and self.transformer.structural_lora
and self.transformer.structural_lora.lora.key == lora_key
):
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
output = FluxStructuralLoRALoaderOutput()
# Attach LoRA layers to the models.
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.structural_lora = StructuralLoRAField(
lora=self.lora,
img=self.image,
weight=self.weight,
)
return output

View File

@@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import (
@@ -65,11 +65,6 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
@@ -80,6 +75,18 @@ class VAEField(BaseModel):
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
class StructuralLoRAField(LoRAField):
img: ImageField = Field(description="Image to use in structural conditioning")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
structural_lora: Optional[StructuralLoRAField] = Field(
description="Structural LoRAs to apply on model loading", default=None
)
@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):
"""Base class for invocations that output a UNet field."""

View File

@@ -30,6 +30,8 @@ def denoise(
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
# extra img tokens
img_cond: torch.Tensor | None = None,
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -69,9 +71,9 @@ def denoise(
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
pred = model(
img=img,
img=pred_img,
img_ids=img_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,

View File

@@ -0,0 +1,28 @@
import numpy as np
import torch
from einops import rearrange
from PIL import Image
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
def prepare_control(
height: int,
width: int,
seed: int,
ae: AutoEncoder,
cond_image: Image.Image,
) -> torch.Tensor:
# load and encode the conditioning image
img_cond = cond_image.convert("RGB")
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS)
img_cond = np.array(img_cond)
img_cond = torch.from_numpy(img_cond).float()
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
ae_dtype = next(iter(ae.parameters())).dtype
ae_device = next(iter(ae.parameters())).device
img_cond = img_cond.to(device=ae_device, dtype=ae_dtype)
generator = torch.Generator(device=ae_device).manual_seed(seed)
img_cond = ae.encode(img_cond, sample=True, generator=generator)
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return img_cond

View File

@@ -1,6 +1,7 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor, nn
@@ -35,6 +36,7 @@ class FluxParams:
theta: int
qkv_bias: bool
guidance_embed: bool
out_channels: Optional[int] = None
class Flux(nn.Module):
@@ -47,7 +49,7 @@ class Flux(nn.Module):
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
self.out_channels = params.out_channels or self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads

View File

@@ -0,0 +1,50 @@
import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from transformers import AutoModelForDepthEstimation, AutoProcessor
class DepthImageEncoder:
depth_model_name = "LiheYoung/depth-anything-large-hf"
def __init__(self, device):
self.device = device
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
def __call__(self, img: torch.Tensor) -> torch.Tensor:
hw = img.shape[-2:]
img = torch.clamp(img, -1.0, 1.0)
img_byte = ((img + 1.0) * 127.5).byte()
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
depth = self.depth_model(img.to(self.device)).predicted_depth
depth = repeat(depth, "b h w -> b 3 h w")
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
depth = depth / 127.5 - 1.0
return depth
class CannyImageEncoder:
def __init__(
self,
device,
min_t: int = 50,
max_t: int = 200,
):
self.device = device
self.min_t = min_t
self.max_t = max_t
def __call__(self, img: torch.Tensor) -> torch.Tensor:
assert img.shape[0] == 1, "Only batch size 1 is supported"
img = rearrange(img[0], "c h w -> h w c")
img = torch.clamp(img, -1.0, 1.0)
img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
# Apply Canny edge detection
canny = cv2.Canny(img_np, self.min_t, self.max_t)
# Convert back to torch tensor and reshape
canny = torch.from_numpy(canny).float() / 127.5 - 1.0
canny = rearrange(canny, "h w -> 1 1 h w")
canny = repeat(canny, "b 1 ... -> b 3 ...")
return canny.to(self.device)

View File

@@ -0,0 +1,65 @@
import re
from typing import Any, Dict
import torch
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
# Example keys:
# guidance_in.in_layer.lora_B.bias
# single_blocks.0.linear1.lora_A.weight
# double_blocks.0.img_attn.norm.key_norm.scale
FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)"
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(
re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k)
for k in state_dict.keys()
)
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
key_props = key.split(".")
# Got it loading using lora_down and lora_up but it didn't seem to match this lora's structure
# Leaving this in since it doesn't hurt anything and may be better
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
layer_name = ".".join(key_props[:layer_prop_size])
param_name = ".".join(key_props[layer_prop_size:])
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
# Convert to a full layer diff
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
if all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
layers[prefixed_key] = LoRALayer(
layer_state_dict["lora_B.weight"],
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"],
)
elif "scale" in layer_state_dict:
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
else:
raise AssertionError(f"{layer_key} not expected")
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)

View File

@@ -7,5 +7,8 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]
AnyLoRALayer = Union[
LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer
]

View File

@@ -0,0 +1,34 @@
from typing import Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class ReshapeWeightLayer(LoRALayerBase):
# TODO: Just everything in this class
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight) if weight is not None else None
self.bias = torch.nn.Parameter(bias) if bias is not None else None
self.manual_scale = scale
def scale(self):
return self.manual_scale.float() if self.manual_scale is not None else super().scale()
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return orig_weight
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
if self.weight is not None:
self.weight = self.weight.to(device=device, dtype=dtype)
if self.manual_scale is not None:
self.manual_scale = self.manual_scale.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensor_size(self.manual_scale)

View File

@@ -0,0 +1,29 @@
from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class SetParameterLayer(LoRALayerBase):
def __init__(self, param_name: str, weight: torch.Tensor):
super().__init__(None, None)
self.weight = weight
self.param_name = param_name
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight - orig_weight
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
return {self.param_name: self.get_weight(orig_module.get_parameter(self.param_name))}
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensor_size(self.weight)

View File

@@ -52,7 +52,8 @@ class LoRAPatcher:
yield
finally:
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
cur_param = model.get_parameter(param_key)
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
@staticmethod
@torch.no_grad()
@@ -93,8 +94,9 @@ class LoRAPatcher:
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
first_param = next(module.parameters())
device = first_param.device
dtype = first_param.dtype
layer_scale = layer.scale()
@@ -114,8 +116,23 @@ class LoRAPatcher:
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
lora_param_weight = lora_param_weight.reshape(module_param.shape)
if module_param.nelement() == lora_param_weight.nelement():
lora_param_weight = lora_param_weight.reshape(module_param.shape)
else:
# This condition was added to handle layers in FLUX control LoRAs.
# TODO(ryand): Move the weight update into the LoRA layer so that the LoRAPatcher doesn't need
# to worry about this?
expanded_weight = torch.zeros_like(
lora_param_weight, dtype=module_param.dtype, device=module_param.device
)
slices = tuple(slice(0, dim) for dim in module_param.shape)
expanded_weight[slices] = module_param
setattr(
module,
param_name,
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
)
module_param = expanded_weight
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)

View File

@@ -67,6 +67,7 @@ class ModelType(str, Enum):
Main = "main"
VAE = "vae"
LoRA = "lora"
StructuralLoRa = "structural_lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
@@ -273,6 +274,18 @@ class LoRALyCORISConfig(LoRAConfigBase):
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class StructuralLoRALyCORISConfig(ModelConfigBase):
"""Model config for Structural LoRA/Lycoris models."""
type: Literal[ModelType.StructuralLoRa] = ModelType.StructuralLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.StructuralLoRa.value}.{ModelFormat.LyCORIS.value}")
class LoRADiffusersConfig(LoRAConfigBase):
"""Model config for LoRA/Diffusers models."""
@@ -535,6 +548,7 @@ AnyModelConfig = Annotated[
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[StructuralLoRALyCORISConfig, StructuralLoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],

View File

@@ -9,10 +9,15 @@ import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora.conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,
)
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
@@ -32,6 +37,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.StructuralLoRa, format=ModelFormat.LyCORIS)
class LoRALoader(ModelLoader):
"""Class to load LoRA models."""
@@ -75,7 +81,10 @@ class LoRALoader(ModelLoader):
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
elif config.format == ModelFormat.LyCORIS:
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:

View File

@@ -15,6 +15,7 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
@@ -258,6 +259,18 @@ class ModelProbe(object):
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
if isinstance(ckpt, dict) and "img_in.lora_A.weight" in ckpt and "img_in.lora_B.weight" in ckpt:
tensor_a, tensor_b = ckpt["img_in.lora_A.weight"], ckpt["img_in.lora_B.weight"]
if (
tensor_a is not None
and isinstance(tensor_a, torch.Tensor)
and tensor_a.shape[1] == 128
and tensor_b is not None
and isinstance(tensor_b, torch.Tensor)
and tensor_b.shape[0] == 3072
):
return ModelType.StructuralLoRa
for key in [str(k) for k in ckpt.keys()]:
if key.startswith(
(
@@ -624,8 +637,10 @@ class LoRACheckpointProbe(CheckpointProbeBase):
return ModelFormat.LyCORIS
def get_base_type(self) -> BaseModelType:
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
self.checkpoint
if (
is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
or is_state_dict_likely_flux_control(self.checkpoint)
):
return BaseModelType.Flux
@@ -1046,6 +1061,7 @@ ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelI
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.StructuralLoRa, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)

View File

@@ -809,6 +809,7 @@
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"structuralLora": "Structural LoRA",
"syncModels": "Sync Models",
"textualInversions": "Textual Inversions",
"triggerPhrases": "Trigger Phrases",

View File

@@ -24,6 +24,7 @@ import type {
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterStructuralLoRAModel,
ParameterT5EncoderModel,
ParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
@@ -75,6 +76,7 @@ export type ParamsState = {
clipEmbedModel: ParameterCLIPEmbedModel | null;
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
structuralLora: ParameterStructuralLoRAModel | null;
};
const initialState: ParamsState = {
@@ -121,6 +123,7 @@ const initialState: ParamsState = {
clipEmbedModel: null,
clipLEmbedModel: null,
clipGEmbedModel: null,
structuralLora: null,
};
export const paramsSlice = createSlice({
@@ -195,6 +198,9 @@ export const paramsSlice = createSlice({
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
state.t5EncoderModel = action.payload;
},
structuralLoRAModelSelected: (state, action: PayloadAction<ParameterStructuralLoRAModel | null>) => {
state.structuralLora = action.payload;
},
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
state.clipEmbedModel = action.payload;
},

View File

@@ -46,6 +46,7 @@ import type {
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterStructuralLoRAModel,
ParameterVAEModel,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
@@ -80,6 +81,7 @@ import {
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isStructuralLoRAModelConfig,
isT2IAdapterModelConfig,
isVAEModelConfig,
} from 'services/api/types';
@@ -226,6 +228,14 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
return modelIdentifier;
};
const parseStructuralLoRAModel: MetadataParseFunc<ParameterStructuralLoRAModel> = async (metadata) => {
const slora = await getProperty(metadata, 'structural_lora', undefined);
const key = await getModelKey(slora, 'structural_lora');
const sloraModelConfig = await fetchModelConfigWithTypeGuard(key, isStructuralLoRAModelConfig);
const modelIdentifier = zModelIdentifierField.parse(sloraModelConfig);
return modelIdentifier;
};
const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
// Previously, the LoRA model identifier parts were stored in the LoRA metadata: `{key: ..., weight: 0.75}`
const modelV1 = await getProperty(metadataItem, 'lora', undefined);
@@ -671,6 +681,7 @@ export const parsers = {
mainModel: parseMainModel,
refinerModel: parseRefinerModel,
vaeModel: parseVAEModel,
structuralLora: parseStructuralLoRAModel,
lora: parseLoRA,
loras: parseAllLoRAs,
controlNet: parseControlNet,

View File

@@ -18,6 +18,7 @@ import {
useMainModels,
useRefinerModels,
useSpandrelImageToImageModels,
useStructuralLoRAModel,
useT2IAdapterModels,
useT5EncoderModels,
useVAEModels,
@@ -92,6 +93,12 @@ const ModelList = () => {
[t5EncoderModels, searchTerm, filteredModelType]
);
const [structuralLoRAModels, { isLoading: isLoadingStructuralLoRAModels }] = useStructuralLoRAModel();
const filteredStructuralLoRAModels = useMemo(
() => modelsFilter(structuralLoRAModels, searchTerm, filteredModelType),
[structuralLoRAModels, searchTerm, filteredModelType]
);
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true });
const filteredClipEmbedModels = useMemo(
() => modelsFilter(clipEmbedModels, searchTerm, filteredModelType),
@@ -118,7 +125,8 @@ const ModelList = () => {
filteredVAEModels.length +
filteredSpandrelImageToImageModels.length +
t5EncoderModels.length +
clipEmbedModels.length
clipEmbedModels.length +
structuralLoRAModels.length
);
}, [
filteredControlNetModels.length,
@@ -133,6 +141,7 @@ const ModelList = () => {
filteredSpandrelImageToImageModels.length,
t5EncoderModels.length,
clipEmbedModels.length,
structuralLoRAModels.length,
]);
return (
@@ -195,6 +204,15 @@ const ModelList = () => {
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
)}
{/* Structural Lora List */}
{isLoadingStructuralLoRAModels && <FetchingModelsLoader loadingMessage="Loading Structural Loras..." />}
{!isLoadingStructuralLoRAModels && filteredStructuralLoRAModels.length > 0 && (
<ModelListWrapper
title={t('modelManager.structuralLora')}
modelList={filteredStructuralLoRAModels}
key="structural-lora"
/>
)}
{/* Clip Embed List */}
{isLoadingClipEmbedModels && <FetchingModelsLoader loadingMessage="Loading Clip Embed Models..." />}
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (

View File

@@ -24,6 +24,7 @@ export const ModelTypeFilter = memo(() => {
ip_adapter: t('common.ipAdapter'),
clip_vision: 'CLIP Vision',
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
structural_lora: t('modelManager.structuralLora'),
}),
[t]
);

View File

@@ -51,6 +51,8 @@ import {
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isStructuralLoRAModelFieldInputInstance,
isStructuralLoRAModelFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate,
isT5EncoderModelFieldInputInstance,
@@ -81,6 +83,7 @@ import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComp
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import StructuralLoRAModelFieldInputComponent from './inputs/StructuralLoraModelFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
@@ -156,6 +159,15 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (
isStructuralLoRAModelFieldInputInstance(fieldInstance) &&
isStructuralLoRAModelFieldInputTemplate(fieldTemplate)
) {
return (
<StructuralLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />
);
}
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -0,0 +1,65 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldStructuralLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
StructuralLoRAModelFieldInputInstance,
StructuralLoRAModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useStructuralLoRAModel } from 'services/api/hooks/modelsByType';
import { isStructuralLoRAModelConfig, type StructuralLoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<StructuralLoRAModelFieldInputInstance, StructuralLoRAModelFieldInputTemplate>;
const StructuralLoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useStructuralLoRAModel();
const _onChange = useCallback(
(value: StructuralLoRAModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldStructuralLoRAModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isStructuralLoRAModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(StructuralLoRAModelFieldInputComponent);

View File

@@ -28,6 +28,7 @@ import type {
SpandrelImageToImageModelFieldValue,
StatefulFieldValue,
StringFieldValue,
StructuralLoRAModelFieldValue,
T2IAdapterModelFieldValue,
T5EncoderModelFieldValue,
VAEModelFieldValue,
@@ -55,6 +56,7 @@ import {
zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue,
zStringFieldValue,
zStructuralLoRAModelFieldValue,
zT2IAdapterModelFieldValue,
zT5EncoderModelFieldValue,
zVAEModelFieldValue,
@@ -369,6 +371,9 @@ export const nodesSlice = createSlice({
fieldCLIPGEmbedValueChanged: (state, action: FieldValueAction<CLIPGEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPGEmbedModelFieldValue);
},
fieldStructuralLoRAModelValueChanged: (state, action: FieldValueAction<StructuralLoRAModelFieldValue>) => {
fieldValueReducer(state, action, zStructuralLoRAModelFieldValue);
},
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
},
@@ -438,6 +443,7 @@ export const {
fieldCLIPEmbedValueChanged,
fieldCLIPLEmbedValueChanged,
fieldCLIPGEmbedValueChanged,
fieldStructuralLoRAModelValueChanged,
fieldFluxVAEModelValueChanged,
nodeEditorReset,
nodeIsIntermediateChanged,

View File

@@ -69,6 +69,7 @@ const zModelType = z.enum([
'main',
'vae',
'lora',
'structural_lora',
'controlnet',
't2i_adapter',
'ip_adapter',

View File

@@ -178,6 +178,10 @@ const zCLIPGEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPGEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zStructuralLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('StructuralLoRAModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxVAEModelField'),
originalType: zStatelessFieldType.optional(),
@@ -210,6 +214,7 @@ const zStatefulFieldType = z.union([
zCLIPEmbedModelFieldType,
zCLIPLEmbedModelFieldType,
zCLIPGEmbedModelFieldType,
zStructuralLoRAModelFieldType,
zFluxVAEModelFieldType,
zColorFieldType,
zSchedulerFieldType,
@@ -864,6 +869,29 @@ export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGE
// #endregion
// #region StructuralLoRAModelField
export const zStructuralLoRAModelFieldValue = zModelIdentifierField.optional();
const zStructuralLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zStructuralLoRAModelFieldValue,
});
const zStructuralLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zStructuralLoRAModelFieldType,
originalType: zFieldType.optional(),
default: zStructuralLoRAModelFieldValue,
});
export type StructuralLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
export type StructuralLoRAModelFieldInputInstance = z.infer<typeof zStructuralLoRAModelFieldInputInstance>;
export type StructuralLoRAModelFieldInputTemplate = z.infer<typeof zStructuralLoRAModelFieldInputTemplate>;
export const isStructuralLoRAModelFieldInputInstance = (val: unknown): val is StructuralLoRAModelFieldInputInstance =>
zStructuralLoRAModelFieldInputInstance.safeParse(val).success;
export const isStructuralLoRAModelFieldInputTemplate = (val: unknown): val is StructuralLoRAModelFieldInputTemplate =>
zStructuralLoRAModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional();
@@ -959,6 +987,7 @@ export const zStatefulFieldValue = z.union([
zCLIPEmbedModelFieldValue,
zCLIPLEmbedModelFieldValue,
zCLIPGEmbedModelFieldValue,
zStructuralLoRAModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
]);
@@ -1030,6 +1059,7 @@ const zStatefulFieldInputTemplate = z.union([
zCLIPEmbedModelFieldInputTemplate,
zCLIPLEmbedModelFieldInputTemplate,
zCLIPGEmbedModelFieldInputTemplate,
zStructuralLoRAModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,

View File

@@ -28,6 +28,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
CLIPEmbedModelField: undefined,
CLIPLEmbedModelField: undefined,
CLIPGEmbedModelField: undefined,
StructuralLoRAModelField: undefined,
};
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {

View File

@@ -28,6 +28,7 @@ import type {
StatefulFieldType,
StatelessFieldInputTemplate,
StringFieldInputTemplate,
StructuralLoRAModelFieldInputTemplate,
T2IAdapterModelFieldInputTemplate,
T5EncoderModelFieldInputTemplate,
VAEModelFieldInputTemplate,
@@ -300,6 +301,20 @@ const buildCLIPGEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPGEmb
return template;
};
const buildStructuralLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<StructuralLoRAModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: StructuralLoRAModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -526,6 +541,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
CLIPLEmbedModelField: buildCLIPLEmbedModelFieldInputTemplate,
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
StructuralLoRAModelField: buildStructuralLoRAModelFieldInputTemplate,
} as const;
export const buildFieldInputTemplate = (

View File

@@ -113,6 +113,11 @@ export const zParameterVAEModel = zModelIdentifierField;
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
// #endregion
// #region Structural Lora Model
export const zParameterStructuralLoRAModel = zModelIdentifierField;
export type ParameterStructuralLoRAModel = z.infer<typeof zParameterStructuralLoRAModel>;
// #endregion
// #region T5Encoder Model
export const zParameterT5EncoderModel = zModelIdentifierField;
export type ParameterT5EncoderModel = z.infer<typeof zParameterT5EncoderModel>;

View File

@@ -23,6 +23,7 @@ import {
isSD3MainModelModelConfig,
isSDXLMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isStructuralLoRAModelConfig,
isT2IAdapterModelConfig,
isT5EncoderModelConfig,
isTIModelConfig,
@@ -58,6 +59,7 @@ export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useStructuralLoRAModel = buildModelsHook(isStructuralLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);

File diff suppressed because one or more lines are too long

View File

@@ -44,6 +44,7 @@ export type BaseModelType = S['BaseModelType'];
// Model Configs
export type StructuralLoRAModelConfig = S['StructuralLoRALyCORISConfig'];
// TODO(MM2): Can we make key required in the pydantic model?
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
// TODO(MM2): Can we rename this from Vae -> VAE
@@ -63,6 +64,7 @@ export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| StructuralLoRAModelConfig
| LoRAModelConfig
| VAEModelConfig
| ControlNetModelConfig
@@ -114,6 +116,10 @@ export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelCo
return config.type === 'lora';
};
export const isStructuralLoRAModelConfig = (config: AnyModelConfig): config is StructuralLoRAModelConfig => {
return config.type === 'structural_lora';
};
export const isVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => {
return config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && checkSubmodels(['vae'], config));
};

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,71 @@
import pytest
import torch
from invokeai.backend.lora.conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,
)
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from tests.backend.lora.conversions.lora_state_dicts.flux_control_lora_format import (
state_dict_keys as flux_control_lora_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
@pytest.mark.parametrize("sd_keys", [flux_control_lora_state_dict_keys])
def test_is_state_dict_likely_in_flux_control_format_true(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_flux_control() can identify a state dict in the FLUX Control LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)
assert is_state_dict_likely_flux_control(state_dict)
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys])
def test_is_state_dict_likely_in_flux_control_format_false(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_flux_control() returns False for a state dict that is in the Diffusers
FLUX LoRA format.
"""
# Construct a state dict that is not in the FLUX Control LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)
assert not is_state_dict_likely_flux_control(state_dict)
@pytest.mark.parametrize("sd_keys", [flux_control_lora_state_dict_keys])
def test_lora_model_from_flux_control_state_dict(sd_keys: dict[str, list[int]]):
"""Test that lora_model_from_flux_control_state_dict() can load a state dict in the FLUX Control LoRA format."""
# Construct a state dict that is in the FLUX Control LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)
# Load the state dict into a LoRAModelRaw object.
model = lora_model_from_flux_control_state_dict(state_dict)
# Check that the model has the correct number of LoRA layers.
expected_lora_layers: set[str] = set()
for k in sd_keys:
k = k.replace("lora_A.weight", "")
k = k.replace("lora_B.weight", "")
k = k.replace("lora_B.bias", "")
k = k.replace(".scale", "")
expected_lora_layers.add(k)
# Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep
# the Q weights so that we count these layers once).
assert len(model.layers) == len(expected_lora_layers)
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
def test_lora_model_from_flux_control_state_dict_extra_keys_error():
"""Test that lora_model_from_flux_control_state_dict() raises an error if the input state_dict contains unexpected
keys that we don't handle.
"""
# Construct a state dict that is in the FLUX Control LoRA format.
state_dict = keys_to_mock_state_dict(flux_control_lora_state_dict_keys)
# Add an unexpected key.
state_dict["transformer.single_transformer_blocks.0.unexpected_key.lora_A.weight"] = torch.empty(1)
# Check that an error is raised.
with pytest.raises(AssertionError):
lora_model_from_flux_control_state_dict(state_dict)