mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 08:28:14 -05:00
Compare commits
15 Commits
revert-798
...
ryan/lora-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ed6e65a6e | ||
|
|
52c9646f84 | ||
|
|
7662f0522b | ||
|
|
e50fe69839 | ||
|
|
5a9f884620 | ||
|
|
edc72d1739 | ||
|
|
23f521dc7c | ||
|
|
3d6b93efdd | ||
|
|
3f28d3afad | ||
|
|
9353bfbdd6 | ||
|
|
93f2bc6118 | ||
|
|
9019026d6d | ||
|
|
c195b326ec | ||
|
|
2f460d2a45 | ||
|
|
4473cba512 |
@@ -82,10 +82,11 @@ class CompelInvocation(BaseInvocation):
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_te_",
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
@@ -179,10 +180,11 @@ class SDXLPromptInvocationBase:
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
|
||||
@@ -1003,10 +1003,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
model=unet,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_unet_",
|
||||
dtype=unet.dtype,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
|
||||
@@ -56,6 +56,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
StructuralLoRAModel = "StructuralLoRAModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@@ -143,6 +144,7 @@ class FieldDescriptions:
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
structural_lora_model = "Structural LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
from typing import Callable, Iterator, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -8,6 +8,8 @@ import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
@@ -22,7 +24,7 @@ from invokeai.app.invocations.fields import (
|
||||
)
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField, StructuralLoRAField, LoRAField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
@@ -43,6 +45,8 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
@@ -284,6 +288,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
device=x.device,
|
||||
)
|
||||
img_cond = None
|
||||
if struct_lora := self.transformer.structural_lora:
|
||||
# What should we do when we have multiple of these?
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("controlnet_vae must be set when using a strutural lora")
|
||||
ae_info = context.models.load(self.controlnet_vae.vae)
|
||||
img = context.images.get_pil(struct_lora.img.image_name)
|
||||
with ae_info as ae:
|
||||
assert isinstance(ae, AutoEncoder)
|
||||
img_cond = prepare_control(self.height, self.width, self.seed, ae, img)
|
||||
|
||||
# Load the transformer model.
|
||||
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
||||
@@ -296,10 +310,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if config.format in [ModelFormat.Checkpoint]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
@@ -311,7 +326,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
|
||||
# than directly patching the weights, but is agnostic to the quantization format.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_sidecar_patches(
|
||||
LoRAPatcher.apply_lora_wrapper_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
@@ -345,6 +360,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
controlnet_extensions=controlnet_extensions,
|
||||
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
img_cond=img_cond
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
@@ -682,7 +698,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras]
|
||||
if self.transformer.structural_lora:
|
||||
loras.append(self.transformer.structural_lora)
|
||||
for lora in loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
|
||||
@@ -81,8 +81,8 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
transformer=TransformerField(transformer=transformer, loras=[], structural_loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
|
||||
70
invokeai/app/invocations/flux_structural_lora_loader.py
Normal file
70
invokeai/app/invocations/flux_structural_lora_loader.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Optional, Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
|
||||
from invokeai.app.invocations.model import VAEField, StructuralLoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_structural_lora_loader_output")
|
||||
class FluxStructuralLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""Flux Structural LoRA Loader Output"""
|
||||
|
||||
transformer: Optional[TransformerField] = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_structural_lora_loader",
|
||||
title="Flux Structural LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxStructuralLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.structural_lora_model, title="Structural LoRA", ui_type=UIType.StructuralLoRAModel
|
||||
)
|
||||
transformer: TransformerField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxStructuralLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
# Check for existing LoRAs with the same key.
|
||||
if self.transformer and self.transformer.structural_lora and self.transformer.structural_lora.lora.key == lora_key:
|
||||
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
|
||||
|
||||
output = FluxStructuralLoRALoaderOutput()
|
||||
|
||||
# Attach LoRA layers to the models.
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.structural_lora = StructuralLoRAField(
|
||||
lora=self.lora,
|
||||
img=self.image,
|
||||
weight=self.weight,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -22,6 +22,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -111,10 +112,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@@ -65,11 +65,6 @@ class CLIPField(BaseModel):
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
@@ -79,6 +74,13 @@ class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
class StructuralLoRAField(LoRAField):
|
||||
img: ImageField = Field(description="Image to use in structural conditioning")
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
structural_lora: Optional[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading", default=None)
|
||||
|
||||
@invocation_output("unet_output")
|
||||
class UNetOutput(BaseInvocationOutput):
|
||||
|
||||
@@ -21,6 +21,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
||||
SD3_T5_MAX_SEQ_LEN = 256
|
||||
@@ -150,10 +151,11 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context, clip_model),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -207,7 +207,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info as unet,
|
||||
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
|
||||
),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
@@ -30,6 +30,8 @@ def denoise(
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
||||
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
# extra img tokens
|
||||
img_cond: torch.Tensor | None = None,
|
||||
):
|
||||
# step 0 is the initial state
|
||||
total_steps = len(timesteps) - 1
|
||||
@@ -69,9 +71,9 @@ def denoise(
|
||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
|
||||
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
|
||||
pred = model(
|
||||
img=img,
|
||||
img=pred_img,
|
||||
img_ids=img_ids,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
|
||||
27
invokeai/backend/flux/flux_tools_sampling_utils.py
Normal file
27
invokeai/backend/flux/flux_tools_sampling_utils.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
|
||||
def prepare_control(
|
||||
height: int,
|
||||
width: int,
|
||||
seed: int,
|
||||
ae: AutoEncoder,
|
||||
cond_image: Image.Image,
|
||||
) -> torch.Tensor:
|
||||
# load and encode the conditioning image
|
||||
img_cond = cond_image.convert("RGB")
|
||||
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS)
|
||||
img_cond = np.array(img_cond)
|
||||
img_cond = torch.from_numpy(img_cond).float()
|
||||
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
|
||||
ae_dtype = next(iter(ae.parameters())).dtype
|
||||
ae_device = next(iter(ae.parameters())).device
|
||||
img_cond = img_cond.to(device=ae_device, dtype=ae_dtype)
|
||||
generator = torch.Generator(device=ae_device).manual_seed(seed)
|
||||
img_cond = ae.encode(img_cond, sample=True, generator=generator)
|
||||
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
return img_cond
|
||||
@@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.flux.custom_block_processor import (
|
||||
CustomDoubleStreamBlockProcessor,
|
||||
@@ -35,6 +36,7 @@ class FluxParams:
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
out_channels: Optional[int] = None
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
@@ -47,7 +49,7 @@ class Flux(nn.Module):
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
self.out_channels = params.out_channels or self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
|
||||
50
invokeai/backend/flux/modules/image_embedders.py
Normal file
50
invokeai/backend/flux/modules/image_embedders.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file as load_sft
|
||||
from torch import nn
|
||||
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
class DepthImageEncoder:
|
||||
depth_model_name = "LiheYoung/depth-anything-large-hf"
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
|
||||
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
hw = img.shape[-2:]
|
||||
img = torch.clamp(img, -1.0, 1.0)
|
||||
img_byte = ((img + 1.0) * 127.5).byte()
|
||||
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
|
||||
depth = self.depth_model(img.to(self.device)).predicted_depth
|
||||
depth = repeat(depth, "b h w -> b 3 h w")
|
||||
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
|
||||
depth = depth / 127.5 - 1.0
|
||||
return depth
|
||||
|
||||
class CannyImageEncoder:
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
min_t: int = 50,
|
||||
max_t: int = 200,
|
||||
):
|
||||
self.device = device
|
||||
self.min_t = min_t
|
||||
self.max_t = max_t
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
assert img.shape[0] == 1, "Only batch size 1 is supported"
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img = torch.clamp(img, -1.0, 1.0)
|
||||
img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
|
||||
# Apply Canny edge detection
|
||||
canny = cv2.Canny(img_np, self.min_t, self.max_t)
|
||||
# Convert back to torch tensor and reshape
|
||||
canny = torch.from_numpy(canny).float() / 127.5 - 1.0
|
||||
canny = rearrange(canny, "h w -> 1 1 h w")
|
||||
canny = repeat(canny, "b 1 ... -> b 3 ...")
|
||||
return canny.to(self.device)
|
||||
65
invokeai/backend/lora/conversions/flux_control_lora_utils.py
Normal file
65
invokeai/backend/lora/conversions/flux_control_lora_utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import re
|
||||
import torch
|
||||
|
||||
from typing import Any, Dict
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
|
||||
# Example keys:
|
||||
# guidance_in.in_layer.lora_B.bias
|
||||
# single_blocks.0.linear1.lora_A.weight
|
||||
# double_blocks.0.img_attn.norm.key_norm.scale
|
||||
FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)"
|
||||
|
||||
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
return all(
|
||||
re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k)
|
||||
for k in state_dict.keys()
|
||||
)
|
||||
|
||||
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
key_props = key.split(".")
|
||||
# Got it loading using lora_down and lora_up but it didn't seem to match this lora's structure
|
||||
# Leaving this in since it doesn't hurt anything and may be better
|
||||
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
|
||||
layer_name = ".".join(key_props[:layer_prop_size])
|
||||
param_name = ".".join(key_props[layer_prop_size:])
|
||||
if layer_name not in grouped_state_dict:
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
# Convert to a full layer diff
|
||||
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
|
||||
if all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
|
||||
layers[prefixed_key] = LoRALayer(
|
||||
layer_state_dict["lora_B.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"]
|
||||
)
|
||||
elif "scale" in layer_state_dict:
|
||||
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
|
||||
else:
|
||||
raise AssertionError(f"{layer_key} not expected")
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
@@ -7,5 +7,6 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer]
|
||||
|
||||
34
invokeai/backend/lora/layers/reshape_weight_layer.py
Normal file
34
invokeai/backend/lora/layers/reshape_weight_layer.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class ReshapeWeightLayer(LoRALayerBase):
|
||||
# TODO: Just everything in this class
|
||||
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
|
||||
super().__init__(alpha=None, bias=bias)
|
||||
self.weight = torch.nn.Parameter(weight) if weight is not None else None
|
||||
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
||||
self.manual_scale = scale
|
||||
|
||||
def scale(self):
|
||||
return self.manual_scale.float() if self.manual_scale is not None else super().scale()
|
||||
|
||||
def rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return orig_weight
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
if self.weight is not None:
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
if self.manual_scale is not None:
|
||||
self.manual_scale = self.manual_scale.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return super().calc_size() + calc_tensor_size(self.manual_scale)
|
||||
29
invokeai/backend/lora/layers/set_parameter_layer.py
Normal file
29
invokeai/backend/lora/layers/set_parameter_layer.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class SetParameterLayer(LoRALayerBase):
|
||||
def __init__(self, param_name: str, weight: torch.Tensor):
|
||||
super().__init__(None, None)
|
||||
self.weight = weight
|
||||
self.param_name = param_name
|
||||
|
||||
def rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight - orig_weight
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
return {self.param_name: self.get_weight(orig_module.get_parameter(self.param_name))}
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return super().calc_size() + calc_tensor_size(self.weight)
|
||||
@@ -9,6 +9,7 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
|
||||
133
invokeai/backend/lora/lora_layer_wrappers.py
Normal file
133
invokeai/backend/lora/lora_layer_wrappers.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
class LoRASidecarWrapper(torch.nn.Module):
|
||||
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[AnyLoRALayer], lora_weights: list[float]):
|
||||
super().__init__()
|
||||
self._orig_module = orig_module
|
||||
self._lora_layers = lora_layers
|
||||
self._lora_weights = lora_weights
|
||||
|
||||
@property
|
||||
def orig_module(self) -> torch.nn.Module:
|
||||
return self._orig_module
|
||||
|
||||
def add_lora_layer(self, lora_layer: AnyLoRALayer, lora_weight: float):
|
||||
self._lora_layers.append(lora_layer)
|
||||
self._lora_weights.append(lora_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_lora_patched_parameters(
|
||||
self, orig_params: dict[str, torch.Tensor], lora_layers: list[AnyLoRALayer], lora_weights: list[float]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
params: dict[str, torch.Tensor] = {}
|
||||
for lora_layer, lora_weight in zip(lora_layers, lora_weights, strict=True):
|
||||
layer_params = lora_layer.get_parameters(self._orig_module)
|
||||
for param_name, param_weight in layer_params.items():
|
||||
if orig_params[param_name].shape != param_weight.shape:
|
||||
param_weight = param_weight.reshape(orig_params[param_name].shape)
|
||||
|
||||
if param_name not in params:
|
||||
params[param_name] = param_weight * (lora_layer.scale() * lora_weight)
|
||||
else:
|
||||
params[param_name] += param_weight * (lora_layer.scale() * lora_weight)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class LoRALinearWrapper(LoRASidecarWrapper):
|
||||
def _lora_linear_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
|
||||
x = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
|
||||
x *= lora_weight * lora_layer.scale()
|
||||
return x
|
||||
|
||||
def _concatenated_lora_forward(
|
||||
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
|
||||
) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= lora_weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# Split the LoRA layers into those that have optimized implementations and those that don't.
|
||||
optimized_layer_types = (LoRALayer, ConcatenatedLoRALayer)
|
||||
optimized_layers = [
|
||||
(layer, weight)
|
||||
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
|
||||
if isinstance(layer, optimized_layer_types)
|
||||
]
|
||||
non_optimized_layers = [
|
||||
(layer, weight)
|
||||
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
|
||||
if not isinstance(layer, optimized_layer_types)
|
||||
]
|
||||
|
||||
# First, calculate the residual for LoRA layers for which there is an optimized implementation.
|
||||
residual = None
|
||||
for lora_layer, lora_weight in optimized_layers:
|
||||
if isinstance(lora_layer, LoRALayer):
|
||||
added_residual = self._lora_linear_forward(input, lora_layer, lora_weight)
|
||||
elif isinstance(lora_layer, ConcatenatedLoRALayer):
|
||||
added_residual = self._concatenated_lora_forward(input, lora_layer, lora_weight)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LoRA layer type: {type(lora_layer)}")
|
||||
|
||||
if residual is None:
|
||||
residual = added_residual
|
||||
else:
|
||||
residual += added_residual
|
||||
|
||||
# Next, calculate the residuals for the LoRA layers for which there is no optimized implementation.
|
||||
if non_optimized_layers:
|
||||
unoptimized_layers, unoptimized_weights = zip(*non_optimized_layers, strict=True)
|
||||
params = self._get_lora_patched_parameters(
|
||||
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
|
||||
lora_layers=unoptimized_layers,
|
||||
lora_weights=unoptimized_weights,
|
||||
)
|
||||
added_residual = torch.nn.functional.linear(input, params["weight"], params.get("bias", None))
|
||||
if residual is None:
|
||||
residual = added_residual
|
||||
else:
|
||||
residual += added_residual
|
||||
|
||||
return self.orig_module(input) + residual
|
||||
|
||||
|
||||
class LoRAConv1dWrapper(LoRASidecarWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
params = self._get_lora_patched_parameters(
|
||||
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
|
||||
lora_layers=self._lora_layers,
|
||||
lora_weights=self._lora_weights,
|
||||
)
|
||||
return self.orig_module(input) + torch.nn.functional.conv1d(input, params["weight"], params.get("bias", None))
|
||||
|
||||
|
||||
class LoRAConv2dWrapper(LoRASidecarWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
params = self._get_lora_patched_parameters(
|
||||
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
|
||||
lora_layers=self._lora_layers,
|
||||
lora_weights=self._lora_weights,
|
||||
)
|
||||
return self.orig_module(input) + torch.nn.functional.conv2d(input, params["weight"], params.get("bias", None))
|
||||
@@ -4,19 +4,126 @@ from typing import Dict, Iterable, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
|
||||
ConcatenatedLoRALinearSidecarLayer,
|
||||
from invokeai.backend.lora.lora_layer_wrappers import (
|
||||
LoRAConv1dWrapper,
|
||||
LoRAConv2dWrapper,
|
||||
LoRALinearWrapper,
|
||||
LoRASidecarWrapper,
|
||||
)
|
||||
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
|
||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class LoRAPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_smart_lora_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Apply 'smart' LoRA patching that chooses whether to use direct patching or a sidecar wrapper for each module."""
|
||||
|
||||
# original_weights are stored for unpatching layers that are directly patched.
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
# original_modules are stored for unpatching layers that are wrapped in a LoRASidecarWrapper.
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher._apply_smart_lora_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Restore directly patched layers.
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
# Restore LoRASidecarWrapper modules.
|
||||
# Note: This logic assumes no nested modules in original_modules.
|
||||
for module_key, orig_module in original_modules.items():
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
||||
parent_module = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_smart_lora_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
|
||||
patching or a sidecar wrapper for each module.
|
||||
"""
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LoRAPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# Decide whether to use direct patching or a sidecar wrapper.
|
||||
# Direct patching is preferred, because it results in better runtime speed.
|
||||
# Reasons to use sidecar patching:
|
||||
# - The module is already wrapped in a LoRASidecarWrapper.
|
||||
# - The module is quantized.
|
||||
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
|
||||
# CPU, since this would double the RAM usage)
|
||||
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
|
||||
# and that the caller will use the 'apply_lora_wrapper_patches' method if the layer is quantized.
|
||||
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
|
||||
# forcing full patching even on the CPU?
|
||||
if isinstance(module, LoRASidecarWrapper) or LoRAPatcher._is_any_part_of_layer_on_cpu(module):
|
||||
LoRAPatcher._apply_lora_layer_wrapper_patch(
|
||||
model=model,
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
LoRAPatcher._apply_lora_layer_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
|
||||
return any(p.device.type == "cpu" for p in layer.parameters())
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
@@ -40,7 +147,7 @@ class LoRAPatcher:
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher.apply_lora_patch(
|
||||
LoRAPatcher._apply_lora_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
@@ -52,11 +159,12 @@ class LoRAPatcher:
|
||||
yield
|
||||
finally:
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
cur_param = model.get_parameter(param_key)
|
||||
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def apply_lora_patch(
|
||||
def _apply_lora_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
@@ -91,48 +199,84 @@ class LoRAPatcher:
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
LoRAPatcher._apply_lora_layer_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
|
||||
layer_scale = layer.scale()
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_lora_layer_patch(
|
||||
module_to_patch: torch.nn.Module,
|
||||
module_to_patch_key: str,
|
||||
patch: AnyLoRALayer,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
first_param = next(module_to_patch.parameters())
|
||||
device = first_param.device
|
||||
dtype = first_param.dtype
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_scale = patch.scale()
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
patch.to(device=device)
|
||||
patch.to(dtype=torch.float32)
|
||||
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in patch.get_parameters(module_to_patch).items():
|
||||
param_key = module_to_patch_key + "." + param_name
|
||||
module_param = module_to_patch.get_parameter(param_name)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
if module_param.nelement() == lora_param_weight.nelement():
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
else:
|
||||
# This condition was added to handle layers in FLUX control LoRAs.
|
||||
# TODO(ryand): Move the weight update into the LoRA layer so that the LoRAPatcher doesn't need
|
||||
# to worry about this?
|
||||
expanded_weight = torch.zeros_like(
|
||||
lora_param_weight, dtype=module_param.dtype, device=module_param.device
|
||||
)
|
||||
slices = tuple(slice(0, dim) for dim in module_param.shape)
|
||||
expanded_weight[slices] = module_param
|
||||
setattr(
|
||||
module,
|
||||
param_name,
|
||||
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
|
||||
)
|
||||
module_param = expanded_weight
|
||||
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
patch.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_sidecar_patches(
|
||||
def apply_lora_wrapper_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
|
||||
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
|
||||
quantization format.
|
||||
"""Apply one or more LoRA wrapper patches to a model within a context manager. Wrapper patches incur some
|
||||
runtime overhead compared to normal LoRA patching, but they enable:
|
||||
- LoRA layers to be applied to quantized models
|
||||
- LoRA layers to be applied to CPU layers without needing to store a full copy of the original weights (i.e.
|
||||
avoid doubling the memory requirements).
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
@@ -140,14 +284,11 @@ class LoRAPatcher:
|
||||
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
|
||||
all at once.
|
||||
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
|
||||
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
|
||||
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
|
||||
different from their compute dtype.
|
||||
"""
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher._apply_lora_sidecar_patch(
|
||||
LoRAPatcher._apply_lora_wrapper_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
@@ -165,7 +306,7 @@ class LoRAPatcher:
|
||||
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
def _apply_lora_sidecar_patch(
|
||||
def _apply_lora_wrapper_patch(
|
||||
model: torch.nn.Module,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
@@ -173,7 +314,7 @@ class LoRAPatcher:
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA sidecar patch to a model."""
|
||||
"""Apply a single LoRA wrapper patch to a model."""
|
||||
|
||||
if patch_weight == 0:
|
||||
return
|
||||
@@ -194,28 +335,47 @@ class LoRAPatcher:
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# Initialize the LoRA sidecar layer.
|
||||
lora_sidecar_layer = LoRAPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
|
||||
LoRAPatcher._apply_lora_layer_wrapper_patch(
|
||||
model=model,
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Replace the original module with a LoRASidecarModule if it has not already been done.
|
||||
if module_key in original_modules:
|
||||
# The module has already been patched with a LoRASidecarModule. Append to it.
|
||||
assert isinstance(module, LoRASidecarModule)
|
||||
lora_sidecar_module = module
|
||||
else:
|
||||
# The module has not yet been patched with a LoRASidecarModule. Create one.
|
||||
lora_sidecar_module = LoRASidecarModule(module, [])
|
||||
original_modules[module_key] = module
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
||||
module_parent = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_lora_layer_wrapper_patch(
|
||||
model: torch.nn.Module,
|
||||
module_to_patch: torch.nn.Module,
|
||||
module_to_patch_key: str,
|
||||
patch: AnyLoRALayer,
|
||||
patch_weight: float,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA wrapper patch to a model."""
|
||||
|
||||
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
|
||||
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
|
||||
lora_sidecar_layer.to(device=lora_sidecar_module.orig_module.weight.device, dtype=dtype)
|
||||
# Replace the original module with a LoRASidecarWrapper if it has not already been done.
|
||||
if not isinstance(module_to_patch, LoRASidecarWrapper):
|
||||
lora_wrapper_layer = LoRAPatcher._initialize_lora_wrapper_layer(module_to_patch)
|
||||
original_modules[module_to_patch_key] = module_to_patch
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_to_patch_key)
|
||||
module_parent = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(module_parent, module_name, lora_wrapper_layer)
|
||||
orig_module = module_to_patch
|
||||
else:
|
||||
assert module_to_patch_key in original_modules
|
||||
lora_wrapper_layer = module_to_patch
|
||||
orig_module = module_to_patch.orig_module
|
||||
|
||||
# Add the LoRA sidecar layer to the LoRASidecarModule.
|
||||
lora_sidecar_module.add_lora_layer(lora_sidecar_layer)
|
||||
# Move the LoRA layer to the same device/dtype as the orig module.
|
||||
patch.to(device=orig_module.weight.device, dtype=dtype)
|
||||
|
||||
# Add the LoRA wrapper layer to the LoRASidecarWrapper.
|
||||
lora_wrapper_layer.add_lora_layer(patch, patch_weight)
|
||||
|
||||
@staticmethod
|
||||
def _split_parent_key(module_key: str) -> tuple[str, str]:
|
||||
@@ -236,17 +396,13 @@ class LoRAPatcher:
|
||||
raise ValueError(f"Invalid module key: {module_key}")
|
||||
|
||||
@staticmethod
|
||||
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
|
||||
# TODO(ryand): Add support for more original layer types and LoRA layer types.
|
||||
if isinstance(orig_layer, torch.nn.Linear) or (
|
||||
isinstance(orig_layer, LoRASidecarModule) and isinstance(orig_layer.orig_module, torch.nn.Linear)
|
||||
):
|
||||
if isinstance(lora_layer, LoRALayer):
|
||||
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
|
||||
elif isinstance(lora_layer, ConcatenatedLoRALayer):
|
||||
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
|
||||
def _initialize_lora_wrapper_layer(orig_layer: torch.nn.Module):
|
||||
if isinstance(orig_layer, torch.nn.Linear):
|
||||
return LoRALinearWrapper(orig_layer, [], [])
|
||||
elif isinstance(orig_layer, torch.nn.Conv1d):
|
||||
return LoRAConv1dWrapper(orig_layer, [], [])
|
||||
elif isinstance(orig_layer, torch.nn.Conv2d):
|
||||
return LoRAConv2dWrapper(orig_layer, [], [])
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
|
||||
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
|
||||
|
||||
class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
concatenated_lora_layer: ConcatenatedLoRALayer,
|
||||
weight: float,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._concatenated_lora_layer = concatenated_lora_layer
|
||||
self._weight = weight
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in self._concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= self._weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert self._concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._concatenated_lora_layer.to(device=device, dtype=dtype)
|
||||
return self
|
||||
@@ -1,27 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
class LoRALinearSidecarLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
lora_layer: LoRALayer,
|
||||
weight: float,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._lora_layer = lora_layer
|
||||
self._weight = weight
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.down)
|
||||
if self._lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias)
|
||||
x *= self._weight * self._lora_layer.scale()
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._lora_layer.to(device=device, dtype=dtype)
|
||||
return self
|
||||
@@ -1,24 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class LoRASidecarModule(torch.nn.Module):
|
||||
"""A LoRA sidecar module that wraps an original module and adds LoRA layers to it."""
|
||||
|
||||
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
|
||||
super().__init__()
|
||||
self.orig_module = orig_module
|
||||
self._lora_layers = lora_layers
|
||||
|
||||
def add_lora_layer(self, lora_layer: torch.nn.Module):
|
||||
self._lora_layers.append(lora_layer)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
x = self.orig_module(input)
|
||||
for lora_layer in self._lora_layers:
|
||||
x += lora_layer(input)
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._orig_module.to(device=device, dtype=dtype)
|
||||
for lora_layer in self._lora_layers:
|
||||
lora_layer.to(device=device, dtype=dtype)
|
||||
@@ -67,6 +67,7 @@ class ModelType(str, Enum):
|
||||
Main = "main"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
StructuralLoRa = "structural_lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
@@ -273,6 +274,18 @@ class LoRALyCORISConfig(LoRAConfigBase):
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class StructuralLoRALyCORISConfig(ModelConfigBase):
|
||||
"""Model config for Structural LoRA/Lycoris models."""
|
||||
|
||||
type: Literal[ModelType.StructuralLoRa] = ModelType.StructuralLoRa
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.StructuralLoRa.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class LoRADiffusersConfig(LoRAConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
@@ -535,6 +548,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[StructuralLoRALyCORISConfig, StructuralLoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
||||
|
||||
@@ -13,8 +13,9 @@ from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils impo
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control, lora_model_from_flux_control_state_dict
|
||||
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.model_manager import (
|
||||
@@ -32,6 +33,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.StructuralLoRa, format=ModelFormat.LyCORIS)
|
||||
class LoRALoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
@@ -75,7 +77,10 @@ class LoRALoader(ModelLoader):
|
||||
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
|
||||
elif config.format == ModelFormat.LyCORIS:
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_flux_control(state_dict=state_dict):
|
||||
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
|
||||
@@ -18,6 +18,7 @@ from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlab
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@@ -258,6 +259,18 @@ class ModelProbe(object):
|
||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
if isinstance(ckpt, dict) and "img_in.lora_A.weight" in ckpt and "img_in.lora_B.weight" in ckpt:
|
||||
tensor_a, tensor_b = ckpt["img_in.lora_A.weight"], ckpt["img_in.lora_B.weight"]
|
||||
if (
|
||||
tensor_a is not None
|
||||
and isinstance(tensor_a, torch.Tensor)
|
||||
and tensor_a.shape[1] == 128
|
||||
and tensor_b is not None
|
||||
and isinstance(tensor_b, torch.Tensor)
|
||||
and tensor_b.shape[0] == 3072
|
||||
):
|
||||
return ModelType.StructuralLoRa
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(
|
||||
(
|
||||
@@ -624,8 +637,10 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
return ModelFormat.LyCORIS
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
|
||||
self.checkpoint
|
||||
if (
|
||||
is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
|
||||
or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
|
||||
or is_state_dict_likely_flux_control(self.checkpoint)
|
||||
):
|
||||
return BaseModelType.Flux
|
||||
|
||||
@@ -1046,6 +1061,7 @@ ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelI
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.StructuralLoRa, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
|
||||
@@ -809,6 +809,7 @@
|
||||
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",
|
||||
"starterModels": "Starter Models",
|
||||
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||
"structuralLora": "Structural LoRA",
|
||||
"syncModels": "Sync Models",
|
||||
"textualInversions": "Textual Inversions",
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
|
||||
@@ -24,6 +24,7 @@ import type {
|
||||
ParameterSeed,
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterStructuralLoRAModel,
|
||||
ParameterT5EncoderModel,
|
||||
ParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
@@ -75,6 +76,7 @@ export type ParamsState = {
|
||||
clipEmbedModel: ParameterCLIPEmbedModel | null;
|
||||
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
|
||||
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
|
||||
structuralLora: ParameterStructuralLoRAModel | null;
|
||||
};
|
||||
|
||||
const initialState: ParamsState = {
|
||||
@@ -121,6 +123,7 @@ const initialState: ParamsState = {
|
||||
clipEmbedModel: null,
|
||||
clipLEmbedModel: null,
|
||||
clipGEmbedModel: null,
|
||||
structuralLora: null,
|
||||
};
|
||||
|
||||
export const paramsSlice = createSlice({
|
||||
@@ -195,6 +198,9 @@ export const paramsSlice = createSlice({
|
||||
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
|
||||
state.t5EncoderModel = action.payload;
|
||||
},
|
||||
structuralLoRAModelSelected: (state, action: PayloadAction<ParameterStructuralLoRAModel | null>) => {
|
||||
state.structuralLora = action.payload;
|
||||
},
|
||||
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
|
||||
state.clipEmbedModel = action.payload;
|
||||
},
|
||||
|
||||
@@ -46,6 +46,7 @@ import type {
|
||||
ParameterSeed,
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterStructuralLoRAModel,
|
||||
ParameterVAEModel,
|
||||
ParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
@@ -80,6 +81,7 @@ import {
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isStructuralLoRAModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isVAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
@@ -226,6 +228,14 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
|
||||
return modelIdentifier;
|
||||
};
|
||||
|
||||
const parseStructuralLoRAModel: MetadataParseFunc<ParameterStructuralLoRAModel> = async (metadata) => {
|
||||
const slora = await getProperty(metadata, 'structural_lora', undefined);
|
||||
const key = await getModelKey(slora, 'structural_lora');
|
||||
const sloraModelConfig = await fetchModelConfigWithTypeGuard(key, isStructuralLoRAModelConfig);
|
||||
const modelIdentifier = zModelIdentifierField.parse(sloraModelConfig);
|
||||
return modelIdentifier;
|
||||
};
|
||||
|
||||
const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
||||
// Previously, the LoRA model identifier parts were stored in the LoRA metadata: `{key: ..., weight: 0.75}`
|
||||
const modelV1 = await getProperty(metadataItem, 'lora', undefined);
|
||||
@@ -671,6 +681,7 @@ export const parsers = {
|
||||
mainModel: parseMainModel,
|
||||
refinerModel: parseRefinerModel,
|
||||
vaeModel: parseVAEModel,
|
||||
structuralLora: parseStructuralLoRAModel,
|
||||
lora: parseLoRA,
|
||||
loras: parseAllLoRAs,
|
||||
controlNet: parseControlNet,
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
useMainModels,
|
||||
useRefinerModels,
|
||||
useSpandrelImageToImageModels,
|
||||
useStructuralLoRAModel,
|
||||
useT2IAdapterModels,
|
||||
useT5EncoderModels,
|
||||
useVAEModels,
|
||||
@@ -92,6 +93,12 @@ const ModelList = () => {
|
||||
[t5EncoderModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [structuralLoRAModels, { isLoading: isLoadingStructuralLoRAModels }] = useStructuralLoRAModel();
|
||||
const filteredStructuralLoRAModels = useMemo(
|
||||
() => modelsFilter(structuralLoRAModels, searchTerm, filteredModelType),
|
||||
[structuralLoRAModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true });
|
||||
const filteredClipEmbedModels = useMemo(
|
||||
() => modelsFilter(clipEmbedModels, searchTerm, filteredModelType),
|
||||
@@ -118,7 +125,8 @@ const ModelList = () => {
|
||||
filteredVAEModels.length +
|
||||
filteredSpandrelImageToImageModels.length +
|
||||
t5EncoderModels.length +
|
||||
clipEmbedModels.length
|
||||
clipEmbedModels.length +
|
||||
structuralLoRAModels.length
|
||||
);
|
||||
}, [
|
||||
filteredControlNetModels.length,
|
||||
@@ -133,6 +141,7 @@ const ModelList = () => {
|
||||
filteredSpandrelImageToImageModels.length,
|
||||
t5EncoderModels.length,
|
||||
clipEmbedModels.length,
|
||||
structuralLoRAModels.length,
|
||||
]);
|
||||
|
||||
return (
|
||||
@@ -195,6 +204,15 @@ const ModelList = () => {
|
||||
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
|
||||
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
|
||||
)}
|
||||
{/* Structural Lora List */}
|
||||
{isLoadingStructuralLoRAModels && <FetchingModelsLoader loadingMessage="Loading Structural Loras..." />}
|
||||
{!isLoadingStructuralLoRAModels && filteredStructuralLoRAModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title={t('modelManager.structuralLora')}
|
||||
modelList={filteredStructuralLoRAModels}
|
||||
key="structural-lora"
|
||||
/>
|
||||
)}
|
||||
{/* Clip Embed List */}
|
||||
{isLoadingClipEmbedModels && <FetchingModelsLoader loadingMessage="Loading Clip Embed Models..." />}
|
||||
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (
|
||||
|
||||
@@ -24,6 +24,7 @@ export const ModelTypeFilter = memo(() => {
|
||||
ip_adapter: t('common.ipAdapter'),
|
||||
clip_vision: 'CLIP Vision',
|
||||
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
|
||||
structural_lora: t('modelManager.structuralLora'),
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
|
||||
@@ -51,6 +51,8 @@ import {
|
||||
isSpandrelImageToImageModelFieldInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isStructuralLoRAModelFieldInputInstance,
|
||||
isStructuralLoRAModelFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
isT2IAdapterModelFieldInputTemplate,
|
||||
isT5EncoderModelFieldInputInstance,
|
||||
@@ -81,6 +83,7 @@ import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComp
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||
import StructuralLoRAModelFieldInputComponent from './inputs/StructuralLoraModelFieldInputComponent';
|
||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
|
||||
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||
@@ -156,6 +159,15 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (
|
||||
isStructuralLoRAModelFieldInputInstance(fieldInstance) &&
|
||||
isStructuralLoRAModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<StructuralLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />
|
||||
);
|
||||
}
|
||||
|
||||
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldStructuralLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
StructuralLoRAModelFieldInputInstance,
|
||||
StructuralLoRAModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useStructuralLoRAModel } from 'services/api/hooks/modelsByType';
|
||||
import { isStructuralLoRAModelConfig, type StructuralLoRAModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<StructuralLoRAModelFieldInputInstance, StructuralLoRAModelFieldInputTemplate>;
|
||||
|
||||
const StructuralLoRAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useStructuralLoRAModel();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: StructuralLoRAModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldStructuralLoRAModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs: modelConfigs.filter((config) => isStructuralLoRAModelConfig(config)),
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
const required = props.fieldTemplate.required;
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(StructuralLoRAModelFieldInputComponent);
|
||||
@@ -28,6 +28,7 @@ import type {
|
||||
SpandrelImageToImageModelFieldValue,
|
||||
StatefulFieldValue,
|
||||
StringFieldValue,
|
||||
StructuralLoRAModelFieldValue,
|
||||
T2IAdapterModelFieldValue,
|
||||
T5EncoderModelFieldValue,
|
||||
VAEModelFieldValue,
|
||||
@@ -55,6 +56,7 @@ import {
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zStatefulFieldValue,
|
||||
zStringFieldValue,
|
||||
zStructuralLoRAModelFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zT5EncoderModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
@@ -369,6 +371,9 @@ export const nodesSlice = createSlice({
|
||||
fieldCLIPGEmbedValueChanged: (state, action: FieldValueAction<CLIPGEmbedModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zCLIPGEmbedModelFieldValue);
|
||||
},
|
||||
fieldStructuralLoRAModelValueChanged: (state, action: FieldValueAction<StructuralLoRAModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStructuralLoRAModelFieldValue);
|
||||
},
|
||||
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
|
||||
},
|
||||
@@ -438,6 +443,7 @@ export const {
|
||||
fieldCLIPEmbedValueChanged,
|
||||
fieldCLIPLEmbedValueChanged,
|
||||
fieldCLIPGEmbedValueChanged,
|
||||
fieldStructuralLoRAModelValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
nodeEditorReset,
|
||||
nodeIsIntermediateChanged,
|
||||
|
||||
@@ -69,6 +69,7 @@ const zModelType = z.enum([
|
||||
'main',
|
||||
'vae',
|
||||
'lora',
|
||||
'structural_lora',
|
||||
'controlnet',
|
||||
't2i_adapter',
|
||||
'ip_adapter',
|
||||
|
||||
@@ -178,6 +178,10 @@ const zCLIPGEmbedModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CLIPGEmbedModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zStructuralLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StructuralLoRAModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxVAEModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -210,6 +214,7 @@ const zStatefulFieldType = z.union([
|
||||
zCLIPEmbedModelFieldType,
|
||||
zCLIPLEmbedModelFieldType,
|
||||
zCLIPGEmbedModelFieldType,
|
||||
zStructuralLoRAModelFieldType,
|
||||
zFluxVAEModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
@@ -864,6 +869,29 @@ export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGE
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region StructuralLoRAModelField
|
||||
|
||||
export const zStructuralLoRAModelFieldValue = zModelIdentifierField.optional();
|
||||
const zStructuralLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zStructuralLoRAModelFieldValue,
|
||||
});
|
||||
const zStructuralLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStructuralLoRAModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zStructuralLoRAModelFieldValue,
|
||||
});
|
||||
|
||||
export type StructuralLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
|
||||
|
||||
export type StructuralLoRAModelFieldInputInstance = z.infer<typeof zStructuralLoRAModelFieldInputInstance>;
|
||||
export type StructuralLoRAModelFieldInputTemplate = z.infer<typeof zStructuralLoRAModelFieldInputTemplate>;
|
||||
export const isStructuralLoRAModelFieldInputInstance = (val: unknown): val is StructuralLoRAModelFieldInputInstance =>
|
||||
zStructuralLoRAModelFieldInputInstance.safeParse(val).success;
|
||||
export const isStructuralLoRAModelFieldInputTemplate = (val: unknown): val is StructuralLoRAModelFieldInputTemplate =>
|
||||
zStructuralLoRAModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
@@ -959,6 +987,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zCLIPLEmbedModelFieldValue,
|
||||
zCLIPGEmbedModelFieldValue,
|
||||
zStructuralLoRAModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
]);
|
||||
@@ -1030,6 +1059,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zCLIPEmbedModelFieldInputTemplate,
|
||||
zCLIPLEmbedModelFieldInputTemplate,
|
||||
zCLIPGEmbedModelFieldInputTemplate,
|
||||
zStructuralLoRAModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
|
||||
@@ -28,6 +28,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
CLIPEmbedModelField: undefined,
|
||||
CLIPLEmbedModelField: undefined,
|
||||
CLIPGEmbedModelField: undefined,
|
||||
StructuralLoRAModelField: undefined,
|
||||
};
|
||||
|
||||
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
||||
|
||||
@@ -28,6 +28,7 @@ import type {
|
||||
StatefulFieldType,
|
||||
StatelessFieldInputTemplate,
|
||||
StringFieldInputTemplate,
|
||||
StructuralLoRAModelFieldInputTemplate,
|
||||
T2IAdapterModelFieldInputTemplate,
|
||||
T5EncoderModelFieldInputTemplate,
|
||||
VAEModelFieldInputTemplate,
|
||||
@@ -300,6 +301,20 @@ const buildCLIPGEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPGEmb
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStructuralLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<StructuralLoRAModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: StructuralLoRAModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -526,6 +541,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
CLIPLEmbedModelField: buildCLIPLEmbedModelFieldInputTemplate,
|
||||
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
|
||||
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
|
||||
StructuralLoRAModelField: buildStructuralLoRAModelFieldInputTemplate,
|
||||
} as const;
|
||||
|
||||
export const buildFieldInputTemplate = (
|
||||
|
||||
@@ -113,6 +113,11 @@ export const zParameterVAEModel = zModelIdentifierField;
|
||||
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
|
||||
// #endregion
|
||||
|
||||
// #region Structural Lora Model
|
||||
export const zParameterStructuralLoRAModel = zModelIdentifierField;
|
||||
export type ParameterStructuralLoRAModel = z.infer<typeof zParameterStructuralLoRAModel>;
|
||||
// #endregion
|
||||
|
||||
// #region T5Encoder Model
|
||||
export const zParameterT5EncoderModel = zModelIdentifierField;
|
||||
export type ParameterT5EncoderModel = z.infer<typeof zParameterT5EncoderModel>;
|
||||
|
||||
@@ -23,6 +23,7 @@ import {
|
||||
isSD3MainModelModelConfig,
|
||||
isSDXLMainModelModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isStructuralLoRAModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isT5EncoderModelConfig,
|
||||
isTIModelConfig,
|
||||
@@ -58,6 +59,7 @@ export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
|
||||
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
|
||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||
export const useStructuralLoRAModel = buildModelsHook(isStructuralLoRAModelConfig);
|
||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -44,6 +44,7 @@ export type BaseModelType = S['BaseModelType'];
|
||||
|
||||
// Model Configs
|
||||
|
||||
export type StructuralLoRAModelConfig = S['StructuralLoRALyCORISConfig'];
|
||||
// TODO(MM2): Can we make key required in the pydantic model?
|
||||
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
|
||||
// TODO(MM2): Can we rename this from Vae -> VAE
|
||||
@@ -63,6 +64,7 @@ export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||
export type AnyModelConfig =
|
||||
| StructuralLoRAModelConfig
|
||||
| LoRAModelConfig
|
||||
| VAEModelConfig
|
||||
| ControlNetModelConfig
|
||||
@@ -114,6 +116,10 @@ export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelCo
|
||||
return config.type === 'lora';
|
||||
};
|
||||
|
||||
export const isStructuralLoRAModelConfig = (config: AnyModelConfig): config is StructuralLoRAModelConfig => {
|
||||
return config.type === 'structural_lora';
|
||||
};
|
||||
|
||||
export const isVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => {
|
||||
return config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && checkSubmodels(['vae'], config));
|
||||
};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import (
|
||||
is_state_dict_likely_flux_control,
|
||||
lora_model_from_flux_control_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_control_lora_format import (
|
||||
state_dict_keys as flux_control_lora_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
|
||||
state_dict_keys as flux_diffusers_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_control_lora_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_control_format_true(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_flux_control() can identify a state dict in the FLUX Control LoRA format."""
|
||||
# Construct a state dict that is in the Diffusers FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
assert is_state_dict_likely_flux_control(state_dict)
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_control_format_false(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_flux_control() returns False for a state dict that is in the Diffusers
|
||||
FLUX LoRA format.
|
||||
"""
|
||||
# Construct a state dict that is not in the FLUX Control LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
assert not is_state_dict_likely_flux_control(state_dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_control_lora_state_dict_keys])
|
||||
def test_lora_model_from_flux_control_state_dict(sd_keys: dict[str, list[int]]):
|
||||
"""Test that lora_model_from_flux_control_state_dict() can load a state dict in the FLUX Control LoRA format."""
|
||||
# Construct a state dict that is in the FLUX Control LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
# Load the state dict into a LoRAModelRaw object.
|
||||
model = lora_model_from_flux_control_state_dict(state_dict)
|
||||
|
||||
# Check that the model has the correct number of LoRA layers.
|
||||
expected_lora_layers: set[str] = set()
|
||||
for k in sd_keys:
|
||||
k = k.replace("lora_A.weight", "")
|
||||
k = k.replace("lora_B.weight", "")
|
||||
k = k.replace("lora_B.bias", "")
|
||||
k = k.replace(".scale", "")
|
||||
expected_lora_layers.add(k)
|
||||
# Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep
|
||||
# the Q weights so that we count these layers once).
|
||||
assert len(model.layers) == len(expected_lora_layers)
|
||||
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
|
||||
|
||||
|
||||
def test_lora_model_from_flux_control_state_dict_extra_keys_error():
|
||||
"""Test that lora_model_from_flux_control_state_dict() raises an error if the input state_dict contains unexpected
|
||||
keys that we don't handle.
|
||||
"""
|
||||
# Construct a state dict that is in the FLUX Control LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(flux_control_lora_state_dict_keys)
|
||||
# Add an unexpected key.
|
||||
state_dict["transformer.single_transformer_blocks.0.unexpected_key.lora_A.weight"] = torch.empty(1)
|
||||
|
||||
# Check that an error is raised.
|
||||
with pytest.raises(AssertionError):
|
||||
lora_model_from_flux_control_state_dict(state_dict)
|
||||
@@ -1,49 +0,0 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
|
||||
ConcatenatedLoRALinearSidecarLayer,
|
||||
)
|
||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
||||
|
||||
|
||||
def test_concatenated_lora_linear_sidecar_layer():
|
||||
"""Test that a ConcatenatedLoRALinearSidecarLayer is equivalent to patching a linear layer with the ConcatenatedLoRA
|
||||
layer.
|
||||
"""
|
||||
|
||||
# Create a linear layer.
|
||||
in_features = 5
|
||||
sub_layer_out_features = [5, 10, 15]
|
||||
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
|
||||
|
||||
# Create a ConcatenatedLoRA layer.
|
||||
rank = 4
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for out_features in sub_layer_out_features:
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
|
||||
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
||||
|
||||
# Patch the ConcatenatedLoRA layer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += (
|
||||
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
|
||||
)
|
||||
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
|
||||
|
||||
# Create a ConcatenatedLoRALinearSidecarLayer.
|
||||
concatenated_lora_linear_sidecar_layer = ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer, weight=1.0)
|
||||
linear_with_sidecar = LoRASidecarModule(linear, [concatenated_lora_linear_sidecar_layer])
|
||||
|
||||
# Run the ConcatenatedLoRA-patched linear layer and the ConcatenatedLoRALinearSidecarLayer and assert they are
|
||||
# equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_sidecar = linear_with_sidecar(input)
|
||||
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)
|
||||
@@ -1,38 +0,0 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
|
||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_lora_linear_sidecar_layer():
|
||||
"""Test that a LoRALinearSidecarLayer is equivalent to patching a linear layer with the LoRA layer."""
|
||||
|
||||
# Create a linear layer.
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
linear = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
# Create a LoRA layer.
|
||||
rank = 4
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
|
||||
|
||||
# Patch the LoRA layer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
|
||||
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
|
||||
# Create a LoRALinearSidecarLayer.
|
||||
lora_linear_sidecar_layer = LoRALinearSidecarLayer(lora_layer, weight=1.0)
|
||||
linear_with_sidecar = LoRASidecarModule(linear, [lora_linear_sidecar_layer])
|
||||
|
||||
# Run the LoRA-patched linear layer and the LoRALinearSidecarLayer and assert they are equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_sidecar = linear_with_sidecar(input)
|
||||
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)
|
||||
69
tests/backend/lora/test_lora_layer_wrappers.py
Normal file
69
tests/backend/lora/test_lora_layer_wrappers.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_layer_wrappers import LoRALinearWrapper
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_lora_linear_wrapper():
|
||||
# Create a linear layer.
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
linear = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
# Create a LoRA layer.
|
||||
rank = 4
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
|
||||
|
||||
# Patch the LoRA layer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
|
||||
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
|
||||
|
||||
# Create a LoRALinearWrapper.
|
||||
lora_wrapped = LoRALinearWrapper(linear, [lora_layer], [1.0])
|
||||
|
||||
# Run the LoRA-patched linear layer and the LoRALinearWrapper and assert they are equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = lora_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
|
||||
|
||||
def test_concatenated_lora_linear_wrapper():
|
||||
# Create a linear layer.
|
||||
in_features = 5
|
||||
sub_layer_out_features = [5, 10, 15]
|
||||
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
|
||||
|
||||
# Create a ConcatenatedLoRA layer.
|
||||
rank = 4
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for out_features in sub_layer_out_features:
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
|
||||
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
||||
|
||||
# Patch the ConcatenatedLoRA layer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += (
|
||||
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
|
||||
)
|
||||
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
|
||||
|
||||
# Create a LoRALinearWrapper.
|
||||
lora_wrapped = LoRALinearWrapper(linear, [concatenated_lora_layer], [1.0])
|
||||
|
||||
# Run the ConcatenatedLoRA-patched linear layer and the LoRALinearWrapper and assert they are equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = lora_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
@@ -2,11 +2,15 @@ import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_layer_wrappers import LoRASidecarWrapper
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
class DummyModuleWithOneLayer(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
|
||||
super().__init__()
|
||||
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||
@@ -15,8 +19,18 @@ class DummyModule(torch.nn.Module):
|
||||
return self.linear_layer_1(x)
|
||||
|
||||
|
||||
class DummyModuleWithTwoLayers(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
|
||||
super().__init__()
|
||||
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||
self.linear_layer_2 = torch.nn.Linear(out_features, out_features, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear_layer_2(self.linear_layer_1(x))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_layers"],
|
||||
["device", "num_loras"],
|
||||
[
|
||||
("cpu", 1),
|
||||
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
@@ -25,7 +39,7 @@ class DummyModule(torch.nn.Module):
|
||||
],
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test_apply_lora_patches(device: str, num_layers: int):
|
||||
def test_apply_lora_patches(device: str, num_loras: int):
|
||||
"""Test the basic behavior of ModelPatcher.apply_lora_patches(...). Check that patching and unpatching produce the
|
||||
correct result, and that model/LoRA tensors are moved between devices as expected.
|
||||
"""
|
||||
@@ -33,12 +47,12 @@ def test_apply_lora_patches(device: str, num_layers: int):
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
# Initialize num_loras LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
@@ -51,7 +65,7 @@ def test_apply_lora_patches(device: str, num_layers: int):
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers)
|
||||
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras)
|
||||
|
||||
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
@@ -79,7 +93,7 @@ def test_apply_lora_patches_change_device():
|
||||
linear_out_features = 8
|
||||
lora_dim = 2
|
||||
# Initialize the model on the CPU.
|
||||
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
@@ -110,7 +124,7 @@ def test_apply_lora_patches_change_device():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_layers"],
|
||||
["device", "num_loras"],
|
||||
[
|
||||
("cpu", 1),
|
||||
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
@@ -118,18 +132,18 @@ def test_apply_lora_patches_change_device():
|
||||
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
],
|
||||
)
|
||||
def test_apply_lora_sidecar_patches(device: str, num_layers: int):
|
||||
"""Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly."""
|
||||
def test_apply_lora_wrapper_patches(device: str, num_loras: int):
|
||||
"""Test the basic behavior of ModelPatcher.apply_lora_wrapper_patches(...). Check that unpatching works correctly."""
|
||||
dtype = torch.float16
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
# Initialize num_loras LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
@@ -146,7 +160,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
|
||||
output_before_patch = model(input)
|
||||
|
||||
# Patch the model and run inference during the patch.
|
||||
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_during_patch = model(input)
|
||||
|
||||
# Run inference after unpatching.
|
||||
@@ -159,20 +173,140 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
|
||||
assert torch.allclose(output_before_patch, output_after_patch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_loras"],
|
||||
[
|
||||
("cpu", 1),
|
||||
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
("cpu", 2),
|
||||
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
],
|
||||
)
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
|
||||
def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
|
||||
"""Test that apply_lora_sidecar_patches(...) produces the same model outputs as apply_lora_patches(...)."""
|
||||
def test_apply_smart_lora_patches(device: str, num_loras: int):
|
||||
"""Test the basic behavior of ModelPatcher.apply_smart_lora_patches(...). Check that unpatching works correctly."""
|
||||
dtype = torch.float16
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
||||
|
||||
# Initialize num_loras LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
# Run inference before patching the model.
|
||||
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
|
||||
output_before_patch = model(input)
|
||||
|
||||
# Patch the model and run inference during the patch.
|
||||
with LoRAPatcher.apply_smart_lora_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_during_patch = model(input)
|
||||
|
||||
# Run inference after unpatching.
|
||||
output_after_patch = model(input)
|
||||
|
||||
# Check that the output before patching is different from the output during patching.
|
||||
assert not torch.allclose(output_before_patch, output_during_patch)
|
||||
|
||||
# Check that the output before patching is the same as the output after patching.
|
||||
assert torch.allclose(output_before_patch, output_after_patch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["num_loras"], [(1,), (2,)])
|
||||
@torch.no_grad()
|
||||
def test_apply_smart_lora_patches_to_partially_loaded_model(num_loras: int):
|
||||
"""Test the behavior of ModelPatcher.apply_smart_lora_patches(...) when it is applied to a
|
||||
CachedModelWithPartialLoad that is partially loaded into VRAM.
|
||||
"""
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("requires CUDA device")
|
||||
|
||||
# Initialize the model on the CPU.
|
||||
dtype = torch.float16
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModuleWithTwoLayers(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("cuda"))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
_ = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert cached_model.model.linear_layer_1.weight.device.type == "cuda"
|
||||
assert cached_model.model.linear_layer_2.weight.device.type == "cpu"
|
||||
|
||||
# Initialize num_loras LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
),
|
||||
"linear_layer_2": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_out_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
),
|
||||
}
|
||||
lora = LoRAModelRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
# Run inference before patching the model.
|
||||
input = torch.randn(1, linear_in_features, device="cuda", dtype=dtype)
|
||||
output_before_patch = cached_model.model(input)
|
||||
|
||||
# Patch the model and run inference during the patch.
|
||||
with LoRAPatcher.apply_smart_lora_patches(model=cached_model.model, patches=lora_models, prefix="", dtype=dtype):
|
||||
# Check that the second layer is wrapped in a LoRASidecarWrapper, but the first layer is not.
|
||||
assert not isinstance(cached_model.model.linear_layer_1, LoRASidecarWrapper)
|
||||
assert isinstance(cached_model.model.linear_layer_2, LoRASidecarWrapper)
|
||||
|
||||
output_during_patch = cached_model.model(input)
|
||||
|
||||
# Run inference after unpatching.
|
||||
output_after_patch = cached_model.model(input)
|
||||
|
||||
# Check that the output before patching is different from the output during patching.
|
||||
assert not torch.allclose(output_before_patch, output_during_patch)
|
||||
|
||||
# Check that the output before patching is the same as the output after patching.
|
||||
assert torch.allclose(output_before_patch, output_after_patch)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize(["num_loras"], [(1,), (2,)])
|
||||
def test_all_patching_methods_produce_same_output(num_loras: int):
|
||||
"""Test that apply_lora_wrapper_patches(...) produces the same model outputs as apply_lora_patches(...)."""
|
||||
dtype = torch.float32
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
# Initialize num_loras LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
@@ -189,9 +323,13 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
|
||||
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
|
||||
output_lora_patches = model(input)
|
||||
|
||||
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_lora_sidecar_patches = model(input)
|
||||
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_lora_wrapper_patches = model(input)
|
||||
|
||||
with LoRAPatcher.apply_smart_lora_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_smart_lora_patches = model(input)
|
||||
|
||||
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
|
||||
# differences are tolerable and expected due to the difference between sidecar vs. patching.
|
||||
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
|
||||
assert torch.allclose(output_lora_patches, output_lora_wrapper_patches, atol=1e-5)
|
||||
assert torch.allclose(output_lora_patches, output_smart_lora_patches, atol=1e-5)
|
||||
|
||||
Reference in New Issue
Block a user