mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 03:28:25 -05:00
Compare commits
15 Commits
controlnet
...
ryan/lora-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ed6e65a6e | ||
|
|
52c9646f84 | ||
|
|
7662f0522b | ||
|
|
e50fe69839 | ||
|
|
5a9f884620 | ||
|
|
edc72d1739 | ||
|
|
23f521dc7c | ||
|
|
3d6b93efdd | ||
|
|
3f28d3afad | ||
|
|
9353bfbdd6 | ||
|
|
93f2bc6118 | ||
|
|
9019026d6d | ||
|
|
c195b326ec | ||
|
|
2f460d2a45 | ||
|
|
4473cba512 |
@@ -82,10 +82,11 @@ class CompelInvocation(BaseInvocation):
|
|||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
LoRAPatcher.apply_lora_patches(
|
LoRAPatcher.apply_smart_lora_patches(
|
||||||
model=text_encoder,
|
model=text_encoder,
|
||||||
patches=_lora_loader(),
|
patches=_lora_loader(),
|
||||||
prefix="lora_te_",
|
prefix="lora_te_",
|
||||||
|
dtype=TorchDevice.choose_torch_dtype(),
|
||||||
cached_weights=cached_weights,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
@@ -179,10 +180,11 @@ class SDXLPromptInvocationBase:
|
|||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
LoRAPatcher.apply_lora_patches(
|
LoRAPatcher.apply_smart_lora_patches(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
patches=_lora_loader(),
|
patches=_lora_loader(),
|
||||||
prefix=lora_prefix,
|
prefix=lora_prefix,
|
||||||
|
dtype=TorchDevice.choose_torch_dtype(),
|
||||||
cached_weights=cached_weights,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
|
|||||||
@@ -1003,10 +1003,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
LoRAPatcher.apply_lora_patches(
|
LoRAPatcher.apply_smart_lora_patches(
|
||||||
model=unet,
|
model=unet,
|
||||||
patches=_lora_loader(),
|
patches=_lora_loader(),
|
||||||
prefix="lora_unet_",
|
prefix="lora_unet_",
|
||||||
|
dtype=unet.dtype,
|
||||||
cached_weights=cached_weights,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
|||||||
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
||||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||||
|
StructuralLoRAModel = "StructuralLoRAModelField"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Misc Field Types
|
# region Misc Field Types
|
||||||
@@ -143,6 +144,7 @@ class FieldDescriptions:
|
|||||||
controlnet_model = "ControlNet model to load"
|
controlnet_model = "ControlNet model to load"
|
||||||
vae_model = "VAE model to load"
|
vae_model = "VAE model to load"
|
||||||
lora_model = "LoRA model to load"
|
lora_model = "LoRA model to load"
|
||||||
|
structural_lora_model = "Structural LoRA model to load"
|
||||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||||
flux_model = "Flux model (Transformer) to load"
|
flux_model = "Flux model (Transformer) to load"
|
||||||
sd3_model = "SD3 model (MMDiTX) to load"
|
sd3_model = "SD3 model (MMDiTX) to load"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Callable, Iterator, Optional, Tuple
|
from typing import Callable, Iterator, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -8,6 +8,8 @@ import torchvision.transforms as tv_transforms
|
|||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
DenoiseMaskField,
|
DenoiseMaskField,
|
||||||
@@ -22,7 +24,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
from invokeai.app.invocations.model import TransformerField, VAEField, StructuralLoRAField, LoRAField
|
||||||
from invokeai.app.invocations.primitives import LatentsOutput
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||||
@@ -43,6 +45,8 @@ from invokeai.backend.flux.sampling_utils import (
|
|||||||
pack,
|
pack,
|
||||||
unpack,
|
unpack,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
|
||||||
|
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
@@ -284,6 +288,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
dtype=inference_dtype,
|
dtype=inference_dtype,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
)
|
)
|
||||||
|
img_cond = None
|
||||||
|
if struct_lora := self.transformer.structural_lora:
|
||||||
|
# What should we do when we have multiple of these?
|
||||||
|
if not self.controlnet_vae:
|
||||||
|
raise ValueError("controlnet_vae must be set when using a strutural lora")
|
||||||
|
ae_info = context.models.load(self.controlnet_vae.vae)
|
||||||
|
img = context.images.get_pil(struct_lora.img.image_name)
|
||||||
|
with ae_info as ae:
|
||||||
|
assert isinstance(ae, AutoEncoder)
|
||||||
|
img_cond = prepare_control(self.height, self.width, self.seed, ae, img)
|
||||||
|
|
||||||
# Load the transformer model.
|
# Load the transformer model.
|
||||||
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
||||||
@@ -296,10 +310,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
if config.format in [ModelFormat.Checkpoint]:
|
if config.format in [ModelFormat.Checkpoint]:
|
||||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||||
exit_stack.enter_context(
|
exit_stack.enter_context(
|
||||||
LoRAPatcher.apply_lora_patches(
|
LoRAPatcher.apply_smart_lora_patches(
|
||||||
model=transformer,
|
model=transformer,
|
||||||
patches=self._lora_iterator(context),
|
patches=self._lora_iterator(context),
|
||||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||||
|
dtype=inference_dtype,
|
||||||
cached_weights=cached_weights,
|
cached_weights=cached_weights,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -311,7 +326,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
|
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
|
||||||
# than directly patching the weights, but is agnostic to the quantization format.
|
# than directly patching the weights, but is agnostic to the quantization format.
|
||||||
exit_stack.enter_context(
|
exit_stack.enter_context(
|
||||||
LoRAPatcher.apply_lora_sidecar_patches(
|
LoRAPatcher.apply_lora_wrapper_patches(
|
||||||
model=transformer,
|
model=transformer,
|
||||||
patches=self._lora_iterator(context),
|
patches=self._lora_iterator(context),
|
||||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||||
@@ -345,6 +360,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
controlnet_extensions=controlnet_extensions,
|
controlnet_extensions=controlnet_extensions,
|
||||||
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||||
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||||
|
img_cond=img_cond
|
||||||
)
|
)
|
||||||
|
|
||||||
x = unpack(x.float(), self.height, self.width)
|
x = unpack(x.float(), self.height, self.width)
|
||||||
@@ -682,7 +698,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
|
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
|
||||||
|
|
||||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.transformer.loras:
|
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras]
|
||||||
|
if self.transformer.structural_lora:
|
||||||
|
loras.append(self.transformer.structural_lora)
|
||||||
|
for lora in loras:
|
||||||
lora_info = context.models.load(lora.lora)
|
lora_info = context.models.load(lora.lora)
|
||||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
yield (lora_info.model, lora.weight)
|
yield (lora_info.model, lora.weight)
|
||||||
|
|||||||
@@ -81,8 +81,8 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||||
|
|
||||||
return FluxModelLoaderOutput(
|
return FluxModelLoaderOutput(
|
||||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
transformer=TransformerField(transformer=transformer, loras=[], structural_loras=[]),
|
||||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0),
|
||||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||||
vae=VAEField(vae=vae),
|
vae=VAEField(vae=vae),
|
||||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||||
|
|||||||
70
invokeai/app/invocations/flux_structural_lora_loader.py
Normal file
70
invokeai/app/invocations/flux_structural_lora_loader.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from typing import Optional, Literal
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
Classification,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
|
||||||
|
from invokeai.app.invocations.model import VAEField, StructuralLoRAField, ModelIdentifierField, TransformerField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("flux_structural_lora_loader_output")
|
||||||
|
class FluxStructuralLoRALoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Flux Structural LoRA Loader Output"""
|
||||||
|
|
||||||
|
transformer: Optional[TransformerField] = OutputField(
|
||||||
|
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"flux_structural_lora_loader",
|
||||||
|
title="Flux Structural LoRA",
|
||||||
|
tags=["lora", "model", "flux"],
|
||||||
|
category="model",
|
||||||
|
version="1.1.0",
|
||||||
|
classification=Classification.Prototype,
|
||||||
|
)
|
||||||
|
class FluxStructuralLoRALoaderInvocation(BaseInvocation):
|
||||||
|
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
|
||||||
|
|
||||||
|
lora: ModelIdentifierField = InputField(
|
||||||
|
description=FieldDescriptions.structural_lora_model, title="Structural LoRA", ui_type=UIType.StructuralLoRAModel
|
||||||
|
)
|
||||||
|
transformer: TransformerField | None = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.transformer,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="FLUX Transformer",
|
||||||
|
)
|
||||||
|
image: ImageField = InputField(
|
||||||
|
description="The image to encode.",
|
||||||
|
)
|
||||||
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> FluxStructuralLoRALoaderOutput:
|
||||||
|
lora_key = self.lora.key
|
||||||
|
|
||||||
|
if not context.models.exists(lora_key):
|
||||||
|
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||||
|
|
||||||
|
# Check for existing LoRAs with the same key.
|
||||||
|
if self.transformer and self.transformer.structural_lora and self.transformer.structural_lora.lora.key == lora_key:
|
||||||
|
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
|
||||||
|
|
||||||
|
output = FluxStructuralLoRALoaderOutput()
|
||||||
|
|
||||||
|
# Attach LoRA layers to the models.
|
||||||
|
if self.transformer is not None:
|
||||||
|
output.transformer = self.transformer.model_copy(deep=True)
|
||||||
|
output.transformer.structural_lora = StructuralLoRAField(
|
||||||
|
lora=self.lora,
|
||||||
|
img=self.image,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
@@ -22,6 +22,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
|||||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||||
from invokeai.backend.model_manager.config import ModelFormat
|
from invokeai.backend.model_manager.config import ModelFormat
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@@ -111,10 +112,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||||
exit_stack.enter_context(
|
exit_stack.enter_context(
|
||||||
LoRAPatcher.apply_lora_patches(
|
LoRAPatcher.apply_smart_lora_patches(
|
||||||
model=clip_text_encoder,
|
model=clip_text_encoder,
|
||||||
patches=self._clip_lora_iterator(context),
|
patches=self._clip_lora_iterator(context),
|
||||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||||
|
dtype=TorchDevice.choose_torch_dtype(),
|
||||||
cached_weights=cached_weights,
|
cached_weights=cached_weights,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
@@ -65,11 +65,6 @@ class CLIPField(BaseModel):
|
|||||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||||
|
|
||||||
|
|
||||||
class TransformerField(BaseModel):
|
|
||||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
|
||||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
|
||||||
|
|
||||||
|
|
||||||
class T5EncoderField(BaseModel):
|
class T5EncoderField(BaseModel):
|
||||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||||
@@ -79,6 +74,13 @@ class VAEField(BaseModel):
|
|||||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
|
|
||||||
|
class StructuralLoRAField(LoRAField):
|
||||||
|
img: ImageField = Field(description="Image to use in structural conditioning")
|
||||||
|
|
||||||
|
class TransformerField(BaseModel):
|
||||||
|
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||||
|
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||||
|
structural_lora: Optional[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading", default=None)
|
||||||
|
|
||||||
@invocation_output("unet_output")
|
@invocation_output("unet_output")
|
||||||
class UNetOutput(BaseInvocationOutput):
|
class UNetOutput(BaseInvocationOutput):
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
|||||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||||
from invokeai.backend.model_manager.config import ModelFormat
|
from invokeai.backend.model_manager.config import ModelFormat
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
||||||
SD3_T5_MAX_SEQ_LEN = 256
|
SD3_T5_MAX_SEQ_LEN = 256
|
||||||
@@ -150,10 +151,11 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
|||||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||||
exit_stack.enter_context(
|
exit_stack.enter_context(
|
||||||
LoRAPatcher.apply_lora_patches(
|
LoRAPatcher.apply_smart_lora_patches(
|
||||||
model=clip_text_encoder,
|
model=clip_text_encoder,
|
||||||
patches=self._clip_lora_iterator(context, clip_model),
|
patches=self._clip_lora_iterator(context, clip_model),
|
||||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||||
|
dtype=TorchDevice.choose_torch_dtype(),
|
||||||
cached_weights=cached_weights,
|
cached_weights=cached_weights,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -207,7 +207,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
unet_info as unet,
|
unet_info as unet,
|
||||||
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
LoRAPatcher.apply_smart_lora_patches(
|
||||||
|
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
|
||||||
|
),
|
||||||
):
|
):
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ def denoise(
|
|||||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
||||||
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||||
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||||
|
# extra img tokens
|
||||||
|
img_cond: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
# step 0 is the initial state
|
# step 0 is the initial state
|
||||||
total_steps = len(timesteps) - 1
|
total_steps = len(timesteps) - 1
|
||||||
@@ -69,9 +71,9 @@ def denoise(
|
|||||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||||
|
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
|
||||||
pred = model(
|
pred = model(
|
||||||
img=img,
|
img=pred_img,
|
||||||
img_ids=img_ids,
|
img_ids=img_ids,
|
||||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||||
|
|||||||
27
invokeai/backend/flux/flux_tools_sampling_utils.py
Normal file
27
invokeai/backend/flux/flux_tools_sampling_utils.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
|
|
||||||
|
def prepare_control(
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
seed: int,
|
||||||
|
ae: AutoEncoder,
|
||||||
|
cond_image: Image.Image,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# load and encode the conditioning image
|
||||||
|
img_cond = cond_image.convert("RGB")
|
||||||
|
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS)
|
||||||
|
img_cond = np.array(img_cond)
|
||||||
|
img_cond = torch.from_numpy(img_cond).float()
|
||||||
|
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
|
||||||
|
ae_dtype = next(iter(ae.parameters())).dtype
|
||||||
|
ae_device = next(iter(ae.parameters())).device
|
||||||
|
img_cond = img_cond.to(device=ae_device, dtype=ae_dtype)
|
||||||
|
generator = torch.Generator(device=ae_device).manual_seed(seed)
|
||||||
|
img_cond = ae.encode(img_cond, sample=True, generator=generator)
|
||||||
|
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
return img_cond
|
||||||
@@ -4,6 +4,7 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from invokeai.backend.flux.custom_block_processor import (
|
from invokeai.backend.flux.custom_block_processor import (
|
||||||
CustomDoubleStreamBlockProcessor,
|
CustomDoubleStreamBlockProcessor,
|
||||||
@@ -35,6 +36,7 @@ class FluxParams:
|
|||||||
theta: int
|
theta: int
|
||||||
qkv_bias: bool
|
qkv_bias: bool
|
||||||
guidance_embed: bool
|
guidance_embed: bool
|
||||||
|
out_channels: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class Flux(nn.Module):
|
class Flux(nn.Module):
|
||||||
@@ -47,7 +49,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
self.params = params
|
self.params = params
|
||||||
self.in_channels = params.in_channels
|
self.in_channels = params.in_channels
|
||||||
self.out_channels = self.in_channels
|
self.out_channels = params.out_channels or self.in_channels
|
||||||
if params.hidden_size % params.num_heads != 0:
|
if params.hidden_size % params.num_heads != 0:
|
||||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||||
pe_dim = params.hidden_size // params.num_heads
|
pe_dim = params.hidden_size // params.num_heads
|
||||||
|
|||||||
50
invokeai/backend/flux/modules/image_embedders.py
Normal file
50
invokeai/backend/flux/modules/image_embedders.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from PIL import Image
|
||||||
|
from safetensors.torch import load_file as load_sft
|
||||||
|
from torch import nn
|
||||||
|
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
|
||||||
|
|
||||||
|
class DepthImageEncoder:
|
||||||
|
depth_model_name = "LiheYoung/depth-anything-large-hf"
|
||||||
|
def __init__(self, device):
|
||||||
|
self.device = device
|
||||||
|
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
|
||||||
|
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||||
|
hw = img.shape[-2:]
|
||||||
|
img = torch.clamp(img, -1.0, 1.0)
|
||||||
|
img_byte = ((img + 1.0) * 127.5).byte()
|
||||||
|
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
|
||||||
|
depth = self.depth_model(img.to(self.device)).predicted_depth
|
||||||
|
depth = repeat(depth, "b h w -> b 3 h w")
|
||||||
|
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
|
||||||
|
depth = depth / 127.5 - 1.0
|
||||||
|
return depth
|
||||||
|
|
||||||
|
class CannyImageEncoder:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device,
|
||||||
|
min_t: int = 50,
|
||||||
|
max_t: int = 200,
|
||||||
|
):
|
||||||
|
self.device = device
|
||||||
|
self.min_t = min_t
|
||||||
|
self.max_t = max_t
|
||||||
|
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert img.shape[0] == 1, "Only batch size 1 is supported"
|
||||||
|
img = rearrange(img[0], "c h w -> h w c")
|
||||||
|
img = torch.clamp(img, -1.0, 1.0)
|
||||||
|
img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
|
||||||
|
# Apply Canny edge detection
|
||||||
|
canny = cv2.Canny(img_np, self.min_t, self.max_t)
|
||||||
|
# Convert back to torch tensor and reshape
|
||||||
|
canny = torch.from_numpy(canny).float() / 127.5 - 1.0
|
||||||
|
canny = rearrange(canny, "h w -> 1 1 h w")
|
||||||
|
canny = repeat(canny, "b 1 ... -> b 3 ...")
|
||||||
|
return canny.to(self.device)
|
||||||
65
invokeai/backend/lora/conversions/flux_control_lora_utils.py
Normal file
65
invokeai/backend/lora/conversions/flux_control_lora_utils.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import re
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||||
|
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||||
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
|
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||||
|
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||||
|
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||||
|
|
||||||
|
|
||||||
|
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
|
||||||
|
# Example keys:
|
||||||
|
# guidance_in.in_layer.lora_B.bias
|
||||||
|
# single_blocks.0.linear1.lora_A.weight
|
||||||
|
# double_blocks.0.img_attn.norm.key_norm.scale
|
||||||
|
FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)"
|
||||||
|
|
||||||
|
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||||
|
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
|
||||||
|
|
||||||
|
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||||
|
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||||
|
"""
|
||||||
|
return all(
|
||||||
|
re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k)
|
||||||
|
for k in state_dict.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||||
|
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
|
||||||
|
# Group keys by layer.
|
||||||
|
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
key_props = key.split(".")
|
||||||
|
# Got it loading using lora_down and lora_up but it didn't seem to match this lora's structure
|
||||||
|
# Leaving this in since it doesn't hurt anything and may be better
|
||||||
|
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
|
||||||
|
layer_name = ".".join(key_props[:layer_prop_size])
|
||||||
|
param_name = ".".join(key_props[layer_prop_size:])
|
||||||
|
if layer_name not in grouped_state_dict:
|
||||||
|
grouped_state_dict[layer_name] = {}
|
||||||
|
grouped_state_dict[layer_name][param_name] = value
|
||||||
|
|
||||||
|
# Create LoRA layers.
|
||||||
|
layers: dict[str, AnyLoRALayer] = {}
|
||||||
|
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||||
|
# Convert to a full layer diff
|
||||||
|
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
|
||||||
|
if all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
|
||||||
|
layers[prefixed_key] = LoRALayer(
|
||||||
|
layer_state_dict["lora_B.weight"],
|
||||||
|
None,
|
||||||
|
layer_state_dict["lora_A.weight"],
|
||||||
|
None,
|
||||||
|
layer_state_dict["lora_B.bias"]
|
||||||
|
)
|
||||||
|
elif "scale" in layer_state_dict:
|
||||||
|
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"{layer_key} not expected")
|
||||||
|
# Create and return the LoRAModelRaw.
|
||||||
|
return LoRAModelRaw(layers=layers)
|
||||||
|
|
||||||
@@ -7,5 +7,6 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
|||||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||||
|
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||||
|
|
||||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]
|
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer]
|
||||||
|
|||||||
34
invokeai/backend/lora/layers/reshape_weight_layer.py
Normal file
34
invokeai/backend/lora/layers/reshape_weight_layer.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||||
|
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||||
|
|
||||||
|
|
||||||
|
class ReshapeWeightLayer(LoRALayerBase):
|
||||||
|
# TODO: Just everything in this class
|
||||||
|
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
|
||||||
|
super().__init__(alpha=None, bias=bias)
|
||||||
|
self.weight = torch.nn.Parameter(weight) if weight is not None else None
|
||||||
|
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
||||||
|
self.manual_scale = scale
|
||||||
|
|
||||||
|
def scale(self):
|
||||||
|
return self.manual_scale.float() if self.manual_scale is not None else super().scale()
|
||||||
|
|
||||||
|
def rank(self) -> int | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
return orig_weight
|
||||||
|
|
||||||
|
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
if self.weight is not None:
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
if self.manual_scale is not None:
|
||||||
|
self.manual_scale = self.manual_scale.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
return super().calc_size() + calc_tensor_size(self.manual_scale)
|
||||||
29
invokeai/backend/lora/layers/set_parameter_layer.py
Normal file
29
invokeai/backend/lora/layers/set_parameter_layer.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||||
|
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||||
|
|
||||||
|
|
||||||
|
class SetParameterLayer(LoRALayerBase):
|
||||||
|
def __init__(self, param_name: str, weight: torch.Tensor):
|
||||||
|
super().__init__(None, None)
|
||||||
|
self.weight = weight
|
||||||
|
self.param_name = param_name
|
||||||
|
|
||||||
|
def rank(self) -> int | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.weight - orig_weight
|
||||||
|
|
||||||
|
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||||
|
return {self.param_name: self.get_weight(orig_module.get_parameter(self.param_name))}
|
||||||
|
|
||||||
|
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
return super().calc_size() + calc_tensor_size(self.weight)
|
||||||
@@ -9,6 +9,7 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
|||||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||||
|
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||||
|
|
||||||
|
|
||||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||||
|
|||||||
133
invokeai/backend/lora/lora_layer_wrappers.py
Normal file
133
invokeai/backend/lora/lora_layer_wrappers.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||||
|
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||||
|
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||||
|
|
||||||
|
|
||||||
|
class LoRASidecarWrapper(torch.nn.Module):
|
||||||
|
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[AnyLoRALayer], lora_weights: list[float]):
|
||||||
|
super().__init__()
|
||||||
|
self._orig_module = orig_module
|
||||||
|
self._lora_layers = lora_layers
|
||||||
|
self._lora_weights = lora_weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def orig_module(self) -> torch.nn.Module:
|
||||||
|
return self._orig_module
|
||||||
|
|
||||||
|
def add_lora_layer(self, lora_layer: AnyLoRALayer, lora_weight: float):
|
||||||
|
self._lora_layers.append(lora_layer)
|
||||||
|
self._lora_weights.append(lora_weight)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _get_lora_patched_parameters(
|
||||||
|
self, orig_params: dict[str, torch.Tensor], lora_layers: list[AnyLoRALayer], lora_weights: list[float]
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
params: dict[str, torch.Tensor] = {}
|
||||||
|
for lora_layer, lora_weight in zip(lora_layers, lora_weights, strict=True):
|
||||||
|
layer_params = lora_layer.get_parameters(self._orig_module)
|
||||||
|
for param_name, param_weight in layer_params.items():
|
||||||
|
if orig_params[param_name].shape != param_weight.shape:
|
||||||
|
param_weight = param_weight.reshape(orig_params[param_name].shape)
|
||||||
|
|
||||||
|
if param_name not in params:
|
||||||
|
params[param_name] = param_weight * (lora_layer.scale() * lora_weight)
|
||||||
|
else:
|
||||||
|
params[param_name] += param_weight * (lora_layer.scale() * lora_weight)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALinearWrapper(LoRASidecarWrapper):
|
||||||
|
def _lora_linear_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||||
|
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
|
||||||
|
x = torch.nn.functional.linear(input, lora_layer.down)
|
||||||
|
if lora_layer.mid is not None:
|
||||||
|
x = torch.nn.functional.linear(x, lora_layer.mid)
|
||||||
|
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
|
||||||
|
x *= lora_weight * lora_layer.scale()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _concatenated_lora_forward(
|
||||||
|
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
|
||||||
|
x_chunks: list[torch.Tensor] = []
|
||||||
|
for lora_layer in concatenated_lora_layer.lora_layers:
|
||||||
|
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||||
|
if lora_layer.mid is not None:
|
||||||
|
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||||
|
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||||
|
x_chunk *= lora_weight * lora_layer.scale()
|
||||||
|
x_chunks.append(x_chunk)
|
||||||
|
|
||||||
|
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||||
|
assert concatenated_lora_layer.concat_axis == 0
|
||||||
|
x = torch.cat(x_chunks, dim=-1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Split the LoRA layers into those that have optimized implementations and those that don't.
|
||||||
|
optimized_layer_types = (LoRALayer, ConcatenatedLoRALayer)
|
||||||
|
optimized_layers = [
|
||||||
|
(layer, weight)
|
||||||
|
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
|
||||||
|
if isinstance(layer, optimized_layer_types)
|
||||||
|
]
|
||||||
|
non_optimized_layers = [
|
||||||
|
(layer, weight)
|
||||||
|
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
|
||||||
|
if not isinstance(layer, optimized_layer_types)
|
||||||
|
]
|
||||||
|
|
||||||
|
# First, calculate the residual for LoRA layers for which there is an optimized implementation.
|
||||||
|
residual = None
|
||||||
|
for lora_layer, lora_weight in optimized_layers:
|
||||||
|
if isinstance(lora_layer, LoRALayer):
|
||||||
|
added_residual = self._lora_linear_forward(input, lora_layer, lora_weight)
|
||||||
|
elif isinstance(lora_layer, ConcatenatedLoRALayer):
|
||||||
|
added_residual = self._concatenated_lora_forward(input, lora_layer, lora_weight)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported LoRA layer type: {type(lora_layer)}")
|
||||||
|
|
||||||
|
if residual is None:
|
||||||
|
residual = added_residual
|
||||||
|
else:
|
||||||
|
residual += added_residual
|
||||||
|
|
||||||
|
# Next, calculate the residuals for the LoRA layers for which there is no optimized implementation.
|
||||||
|
if non_optimized_layers:
|
||||||
|
unoptimized_layers, unoptimized_weights = zip(*non_optimized_layers, strict=True)
|
||||||
|
params = self._get_lora_patched_parameters(
|
||||||
|
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
|
||||||
|
lora_layers=unoptimized_layers,
|
||||||
|
lora_weights=unoptimized_weights,
|
||||||
|
)
|
||||||
|
added_residual = torch.nn.functional.linear(input, params["weight"], params.get("bias", None))
|
||||||
|
if residual is None:
|
||||||
|
residual = added_residual
|
||||||
|
else:
|
||||||
|
residual += added_residual
|
||||||
|
|
||||||
|
return self.orig_module(input) + residual
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAConv1dWrapper(LoRASidecarWrapper):
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
params = self._get_lora_patched_parameters(
|
||||||
|
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
|
||||||
|
lora_layers=self._lora_layers,
|
||||||
|
lora_weights=self._lora_weights,
|
||||||
|
)
|
||||||
|
return self.orig_module(input) + torch.nn.functional.conv1d(input, params["weight"], params.get("bias", None))
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAConv2dWrapper(LoRASidecarWrapper):
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
params = self._get_lora_patched_parameters(
|
||||||
|
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
|
||||||
|
lora_layers=self._lora_layers,
|
||||||
|
lora_weights=self._lora_weights,
|
||||||
|
)
|
||||||
|
return self.orig_module(input) + torch.nn.functional.conv2d(input, params["weight"], params.get("bias", None))
|
||||||
@@ -4,19 +4,126 @@ from typing import Dict, Iterable, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
from invokeai.backend.lora.lora_layer_wrappers import (
|
||||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
LoRAConv1dWrapper,
|
||||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
LoRAConv2dWrapper,
|
||||||
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
|
LoRALinearWrapper,
|
||||||
ConcatenatedLoRALinearSidecarLayer,
|
LoRASidecarWrapper,
|
||||||
)
|
)
|
||||||
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
class LoRAPatcher:
|
class LoRAPatcher:
|
||||||
|
@staticmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
@contextmanager
|
||||||
|
def apply_smart_lora_patches(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||||
|
prefix: str,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
):
|
||||||
|
"""Apply 'smart' LoRA patching that chooses whether to use direct patching or a sidecar wrapper for each module."""
|
||||||
|
|
||||||
|
# original_weights are stored for unpatching layers that are directly patched.
|
||||||
|
original_weights = OriginalWeightsStorage(cached_weights)
|
||||||
|
# original_modules are stored for unpatching layers that are wrapped in a LoRASidecarWrapper.
|
||||||
|
original_modules: dict[str, torch.nn.Module] = {}
|
||||||
|
try:
|
||||||
|
for patch, patch_weight in patches:
|
||||||
|
LoRAPatcher._apply_smart_lora_patch(
|
||||||
|
model=model,
|
||||||
|
prefix=prefix,
|
||||||
|
patch=patch,
|
||||||
|
patch_weight=patch_weight,
|
||||||
|
original_weights=original_weights,
|
||||||
|
original_modules=original_modules,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Restore directly patched layers.
|
||||||
|
for param_key, weight in original_weights.get_changed_weights():
|
||||||
|
model.get_parameter(param_key).copy_(weight)
|
||||||
|
|
||||||
|
# Restore LoRASidecarWrapper modules.
|
||||||
|
# Note: This logic assumes no nested modules in original_modules.
|
||||||
|
for module_key, orig_module in original_modules.items():
|
||||||
|
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
||||||
|
parent_module = model.get_submodule(module_parent_key)
|
||||||
|
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
def _apply_smart_lora_patch(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
prefix: str,
|
||||||
|
patch: LoRAModelRaw,
|
||||||
|
patch_weight: float,
|
||||||
|
original_weights: OriginalWeightsStorage,
|
||||||
|
original_modules: dict[str, torch.nn.Module],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
|
||||||
|
patching or a sidecar wrapper for each module.
|
||||||
|
"""
|
||||||
|
if patch_weight == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||||
|
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||||
|
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||||
|
# without searching, but some legacy code still uses flattened keys.
|
||||||
|
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||||
|
|
||||||
|
prefix_len = len(prefix)
|
||||||
|
|
||||||
|
for layer_key, layer in patch.layers.items():
|
||||||
|
if not layer_key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_key, module = LoRAPatcher._get_submodule(
|
||||||
|
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decide whether to use direct patching or a sidecar wrapper.
|
||||||
|
# Direct patching is preferred, because it results in better runtime speed.
|
||||||
|
# Reasons to use sidecar patching:
|
||||||
|
# - The module is already wrapped in a LoRASidecarWrapper.
|
||||||
|
# - The module is quantized.
|
||||||
|
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
|
||||||
|
# CPU, since this would double the RAM usage)
|
||||||
|
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
|
||||||
|
# and that the caller will use the 'apply_lora_wrapper_patches' method if the layer is quantized.
|
||||||
|
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
|
||||||
|
# forcing full patching even on the CPU?
|
||||||
|
if isinstance(module, LoRASidecarWrapper) or LoRAPatcher._is_any_part_of_layer_on_cpu(module):
|
||||||
|
LoRAPatcher._apply_lora_layer_wrapper_patch(
|
||||||
|
model=model,
|
||||||
|
module_to_patch=module,
|
||||||
|
module_to_patch_key=module_key,
|
||||||
|
patch=layer,
|
||||||
|
patch_weight=patch_weight,
|
||||||
|
original_modules=original_modules,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LoRAPatcher._apply_lora_layer_patch(
|
||||||
|
module_to_patch=module,
|
||||||
|
module_to_patch_key=module_key,
|
||||||
|
patch=layer,
|
||||||
|
patch_weight=patch_weight,
|
||||||
|
original_weights=original_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
|
||||||
|
return any(p.device.type == "cpu" for p in layer.parameters())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -40,7 +147,7 @@ class LoRAPatcher:
|
|||||||
original_weights = OriginalWeightsStorage(cached_weights)
|
original_weights = OriginalWeightsStorage(cached_weights)
|
||||||
try:
|
try:
|
||||||
for patch, patch_weight in patches:
|
for patch, patch_weight in patches:
|
||||||
LoRAPatcher.apply_lora_patch(
|
LoRAPatcher._apply_lora_patch(
|
||||||
model=model,
|
model=model,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
patch=patch,
|
patch=patch,
|
||||||
@@ -52,11 +159,12 @@ class LoRAPatcher:
|
|||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
for param_key, weight in original_weights.get_changed_weights():
|
for param_key, weight in original_weights.get_changed_weights():
|
||||||
model.get_parameter(param_key).copy_(weight)
|
cur_param = model.get_parameter(param_key)
|
||||||
|
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def apply_lora_patch(
|
def _apply_lora_patch(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
patch: LoRAModelRaw,
|
patch: LoRAModelRaw,
|
||||||
@@ -91,48 +199,84 @@ class LoRAPatcher:
|
|||||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LoRAPatcher._apply_lora_layer_patch(
|
||||||
|
module_to_patch=module,
|
||||||
|
module_to_patch_key=module_key,
|
||||||
|
patch=layer,
|
||||||
|
patch_weight=patch_weight,
|
||||||
|
original_weights=original_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
def _apply_lora_layer_patch(
|
||||||
|
module_to_patch: torch.nn.Module,
|
||||||
|
module_to_patch_key: str,
|
||||||
|
patch: AnyLoRALayer,
|
||||||
|
patch_weight: float,
|
||||||
|
original_weights: OriginalWeightsStorage,
|
||||||
|
):
|
||||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||||
# (Performance will be best if this is a CUDA device.)
|
# (Performance will be best if this is a CUDA device.)
|
||||||
device = module.weight.device
|
first_param = next(module_to_patch.parameters())
|
||||||
dtype = module.weight.dtype
|
device = first_param.device
|
||||||
|
dtype = first_param.dtype
|
||||||
|
|
||||||
layer_scale = layer.scale()
|
layer_scale = patch.scale()
|
||||||
|
|
||||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||||
# same thing in a single call to '.to(...)'.
|
# same thing in a single call to '.to(...)'.
|
||||||
layer.to(device=device)
|
patch.to(device=device)
|
||||||
layer.to(dtype=torch.float32)
|
patch.to(dtype=torch.float32)
|
||||||
|
|
||||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
for param_name, lora_param_weight in patch.get_parameters(module_to_patch).items():
|
||||||
param_key = module_key + "." + param_name
|
param_key = module_to_patch_key + "." + param_name
|
||||||
module_param = module.get_parameter(param_name)
|
module_param = module_to_patch.get_parameter(param_name)
|
||||||
|
|
||||||
# Save original weight
|
# Save original weight
|
||||||
original_weights.save(param_key, module_param)
|
original_weights.save(param_key, module_param)
|
||||||
|
|
||||||
if module_param.shape != lora_param_weight.shape:
|
if module_param.shape != lora_param_weight.shape:
|
||||||
|
if module_param.nelement() == lora_param_weight.nelement():
|
||||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||||
|
else:
|
||||||
|
# This condition was added to handle layers in FLUX control LoRAs.
|
||||||
|
# TODO(ryand): Move the weight update into the LoRA layer so that the LoRAPatcher doesn't need
|
||||||
|
# to worry about this?
|
||||||
|
expanded_weight = torch.zeros_like(
|
||||||
|
lora_param_weight, dtype=module_param.dtype, device=module_param.device
|
||||||
|
)
|
||||||
|
slices = tuple(slice(0, dim) for dim in module_param.shape)
|
||||||
|
expanded_weight[slices] = module_param
|
||||||
|
setattr(
|
||||||
|
module,
|
||||||
|
param_name,
|
||||||
|
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
|
||||||
|
)
|
||||||
|
module_param = expanded_weight
|
||||||
|
|
||||||
lora_param_weight *= patch_weight * layer_scale
|
lora_param_weight *= patch_weight * layer_scale
|
||||||
module_param += lora_param_weight.to(dtype=dtype)
|
module_param += lora_param_weight.to(dtype=dtype)
|
||||||
|
|
||||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
patch.to(device=TorchDevice.CPU_DEVICE)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_lora_sidecar_patches(
|
def apply_lora_wrapper_patches(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
|
"""Apply one or more LoRA wrapper patches to a model within a context manager. Wrapper patches incur some
|
||||||
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
|
runtime overhead compared to normal LoRA patching, but they enable:
|
||||||
quantization format.
|
- LoRA layers to be applied to quantized models
|
||||||
|
- LoRA layers to be applied to CPU layers without needing to store a full copy of the original weights (i.e.
|
||||||
|
avoid doubling the memory requirements).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (torch.nn.Module): The model to patch.
|
model (torch.nn.Module): The model to patch.
|
||||||
@@ -140,14 +284,11 @@ class LoRAPatcher:
|
|||||||
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
|
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
|
||||||
all at once.
|
all at once.
|
||||||
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
|
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
|
||||||
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
|
|
||||||
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
|
|
||||||
different from their compute dtype.
|
|
||||||
"""
|
"""
|
||||||
original_modules: dict[str, torch.nn.Module] = {}
|
original_modules: dict[str, torch.nn.Module] = {}
|
||||||
try:
|
try:
|
||||||
for patch, patch_weight in patches:
|
for patch, patch_weight in patches:
|
||||||
LoRAPatcher._apply_lora_sidecar_patch(
|
LoRAPatcher._apply_lora_wrapper_patch(
|
||||||
model=model,
|
model=model,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
patch=patch,
|
patch=patch,
|
||||||
@@ -165,7 +306,7 @@ class LoRAPatcher:
|
|||||||
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
|
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _apply_lora_sidecar_patch(
|
def _apply_lora_wrapper_patch(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
patch: LoRAModelRaw,
|
patch: LoRAModelRaw,
|
||||||
patch_weight: float,
|
patch_weight: float,
|
||||||
@@ -173,7 +314,7 @@ class LoRAPatcher:
|
|||||||
original_modules: dict[str, torch.nn.Module],
|
original_modules: dict[str, torch.nn.Module],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
"""Apply a single LoRA sidecar patch to a model."""
|
"""Apply a single LoRA wrapper patch to a model."""
|
||||||
|
|
||||||
if patch_weight == 0:
|
if patch_weight == 0:
|
||||||
return
|
return
|
||||||
@@ -194,28 +335,47 @@ class LoRAPatcher:
|
|||||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the LoRA sidecar layer.
|
LoRAPatcher._apply_lora_layer_wrapper_patch(
|
||||||
lora_sidecar_layer = LoRAPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
|
model=model,
|
||||||
|
module_to_patch=module,
|
||||||
|
module_to_patch_key=module_key,
|
||||||
|
patch=layer,
|
||||||
|
patch_weight=patch_weight,
|
||||||
|
original_modules=original_modules,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
# Replace the original module with a LoRASidecarModule if it has not already been done.
|
@staticmethod
|
||||||
if module_key in original_modules:
|
@torch.no_grad()
|
||||||
# The module has already been patched with a LoRASidecarModule. Append to it.
|
def _apply_lora_layer_wrapper_patch(
|
||||||
assert isinstance(module, LoRASidecarModule)
|
model: torch.nn.Module,
|
||||||
lora_sidecar_module = module
|
module_to_patch: torch.nn.Module,
|
||||||
else:
|
module_to_patch_key: str,
|
||||||
# The module has not yet been patched with a LoRASidecarModule. Create one.
|
patch: AnyLoRALayer,
|
||||||
lora_sidecar_module = LoRASidecarModule(module, [])
|
patch_weight: float,
|
||||||
original_modules[module_key] = module
|
original_modules: dict[str, torch.nn.Module],
|
||||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
"""Apply a single LoRA wrapper patch to a model."""
|
||||||
|
|
||||||
|
# Replace the original module with a LoRASidecarWrapper if it has not already been done.
|
||||||
|
if not isinstance(module_to_patch, LoRASidecarWrapper):
|
||||||
|
lora_wrapper_layer = LoRAPatcher._initialize_lora_wrapper_layer(module_to_patch)
|
||||||
|
original_modules[module_to_patch_key] = module_to_patch
|
||||||
|
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_to_patch_key)
|
||||||
module_parent = model.get_submodule(module_parent_key)
|
module_parent = model.get_submodule(module_parent_key)
|
||||||
LoRAPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
|
LoRAPatcher._set_submodule(module_parent, module_name, lora_wrapper_layer)
|
||||||
|
orig_module = module_to_patch
|
||||||
|
else:
|
||||||
|
assert module_to_patch_key in original_modules
|
||||||
|
lora_wrapper_layer = module_to_patch
|
||||||
|
orig_module = module_to_patch.orig_module
|
||||||
|
|
||||||
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
|
# Move the LoRA layer to the same device/dtype as the orig module.
|
||||||
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
|
patch.to(device=orig_module.weight.device, dtype=dtype)
|
||||||
lora_sidecar_layer.to(device=lora_sidecar_module.orig_module.weight.device, dtype=dtype)
|
|
||||||
|
|
||||||
# Add the LoRA sidecar layer to the LoRASidecarModule.
|
# Add the LoRA wrapper layer to the LoRASidecarWrapper.
|
||||||
lora_sidecar_module.add_lora_layer(lora_sidecar_layer)
|
lora_wrapper_layer.add_lora_layer(patch, patch_weight)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _split_parent_key(module_key: str) -> tuple[str, str]:
|
def _split_parent_key(module_key: str) -> tuple[str, str]:
|
||||||
@@ -236,17 +396,13 @@ class LoRAPatcher:
|
|||||||
raise ValueError(f"Invalid module key: {module_key}")
|
raise ValueError(f"Invalid module key: {module_key}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
|
def _initialize_lora_wrapper_layer(orig_layer: torch.nn.Module):
|
||||||
# TODO(ryand): Add support for more original layer types and LoRA layer types.
|
if isinstance(orig_layer, torch.nn.Linear):
|
||||||
if isinstance(orig_layer, torch.nn.Linear) or (
|
return LoRALinearWrapper(orig_layer, [], [])
|
||||||
isinstance(orig_layer, LoRASidecarModule) and isinstance(orig_layer.orig_module, torch.nn.Linear)
|
elif isinstance(orig_layer, torch.nn.Conv1d):
|
||||||
):
|
return LoRAConv1dWrapper(orig_layer, [], [])
|
||||||
if isinstance(lora_layer, LoRALayer):
|
elif isinstance(orig_layer, torch.nn.Conv2d):
|
||||||
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
|
return LoRAConv2dWrapper(orig_layer, [], [])
|
||||||
elif isinstance(lora_layer, ConcatenatedLoRALayer):
|
|
||||||
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
|
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
|
||||||
|
|
||||||
|
|
||||||
class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
concatenated_lora_layer: ConcatenatedLoRALayer,
|
|
||||||
weight: float,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self._concatenated_lora_layer = concatenated_lora_layer
|
|
||||||
self._weight = weight
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
x_chunks: list[torch.Tensor] = []
|
|
||||||
for lora_layer in self._concatenated_lora_layer.lora_layers:
|
|
||||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
|
||||||
if lora_layer.mid is not None:
|
|
||||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
|
||||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
|
||||||
x_chunk *= self._weight * lora_layer.scale()
|
|
||||||
x_chunks.append(x_chunk)
|
|
||||||
|
|
||||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
|
||||||
assert self._concatenated_lora_layer.concat_axis == 0
|
|
||||||
x = torch.cat(x_chunks, dim=-1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
|
||||||
self._concatenated_lora_layer.to(device=device, dtype=dtype)
|
|
||||||
return self
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALinearSidecarLayer(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
lora_layer: LoRALayer,
|
|
||||||
weight: float,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self._lora_layer = lora_layer
|
|
||||||
self._weight = weight
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x = torch.nn.functional.linear(x, self._lora_layer.down)
|
|
||||||
if self._lora_layer.mid is not None:
|
|
||||||
x = torch.nn.functional.linear(x, self._lora_layer.mid)
|
|
||||||
x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias)
|
|
||||||
x *= self._weight * self._lora_layer.scale()
|
|
||||||
return x
|
|
||||||
|
|
||||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
|
||||||
self._lora_layer.to(device=device, dtype=dtype)
|
|
||||||
return self
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class LoRASidecarModule(torch.nn.Module):
|
|
||||||
"""A LoRA sidecar module that wraps an original module and adds LoRA layers to it."""
|
|
||||||
|
|
||||||
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
|
|
||||||
super().__init__()
|
|
||||||
self.orig_module = orig_module
|
|
||||||
self._lora_layers = lora_layers
|
|
||||||
|
|
||||||
def add_lora_layer(self, lora_layer: torch.nn.Module):
|
|
||||||
self._lora_layers.append(lora_layer)
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
x = self.orig_module(input)
|
|
||||||
for lora_layer in self._lora_layers:
|
|
||||||
x += lora_layer(input)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
|
||||||
self._orig_module.to(device=device, dtype=dtype)
|
|
||||||
for lora_layer in self._lora_layers:
|
|
||||||
lora_layer.to(device=device, dtype=dtype)
|
|
||||||
@@ -67,6 +67,7 @@ class ModelType(str, Enum):
|
|||||||
Main = "main"
|
Main = "main"
|
||||||
VAE = "vae"
|
VAE = "vae"
|
||||||
LoRA = "lora"
|
LoRA = "lora"
|
||||||
|
StructuralLoRa = "structural_lora"
|
||||||
ControlNet = "controlnet" # used by model_probe
|
ControlNet = "controlnet" # used by model_probe
|
||||||
TextualInversion = "embedding"
|
TextualInversion = "embedding"
|
||||||
IPAdapter = "ip_adapter"
|
IPAdapter = "ip_adapter"
|
||||||
@@ -273,6 +274,18 @@ class LoRALyCORISConfig(LoRAConfigBase):
|
|||||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||||
|
|
||||||
|
|
||||||
|
class StructuralLoRALyCORISConfig(ModelConfigBase):
|
||||||
|
"""Model config for Structural LoRA/Lycoris models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.StructuralLoRa] = ModelType.StructuralLoRa
|
||||||
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||||
|
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.StructuralLoRa.value}.{ModelFormat.LyCORIS.value}")
|
||||||
|
|
||||||
|
|
||||||
class LoRADiffusersConfig(LoRAConfigBase):
|
class LoRADiffusersConfig(LoRAConfigBase):
|
||||||
"""Model config for LoRA/Diffusers models."""
|
"""Model config for LoRA/Diffusers models."""
|
||||||
|
|
||||||
@@ -535,6 +548,7 @@ AnyModelConfig = Annotated[
|
|||||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||||
|
Annotated[StructuralLoRALyCORISConfig, StructuralLoRALyCORISConfig.get_tag()],
|
||||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||||
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
||||||
|
|||||||
@@ -13,8 +13,9 @@ from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils impo
|
|||||||
lora_model_from_flux_diffusers_state_dict,
|
lora_model_from_flux_diffusers_state_dict,
|
||||||
)
|
)
|
||||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||||
lora_model_from_flux_kohya_state_dict,
|
is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control, lora_model_from_flux_control_state_dict
|
||||||
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||||
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
@@ -32,6 +33,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
|
|||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.StructuralLoRa, format=ModelFormat.LyCORIS)
|
||||||
class LoRALoader(ModelLoader):
|
class LoRALoader(ModelLoader):
|
||||||
"""Class to load LoRA models."""
|
"""Class to load LoRA models."""
|
||||||
|
|
||||||
@@ -75,7 +77,10 @@ class LoRALoader(ModelLoader):
|
|||||||
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
|
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
|
||||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
|
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
|
||||||
elif config.format == ModelFormat.LyCORIS:
|
elif config.format == ModelFormat.LyCORIS:
|
||||||
|
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
|
||||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||||
|
elif is_state_dict_likely_flux_control(state_dict=state_dict):
|
||||||
|
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlab
|
|||||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||||
is_state_dict_likely_in_flux_diffusers_format,
|
is_state_dict_likely_in_flux_diffusers_format,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
@@ -258,6 +259,18 @@ class ModelProbe(object):
|
|||||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||||
ckpt = ckpt.get("state_dict", ckpt)
|
ckpt = ckpt.get("state_dict", ckpt)
|
||||||
|
|
||||||
|
if isinstance(ckpt, dict) and "img_in.lora_A.weight" in ckpt and "img_in.lora_B.weight" in ckpt:
|
||||||
|
tensor_a, tensor_b = ckpt["img_in.lora_A.weight"], ckpt["img_in.lora_B.weight"]
|
||||||
|
if (
|
||||||
|
tensor_a is not None
|
||||||
|
and isinstance(tensor_a, torch.Tensor)
|
||||||
|
and tensor_a.shape[1] == 128
|
||||||
|
and tensor_b is not None
|
||||||
|
and isinstance(tensor_b, torch.Tensor)
|
||||||
|
and tensor_b.shape[0] == 3072
|
||||||
|
):
|
||||||
|
return ModelType.StructuralLoRa
|
||||||
|
|
||||||
for key in [str(k) for k in ckpt.keys()]:
|
for key in [str(k) for k in ckpt.keys()]:
|
||||||
if key.startswith(
|
if key.startswith(
|
||||||
(
|
(
|
||||||
@@ -624,8 +637,10 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
return ModelFormat.LyCORIS
|
return ModelFormat.LyCORIS
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
|
if (
|
||||||
self.checkpoint
|
is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
|
||||||
|
or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
|
||||||
|
or is_state_dict_likely_flux_control(self.checkpoint)
|
||||||
):
|
):
|
||||||
return BaseModelType.Flux
|
return BaseModelType.Flux
|
||||||
|
|
||||||
@@ -1046,6 +1061,7 @@ ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelI
|
|||||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.StructuralLoRa, LoRACheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||||
|
|||||||
@@ -809,6 +809,7 @@
|
|||||||
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",
|
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",
|
||||||
"starterModels": "Starter Models",
|
"starterModels": "Starter Models",
|
||||||
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||||
|
"structuralLora": "Structural LoRA",
|
||||||
"syncModels": "Sync Models",
|
"syncModels": "Sync Models",
|
||||||
"textualInversions": "Textual Inversions",
|
"textualInversions": "Textual Inversions",
|
||||||
"triggerPhrases": "Trigger Phrases",
|
"triggerPhrases": "Trigger Phrases",
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import type {
|
|||||||
ParameterSeed,
|
ParameterSeed,
|
||||||
ParameterSteps,
|
ParameterSteps,
|
||||||
ParameterStrength,
|
ParameterStrength,
|
||||||
|
ParameterStructuralLoRAModel,
|
||||||
ParameterT5EncoderModel,
|
ParameterT5EncoderModel,
|
||||||
ParameterVAEModel,
|
ParameterVAEModel,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
@@ -75,6 +76,7 @@ export type ParamsState = {
|
|||||||
clipEmbedModel: ParameterCLIPEmbedModel | null;
|
clipEmbedModel: ParameterCLIPEmbedModel | null;
|
||||||
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
|
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
|
||||||
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
|
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
|
||||||
|
structuralLora: ParameterStructuralLoRAModel | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
const initialState: ParamsState = {
|
const initialState: ParamsState = {
|
||||||
@@ -121,6 +123,7 @@ const initialState: ParamsState = {
|
|||||||
clipEmbedModel: null,
|
clipEmbedModel: null,
|
||||||
clipLEmbedModel: null,
|
clipLEmbedModel: null,
|
||||||
clipGEmbedModel: null,
|
clipGEmbedModel: null,
|
||||||
|
structuralLora: null,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const paramsSlice = createSlice({
|
export const paramsSlice = createSlice({
|
||||||
@@ -195,6 +198,9 @@ export const paramsSlice = createSlice({
|
|||||||
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
|
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
|
||||||
state.t5EncoderModel = action.payload;
|
state.t5EncoderModel = action.payload;
|
||||||
},
|
},
|
||||||
|
structuralLoRAModelSelected: (state, action: PayloadAction<ParameterStructuralLoRAModel | null>) => {
|
||||||
|
state.structuralLora = action.payload;
|
||||||
|
},
|
||||||
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
|
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
|
||||||
state.clipEmbedModel = action.payload;
|
state.clipEmbedModel = action.payload;
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ import type {
|
|||||||
ParameterSeed,
|
ParameterSeed,
|
||||||
ParameterSteps,
|
ParameterSteps,
|
||||||
ParameterStrength,
|
ParameterStrength,
|
||||||
|
ParameterStructuralLoRAModel,
|
||||||
ParameterVAEModel,
|
ParameterVAEModel,
|
||||||
ParameterWidth,
|
ParameterWidth,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
@@ -80,6 +81,7 @@ import {
|
|||||||
isLoRAModelConfig,
|
isLoRAModelConfig,
|
||||||
isNonRefinerMainModelConfig,
|
isNonRefinerMainModelConfig,
|
||||||
isRefinerMainModelModelConfig,
|
isRefinerMainModelModelConfig,
|
||||||
|
isStructuralLoRAModelConfig,
|
||||||
isT2IAdapterModelConfig,
|
isT2IAdapterModelConfig,
|
||||||
isVAEModelConfig,
|
isVAEModelConfig,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
@@ -226,6 +228,14 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
|
|||||||
return modelIdentifier;
|
return modelIdentifier;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const parseStructuralLoRAModel: MetadataParseFunc<ParameterStructuralLoRAModel> = async (metadata) => {
|
||||||
|
const slora = await getProperty(metadata, 'structural_lora', undefined);
|
||||||
|
const key = await getModelKey(slora, 'structural_lora');
|
||||||
|
const sloraModelConfig = await fetchModelConfigWithTypeGuard(key, isStructuralLoRAModelConfig);
|
||||||
|
const modelIdentifier = zModelIdentifierField.parse(sloraModelConfig);
|
||||||
|
return modelIdentifier;
|
||||||
|
};
|
||||||
|
|
||||||
const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
||||||
// Previously, the LoRA model identifier parts were stored in the LoRA metadata: `{key: ..., weight: 0.75}`
|
// Previously, the LoRA model identifier parts were stored in the LoRA metadata: `{key: ..., weight: 0.75}`
|
||||||
const modelV1 = await getProperty(metadataItem, 'lora', undefined);
|
const modelV1 = await getProperty(metadataItem, 'lora', undefined);
|
||||||
@@ -671,6 +681,7 @@ export const parsers = {
|
|||||||
mainModel: parseMainModel,
|
mainModel: parseMainModel,
|
||||||
refinerModel: parseRefinerModel,
|
refinerModel: parseRefinerModel,
|
||||||
vaeModel: parseVAEModel,
|
vaeModel: parseVAEModel,
|
||||||
|
structuralLora: parseStructuralLoRAModel,
|
||||||
lora: parseLoRA,
|
lora: parseLoRA,
|
||||||
loras: parseAllLoRAs,
|
loras: parseAllLoRAs,
|
||||||
controlNet: parseControlNet,
|
controlNet: parseControlNet,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import {
|
|||||||
useMainModels,
|
useMainModels,
|
||||||
useRefinerModels,
|
useRefinerModels,
|
||||||
useSpandrelImageToImageModels,
|
useSpandrelImageToImageModels,
|
||||||
|
useStructuralLoRAModel,
|
||||||
useT2IAdapterModels,
|
useT2IAdapterModels,
|
||||||
useT5EncoderModels,
|
useT5EncoderModels,
|
||||||
useVAEModels,
|
useVAEModels,
|
||||||
@@ -92,6 +93,12 @@ const ModelList = () => {
|
|||||||
[t5EncoderModels, searchTerm, filteredModelType]
|
[t5EncoderModels, searchTerm, filteredModelType]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const [structuralLoRAModels, { isLoading: isLoadingStructuralLoRAModels }] = useStructuralLoRAModel();
|
||||||
|
const filteredStructuralLoRAModels = useMemo(
|
||||||
|
() => modelsFilter(structuralLoRAModels, searchTerm, filteredModelType),
|
||||||
|
[structuralLoRAModels, searchTerm, filteredModelType]
|
||||||
|
);
|
||||||
|
|
||||||
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true });
|
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true });
|
||||||
const filteredClipEmbedModels = useMemo(
|
const filteredClipEmbedModels = useMemo(
|
||||||
() => modelsFilter(clipEmbedModels, searchTerm, filteredModelType),
|
() => modelsFilter(clipEmbedModels, searchTerm, filteredModelType),
|
||||||
@@ -118,7 +125,8 @@ const ModelList = () => {
|
|||||||
filteredVAEModels.length +
|
filteredVAEModels.length +
|
||||||
filteredSpandrelImageToImageModels.length +
|
filteredSpandrelImageToImageModels.length +
|
||||||
t5EncoderModels.length +
|
t5EncoderModels.length +
|
||||||
clipEmbedModels.length
|
clipEmbedModels.length +
|
||||||
|
structuralLoRAModels.length
|
||||||
);
|
);
|
||||||
}, [
|
}, [
|
||||||
filteredControlNetModels.length,
|
filteredControlNetModels.length,
|
||||||
@@ -133,6 +141,7 @@ const ModelList = () => {
|
|||||||
filteredSpandrelImageToImageModels.length,
|
filteredSpandrelImageToImageModels.length,
|
||||||
t5EncoderModels.length,
|
t5EncoderModels.length,
|
||||||
clipEmbedModels.length,
|
clipEmbedModels.length,
|
||||||
|
structuralLoRAModels.length,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -195,6 +204,15 @@ const ModelList = () => {
|
|||||||
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
|
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
|
||||||
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
|
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
|
||||||
)}
|
)}
|
||||||
|
{/* Structural Lora List */}
|
||||||
|
{isLoadingStructuralLoRAModels && <FetchingModelsLoader loadingMessage="Loading Structural Loras..." />}
|
||||||
|
{!isLoadingStructuralLoRAModels && filteredStructuralLoRAModels.length > 0 && (
|
||||||
|
<ModelListWrapper
|
||||||
|
title={t('modelManager.structuralLora')}
|
||||||
|
modelList={filteredStructuralLoRAModels}
|
||||||
|
key="structural-lora"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
{/* Clip Embed List */}
|
{/* Clip Embed List */}
|
||||||
{isLoadingClipEmbedModels && <FetchingModelsLoader loadingMessage="Loading Clip Embed Models..." />}
|
{isLoadingClipEmbedModels && <FetchingModelsLoader loadingMessage="Loading Clip Embed Models..." />}
|
||||||
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (
|
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ export const ModelTypeFilter = memo(() => {
|
|||||||
ip_adapter: t('common.ipAdapter'),
|
ip_adapter: t('common.ipAdapter'),
|
||||||
clip_vision: 'CLIP Vision',
|
clip_vision: 'CLIP Vision',
|
||||||
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
|
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
|
||||||
|
structural_lora: t('modelManager.structuralLora'),
|
||||||
}),
|
}),
|
||||||
[t]
|
[t]
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -51,6 +51,8 @@ import {
|
|||||||
isSpandrelImageToImageModelFieldInputTemplate,
|
isSpandrelImageToImageModelFieldInputTemplate,
|
||||||
isStringFieldInputInstance,
|
isStringFieldInputInstance,
|
||||||
isStringFieldInputTemplate,
|
isStringFieldInputTemplate,
|
||||||
|
isStructuralLoRAModelFieldInputInstance,
|
||||||
|
isStructuralLoRAModelFieldInputTemplate,
|
||||||
isT2IAdapterModelFieldInputInstance,
|
isT2IAdapterModelFieldInputInstance,
|
||||||
isT2IAdapterModelFieldInputTemplate,
|
isT2IAdapterModelFieldInputTemplate,
|
||||||
isT5EncoderModelFieldInputInstance,
|
isT5EncoderModelFieldInputInstance,
|
||||||
@@ -81,6 +83,7 @@ import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComp
|
|||||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||||
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
||||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||||
|
import StructuralLoRAModelFieldInputComponent from './inputs/StructuralLoraModelFieldInputComponent';
|
||||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||||
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
|
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
|
||||||
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||||
@@ -156,6 +159,15 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
isStructuralLoRAModelFieldInputInstance(fieldInstance) &&
|
||||||
|
isStructuralLoRAModelFieldInputTemplate(fieldTemplate)
|
||||||
|
) {
|
||||||
|
return (
|
||||||
|
<StructuralLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
|
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||||
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,65 @@
|
|||||||
|
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
|
import { fieldStructuralLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import type {
|
||||||
|
StructuralLoRAModelFieldInputInstance,
|
||||||
|
StructuralLoRAModelFieldInputTemplate,
|
||||||
|
} from 'features/nodes/types/field';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useStructuralLoRAModel } from 'services/api/hooks/modelsByType';
|
||||||
|
import { isStructuralLoRAModelConfig, type StructuralLoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
import type { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
type Props = FieldComponentProps<StructuralLoRAModelFieldInputInstance, StructuralLoRAModelFieldInputTemplate>;
|
||||||
|
|
||||||
|
const StructuralLoRAModelFieldInputComponent = (props: Props) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const [modelConfigs, { isLoading }] = useStructuralLoRAModel();
|
||||||
|
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(value: StructuralLoRAModelConfig | null) => {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
fieldStructuralLoRAModelValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
|
modelConfigs: modelConfigs.filter((config) => isStructuralLoRAModelConfig(config)),
|
||||||
|
onChange: _onChange,
|
||||||
|
isLoading,
|
||||||
|
selectedModel: field.value,
|
||||||
|
});
|
||||||
|
const required = props.fieldTemplate.required;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="full" alignItems="center" gap={2}>
|
||||||
|
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||||
|
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
|
||||||
|
<Combobox
|
||||||
|
value={value}
|
||||||
|
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
|
||||||
|
options={options}
|
||||||
|
onChange={onChange}
|
||||||
|
noOptionsMessage={noOptionsMessage}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</Tooltip>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(StructuralLoRAModelFieldInputComponent);
|
||||||
@@ -28,6 +28,7 @@ import type {
|
|||||||
SpandrelImageToImageModelFieldValue,
|
SpandrelImageToImageModelFieldValue,
|
||||||
StatefulFieldValue,
|
StatefulFieldValue,
|
||||||
StringFieldValue,
|
StringFieldValue,
|
||||||
|
StructuralLoRAModelFieldValue,
|
||||||
T2IAdapterModelFieldValue,
|
T2IAdapterModelFieldValue,
|
||||||
T5EncoderModelFieldValue,
|
T5EncoderModelFieldValue,
|
||||||
VAEModelFieldValue,
|
VAEModelFieldValue,
|
||||||
@@ -55,6 +56,7 @@ import {
|
|||||||
zSpandrelImageToImageModelFieldValue,
|
zSpandrelImageToImageModelFieldValue,
|
||||||
zStatefulFieldValue,
|
zStatefulFieldValue,
|
||||||
zStringFieldValue,
|
zStringFieldValue,
|
||||||
|
zStructuralLoRAModelFieldValue,
|
||||||
zT2IAdapterModelFieldValue,
|
zT2IAdapterModelFieldValue,
|
||||||
zT5EncoderModelFieldValue,
|
zT5EncoderModelFieldValue,
|
||||||
zVAEModelFieldValue,
|
zVAEModelFieldValue,
|
||||||
@@ -369,6 +371,9 @@ export const nodesSlice = createSlice({
|
|||||||
fieldCLIPGEmbedValueChanged: (state, action: FieldValueAction<CLIPGEmbedModelFieldValue>) => {
|
fieldCLIPGEmbedValueChanged: (state, action: FieldValueAction<CLIPGEmbedModelFieldValue>) => {
|
||||||
fieldValueReducer(state, action, zCLIPGEmbedModelFieldValue);
|
fieldValueReducer(state, action, zCLIPGEmbedModelFieldValue);
|
||||||
},
|
},
|
||||||
|
fieldStructuralLoRAModelValueChanged: (state, action: FieldValueAction<StructuralLoRAModelFieldValue>) => {
|
||||||
|
fieldValueReducer(state, action, zStructuralLoRAModelFieldValue);
|
||||||
|
},
|
||||||
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
|
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
|
||||||
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
|
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
|
||||||
},
|
},
|
||||||
@@ -438,6 +443,7 @@ export const {
|
|||||||
fieldCLIPEmbedValueChanged,
|
fieldCLIPEmbedValueChanged,
|
||||||
fieldCLIPLEmbedValueChanged,
|
fieldCLIPLEmbedValueChanged,
|
||||||
fieldCLIPGEmbedValueChanged,
|
fieldCLIPGEmbedValueChanged,
|
||||||
|
fieldStructuralLoRAModelValueChanged,
|
||||||
fieldFluxVAEModelValueChanged,
|
fieldFluxVAEModelValueChanged,
|
||||||
nodeEditorReset,
|
nodeEditorReset,
|
||||||
nodeIsIntermediateChanged,
|
nodeIsIntermediateChanged,
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ const zModelType = z.enum([
|
|||||||
'main',
|
'main',
|
||||||
'vae',
|
'vae',
|
||||||
'lora',
|
'lora',
|
||||||
|
'structural_lora',
|
||||||
'controlnet',
|
'controlnet',
|
||||||
't2i_adapter',
|
't2i_adapter',
|
||||||
'ip_adapter',
|
'ip_adapter',
|
||||||
|
|||||||
@@ -178,6 +178,10 @@ const zCLIPGEmbedModelFieldType = zFieldTypeBase.extend({
|
|||||||
name: z.literal('CLIPGEmbedModelField'),
|
name: z.literal('CLIPGEmbedModelField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
});
|
});
|
||||||
|
const zStructuralLoRAModelFieldType = zFieldTypeBase.extend({
|
||||||
|
name: z.literal('StructuralLoRAModelField'),
|
||||||
|
originalType: zStatelessFieldType.optional(),
|
||||||
|
});
|
||||||
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
|
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
|
||||||
name: z.literal('FluxVAEModelField'),
|
name: z.literal('FluxVAEModelField'),
|
||||||
originalType: zStatelessFieldType.optional(),
|
originalType: zStatelessFieldType.optional(),
|
||||||
@@ -210,6 +214,7 @@ const zStatefulFieldType = z.union([
|
|||||||
zCLIPEmbedModelFieldType,
|
zCLIPEmbedModelFieldType,
|
||||||
zCLIPLEmbedModelFieldType,
|
zCLIPLEmbedModelFieldType,
|
||||||
zCLIPGEmbedModelFieldType,
|
zCLIPGEmbedModelFieldType,
|
||||||
|
zStructuralLoRAModelFieldType,
|
||||||
zFluxVAEModelFieldType,
|
zFluxVAEModelFieldType,
|
||||||
zColorFieldType,
|
zColorFieldType,
|
||||||
zSchedulerFieldType,
|
zSchedulerFieldType,
|
||||||
@@ -864,6 +869,29 @@ export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGE
|
|||||||
|
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region StructuralLoRAModelField
|
||||||
|
|
||||||
|
export const zStructuralLoRAModelFieldValue = zModelIdentifierField.optional();
|
||||||
|
const zStructuralLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||||
|
value: zStructuralLoRAModelFieldValue,
|
||||||
|
});
|
||||||
|
const zStructuralLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||||
|
type: zStructuralLoRAModelFieldType,
|
||||||
|
originalType: zFieldType.optional(),
|
||||||
|
default: zStructuralLoRAModelFieldValue,
|
||||||
|
});
|
||||||
|
|
||||||
|
export type StructuralLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
|
||||||
|
|
||||||
|
export type StructuralLoRAModelFieldInputInstance = z.infer<typeof zStructuralLoRAModelFieldInputInstance>;
|
||||||
|
export type StructuralLoRAModelFieldInputTemplate = z.infer<typeof zStructuralLoRAModelFieldInputTemplate>;
|
||||||
|
export const isStructuralLoRAModelFieldInputInstance = (val: unknown): val is StructuralLoRAModelFieldInputInstance =>
|
||||||
|
zStructuralLoRAModelFieldInputInstance.safeParse(val).success;
|
||||||
|
export const isStructuralLoRAModelFieldInputTemplate = (val: unknown): val is StructuralLoRAModelFieldInputTemplate =>
|
||||||
|
zStructuralLoRAModelFieldInputTemplate.safeParse(val).success;
|
||||||
|
|
||||||
|
// #endregion
|
||||||
|
|
||||||
// #region SchedulerField
|
// #region SchedulerField
|
||||||
|
|
||||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||||
@@ -959,6 +987,7 @@ export const zStatefulFieldValue = z.union([
|
|||||||
zCLIPEmbedModelFieldValue,
|
zCLIPEmbedModelFieldValue,
|
||||||
zCLIPLEmbedModelFieldValue,
|
zCLIPLEmbedModelFieldValue,
|
||||||
zCLIPGEmbedModelFieldValue,
|
zCLIPGEmbedModelFieldValue,
|
||||||
|
zStructuralLoRAModelFieldValue,
|
||||||
zColorFieldValue,
|
zColorFieldValue,
|
||||||
zSchedulerFieldValue,
|
zSchedulerFieldValue,
|
||||||
]);
|
]);
|
||||||
@@ -1030,6 +1059,7 @@ const zStatefulFieldInputTemplate = z.union([
|
|||||||
zCLIPEmbedModelFieldInputTemplate,
|
zCLIPEmbedModelFieldInputTemplate,
|
||||||
zCLIPLEmbedModelFieldInputTemplate,
|
zCLIPLEmbedModelFieldInputTemplate,
|
||||||
zCLIPGEmbedModelFieldInputTemplate,
|
zCLIPGEmbedModelFieldInputTemplate,
|
||||||
|
zStructuralLoRAModelFieldInputTemplate,
|
||||||
zColorFieldInputTemplate,
|
zColorFieldInputTemplate,
|
||||||
zSchedulerFieldInputTemplate,
|
zSchedulerFieldInputTemplate,
|
||||||
zStatelessFieldInputTemplate,
|
zStatelessFieldInputTemplate,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
|||||||
CLIPEmbedModelField: undefined,
|
CLIPEmbedModelField: undefined,
|
||||||
CLIPLEmbedModelField: undefined,
|
CLIPLEmbedModelField: undefined,
|
||||||
CLIPGEmbedModelField: undefined,
|
CLIPGEmbedModelField: undefined,
|
||||||
|
StructuralLoRAModelField: undefined,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import type {
|
|||||||
StatefulFieldType,
|
StatefulFieldType,
|
||||||
StatelessFieldInputTemplate,
|
StatelessFieldInputTemplate,
|
||||||
StringFieldInputTemplate,
|
StringFieldInputTemplate,
|
||||||
|
StructuralLoRAModelFieldInputTemplate,
|
||||||
T2IAdapterModelFieldInputTemplate,
|
T2IAdapterModelFieldInputTemplate,
|
||||||
T5EncoderModelFieldInputTemplate,
|
T5EncoderModelFieldInputTemplate,
|
||||||
VAEModelFieldInputTemplate,
|
VAEModelFieldInputTemplate,
|
||||||
@@ -300,6 +301,20 @@ const buildCLIPGEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPGEmb
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildStructuralLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<StructuralLoRAModelFieldInputTemplate> = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
fieldType,
|
||||||
|
}) => {
|
||||||
|
const template: StructuralLoRAModelFieldInputTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: fieldType,
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
|
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@@ -526,6 +541,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
|||||||
CLIPLEmbedModelField: buildCLIPLEmbedModelFieldInputTemplate,
|
CLIPLEmbedModelField: buildCLIPLEmbedModelFieldInputTemplate,
|
||||||
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
|
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
|
||||||
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
|
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
|
||||||
|
StructuralLoRAModelField: buildStructuralLoRAModelFieldInputTemplate,
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
export const buildFieldInputTemplate = (
|
export const buildFieldInputTemplate = (
|
||||||
|
|||||||
@@ -113,6 +113,11 @@ export const zParameterVAEModel = zModelIdentifierField;
|
|||||||
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
|
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region Structural Lora Model
|
||||||
|
export const zParameterStructuralLoRAModel = zModelIdentifierField;
|
||||||
|
export type ParameterStructuralLoRAModel = z.infer<typeof zParameterStructuralLoRAModel>;
|
||||||
|
// #endregion
|
||||||
|
|
||||||
// #region T5Encoder Model
|
// #region T5Encoder Model
|
||||||
export const zParameterT5EncoderModel = zModelIdentifierField;
|
export const zParameterT5EncoderModel = zModelIdentifierField;
|
||||||
export type ParameterT5EncoderModel = z.infer<typeof zParameterT5EncoderModel>;
|
export type ParameterT5EncoderModel = z.infer<typeof zParameterT5EncoderModel>;
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import {
|
|||||||
isSD3MainModelModelConfig,
|
isSD3MainModelModelConfig,
|
||||||
isSDXLMainModelModelConfig,
|
isSDXLMainModelModelConfig,
|
||||||
isSpandrelImageToImageModelConfig,
|
isSpandrelImageToImageModelConfig,
|
||||||
|
isStructuralLoRAModelConfig,
|
||||||
isT2IAdapterModelConfig,
|
isT2IAdapterModelConfig,
|
||||||
isT5EncoderModelConfig,
|
isT5EncoderModelConfig,
|
||||||
isTIModelConfig,
|
isTIModelConfig,
|
||||||
@@ -58,6 +59,7 @@ export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
|
|||||||
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
|
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
|
||||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||||
|
export const useStructuralLoRAModel = buildModelsHook(isStructuralLoRAModelConfig);
|
||||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||||
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -44,6 +44,7 @@ export type BaseModelType = S['BaseModelType'];
|
|||||||
|
|
||||||
// Model Configs
|
// Model Configs
|
||||||
|
|
||||||
|
export type StructuralLoRAModelConfig = S['StructuralLoRALyCORISConfig'];
|
||||||
// TODO(MM2): Can we make key required in the pydantic model?
|
// TODO(MM2): Can we make key required in the pydantic model?
|
||||||
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
|
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
|
||||||
// TODO(MM2): Can we rename this from Vae -> VAE
|
// TODO(MM2): Can we rename this from Vae -> VAE
|
||||||
@@ -63,6 +64,7 @@ export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
|||||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||||
export type AnyModelConfig =
|
export type AnyModelConfig =
|
||||||
|
| StructuralLoRAModelConfig
|
||||||
| LoRAModelConfig
|
| LoRAModelConfig
|
||||||
| VAEModelConfig
|
| VAEModelConfig
|
||||||
| ControlNetModelConfig
|
| ControlNetModelConfig
|
||||||
@@ -114,6 +116,10 @@ export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelCo
|
|||||||
return config.type === 'lora';
|
return config.type === 'lora';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const isStructuralLoRAModelConfig = (config: AnyModelConfig): config is StructuralLoRAModelConfig => {
|
||||||
|
return config.type === 'structural_lora';
|
||||||
|
};
|
||||||
|
|
||||||
export const isVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => {
|
export const isVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => {
|
||||||
return config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && checkSubmodels(['vae'], config));
|
return config.type === 'vae' || (!excludeSubmodels && config.type === 'main' && checkSubmodels(['vae'], config));
|
||||||
};
|
};
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,70 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.conversions.flux_control_lora_utils import (
|
||||||
|
is_state_dict_likely_flux_control,
|
||||||
|
lora_model_from_flux_control_state_dict,
|
||||||
|
)
|
||||||
|
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||||
|
from tests.backend.lora.conversions.lora_state_dicts.flux_control_lora_format import (
|
||||||
|
state_dict_keys as flux_control_lora_state_dict_keys,
|
||||||
|
)
|
||||||
|
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
|
||||||
|
state_dict_keys as flux_diffusers_state_dict_keys,
|
||||||
|
)
|
||||||
|
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sd_keys", [flux_control_lora_state_dict_keys])
|
||||||
|
def test_is_state_dict_likely_in_flux_control_format_true(sd_keys: dict[str, list[int]]):
|
||||||
|
"""Test that is_state_dict_likely_flux_control() can identify a state dict in the FLUX Control LoRA format."""
|
||||||
|
# Construct a state dict that is in the Diffusers FLUX LoRA format.
|
||||||
|
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||||
|
|
||||||
|
assert is_state_dict_likely_flux_control(state_dict)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys])
|
||||||
|
def test_is_state_dict_likely_in_flux_control_format_false(sd_keys: dict[str, list[int]]):
|
||||||
|
"""Test that is_state_dict_likely_flux_control() returns False for a state dict that is in the Diffusers
|
||||||
|
FLUX LoRA format.
|
||||||
|
"""
|
||||||
|
# Construct a state dict that is not in the FLUX Control LoRA format.
|
||||||
|
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||||
|
|
||||||
|
assert not is_state_dict_likely_flux_control(state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sd_keys", [flux_control_lora_state_dict_keys])
|
||||||
|
def test_lora_model_from_flux_control_state_dict(sd_keys: dict[str, list[int]]):
|
||||||
|
"""Test that lora_model_from_flux_control_state_dict() can load a state dict in the FLUX Control LoRA format."""
|
||||||
|
# Construct a state dict that is in the FLUX Control LoRA format.
|
||||||
|
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||||
|
# Load the state dict into a LoRAModelRaw object.
|
||||||
|
model = lora_model_from_flux_control_state_dict(state_dict)
|
||||||
|
|
||||||
|
# Check that the model has the correct number of LoRA layers.
|
||||||
|
expected_lora_layers: set[str] = set()
|
||||||
|
for k in sd_keys:
|
||||||
|
k = k.replace("lora_A.weight", "")
|
||||||
|
k = k.replace("lora_B.weight", "")
|
||||||
|
k = k.replace("lora_B.bias", "")
|
||||||
|
k = k.replace(".scale", "")
|
||||||
|
expected_lora_layers.add(k)
|
||||||
|
# Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep
|
||||||
|
# the Q weights so that we count these layers once).
|
||||||
|
assert len(model.layers) == len(expected_lora_layers)
|
||||||
|
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_model_from_flux_control_state_dict_extra_keys_error():
|
||||||
|
"""Test that lora_model_from_flux_control_state_dict() raises an error if the input state_dict contains unexpected
|
||||||
|
keys that we don't handle.
|
||||||
|
"""
|
||||||
|
# Construct a state dict that is in the FLUX Control LoRA format.
|
||||||
|
state_dict = keys_to_mock_state_dict(flux_control_lora_state_dict_keys)
|
||||||
|
# Add an unexpected key.
|
||||||
|
state_dict["transformer.single_transformer_blocks.0.unexpected_key.lora_A.weight"] = torch.empty(1)
|
||||||
|
|
||||||
|
# Check that an error is raised.
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
lora_model_from_flux_control_state_dict(state_dict)
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
import copy
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
|
||||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
|
||||||
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
|
|
||||||
ConcatenatedLoRALinearSidecarLayer,
|
|
||||||
)
|
|
||||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
|
||||||
|
|
||||||
|
|
||||||
def test_concatenated_lora_linear_sidecar_layer():
|
|
||||||
"""Test that a ConcatenatedLoRALinearSidecarLayer is equivalent to patching a linear layer with the ConcatenatedLoRA
|
|
||||||
layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Create a linear layer.
|
|
||||||
in_features = 5
|
|
||||||
sub_layer_out_features = [5, 10, 15]
|
|
||||||
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
|
|
||||||
|
|
||||||
# Create a ConcatenatedLoRA layer.
|
|
||||||
rank = 4
|
|
||||||
sub_layers: list[LoRALayer] = []
|
|
||||||
for out_features in sub_layer_out_features:
|
|
||||||
down = torch.randn(rank, in_features)
|
|
||||||
up = torch.randn(out_features, rank)
|
|
||||||
bias = torch.randn(out_features)
|
|
||||||
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
|
|
||||||
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
|
||||||
|
|
||||||
# Patch the ConcatenatedLoRA layer into the linear layer.
|
|
||||||
linear_patched = copy.deepcopy(linear)
|
|
||||||
linear_patched.weight.data += (
|
|
||||||
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
|
|
||||||
)
|
|
||||||
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
|
|
||||||
|
|
||||||
# Create a ConcatenatedLoRALinearSidecarLayer.
|
|
||||||
concatenated_lora_linear_sidecar_layer = ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer, weight=1.0)
|
|
||||||
linear_with_sidecar = LoRASidecarModule(linear, [concatenated_lora_linear_sidecar_layer])
|
|
||||||
|
|
||||||
# Run the ConcatenatedLoRA-patched linear layer and the ConcatenatedLoRALinearSidecarLayer and assert they are
|
|
||||||
# equal.
|
|
||||||
input = torch.randn(1, in_features)
|
|
||||||
output_patched = linear_patched(input)
|
|
||||||
output_sidecar = linear_with_sidecar(input)
|
|
||||||
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
import copy
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
|
||||||
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
|
|
||||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def test_lora_linear_sidecar_layer():
|
|
||||||
"""Test that a LoRALinearSidecarLayer is equivalent to patching a linear layer with the LoRA layer."""
|
|
||||||
|
|
||||||
# Create a linear layer.
|
|
||||||
in_features = 10
|
|
||||||
out_features = 20
|
|
||||||
linear = torch.nn.Linear(in_features, out_features)
|
|
||||||
|
|
||||||
# Create a LoRA layer.
|
|
||||||
rank = 4
|
|
||||||
down = torch.randn(rank, in_features)
|
|
||||||
up = torch.randn(out_features, rank)
|
|
||||||
bias = torch.randn(out_features)
|
|
||||||
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
|
|
||||||
|
|
||||||
# Patch the LoRA layer into the linear layer.
|
|
||||||
linear_patched = copy.deepcopy(linear)
|
|
||||||
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
|
|
||||||
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
|
|
||||||
# Create a LoRALinearSidecarLayer.
|
|
||||||
lora_linear_sidecar_layer = LoRALinearSidecarLayer(lora_layer, weight=1.0)
|
|
||||||
linear_with_sidecar = LoRASidecarModule(linear, [lora_linear_sidecar_layer])
|
|
||||||
|
|
||||||
# Run the LoRA-patched linear layer and the LoRALinearSidecarLayer and assert they are equal.
|
|
||||||
input = torch.randn(1, in_features)
|
|
||||||
output_patched = linear_patched(input)
|
|
||||||
output_sidecar = linear_with_sidecar(input)
|
|
||||||
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)
|
|
||||||
69
tests/backend/lora/test_lora_layer_wrappers.py
Normal file
69
tests/backend/lora/test_lora_layer_wrappers.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||||
|
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||||
|
from invokeai.backend.lora.lora_layer_wrappers import LoRALinearWrapper
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_lora_linear_wrapper():
|
||||||
|
# Create a linear layer.
|
||||||
|
in_features = 10
|
||||||
|
out_features = 20
|
||||||
|
linear = torch.nn.Linear(in_features, out_features)
|
||||||
|
|
||||||
|
# Create a LoRA layer.
|
||||||
|
rank = 4
|
||||||
|
down = torch.randn(rank, in_features)
|
||||||
|
up = torch.randn(out_features, rank)
|
||||||
|
bias = torch.randn(out_features)
|
||||||
|
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
|
||||||
|
|
||||||
|
# Patch the LoRA layer into the linear layer.
|
||||||
|
linear_patched = copy.deepcopy(linear)
|
||||||
|
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
|
||||||
|
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
|
||||||
|
|
||||||
|
# Create a LoRALinearWrapper.
|
||||||
|
lora_wrapped = LoRALinearWrapper(linear, [lora_layer], [1.0])
|
||||||
|
|
||||||
|
# Run the LoRA-patched linear layer and the LoRALinearWrapper and assert they are equal.
|
||||||
|
input = torch.randn(1, in_features)
|
||||||
|
output_patched = linear_patched(input)
|
||||||
|
output_wrapped = lora_wrapped(input)
|
||||||
|
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_concatenated_lora_linear_wrapper():
|
||||||
|
# Create a linear layer.
|
||||||
|
in_features = 5
|
||||||
|
sub_layer_out_features = [5, 10, 15]
|
||||||
|
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
|
||||||
|
|
||||||
|
# Create a ConcatenatedLoRA layer.
|
||||||
|
rank = 4
|
||||||
|
sub_layers: list[LoRALayer] = []
|
||||||
|
for out_features in sub_layer_out_features:
|
||||||
|
down = torch.randn(rank, in_features)
|
||||||
|
up = torch.randn(out_features, rank)
|
||||||
|
bias = torch.randn(out_features)
|
||||||
|
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
|
||||||
|
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
||||||
|
|
||||||
|
# Patch the ConcatenatedLoRA layer into the linear layer.
|
||||||
|
linear_patched = copy.deepcopy(linear)
|
||||||
|
linear_patched.weight.data += (
|
||||||
|
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
|
||||||
|
)
|
||||||
|
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
|
||||||
|
|
||||||
|
# Create a LoRALinearWrapper.
|
||||||
|
lora_wrapped = LoRALinearWrapper(linear, [concatenated_lora_layer], [1.0])
|
||||||
|
|
||||||
|
# Run the ConcatenatedLoRA-patched linear layer and the LoRALinearWrapper and assert they are equal.
|
||||||
|
input = torch.randn(1, in_features)
|
||||||
|
output_patched = linear_patched(input)
|
||||||
|
output_wrapped = lora_wrapped(input)
|
||||||
|
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||||
@@ -2,11 +2,15 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||||
|
from invokeai.backend.lora.lora_layer_wrappers import LoRASidecarWrapper
|
||||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||||
|
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||||
|
CachedModelWithPartialLoad,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DummyModule(torch.nn.Module):
|
class DummyModuleWithOneLayer(torch.nn.Module):
|
||||||
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
|
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||||
@@ -15,8 +19,18 @@ class DummyModule(torch.nn.Module):
|
|||||||
return self.linear_layer_1(x)
|
return self.linear_layer_1(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModuleWithTwoLayers(torch.nn.Module):
|
||||||
|
def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||||
|
self.linear_layer_2 = torch.nn.Linear(out_features, out_features, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.linear_layer_2(self.linear_layer_1(x))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
["device", "num_layers"],
|
["device", "num_loras"],
|
||||||
[
|
[
|
||||||
("cpu", 1),
|
("cpu", 1),
|
||||||
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||||
@@ -25,7 +39,7 @@ class DummyModule(torch.nn.Module):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_apply_lora_patches(device: str, num_layers: int):
|
def test_apply_lora_patches(device: str, num_loras: int):
|
||||||
"""Test the basic behavior of ModelPatcher.apply_lora_patches(...). Check that patching and unpatching produce the
|
"""Test the basic behavior of ModelPatcher.apply_lora_patches(...). Check that patching and unpatching produce the
|
||||||
correct result, and that model/LoRA tensors are moved between devices as expected.
|
correct result, and that model/LoRA tensors are moved between devices as expected.
|
||||||
"""
|
"""
|
||||||
@@ -33,12 +47,12 @@ def test_apply_lora_patches(device: str, num_layers: int):
|
|||||||
linear_in_features = 4
|
linear_in_features = 4
|
||||||
linear_out_features = 8
|
linear_out_features = 8
|
||||||
lora_rank = 2
|
lora_rank = 2
|
||||||
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
|
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
|
||||||
|
|
||||||
# Initialize num_layers LoRA models with weights of 0.5.
|
# Initialize num_loras LoRA models with weights of 0.5.
|
||||||
lora_weight = 0.5
|
lora_weight = 0.5
|
||||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||||
for _ in range(num_layers):
|
for _ in range(num_loras):
|
||||||
lora_layers = {
|
lora_layers = {
|
||||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||||
values={
|
values={
|
||||||
@@ -51,7 +65,7 @@ def test_apply_lora_patches(device: str, num_layers: int):
|
|||||||
lora_models.append((lora, lora_weight))
|
lora_models.append((lora, lora_weight))
|
||||||
|
|
||||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||||
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers)
|
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras)
|
||||||
|
|
||||||
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
|
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
|
||||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||||
@@ -79,7 +93,7 @@ def test_apply_lora_patches_change_device():
|
|||||||
linear_out_features = 8
|
linear_out_features = 8
|
||||||
lora_dim = 2
|
lora_dim = 2
|
||||||
# Initialize the model on the CPU.
|
# Initialize the model on the CPU.
|
||||||
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||||
|
|
||||||
lora_layers = {
|
lora_layers = {
|
||||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||||
@@ -110,7 +124,7 @@ def test_apply_lora_patches_change_device():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
["device", "num_layers"],
|
["device", "num_loras"],
|
||||||
[
|
[
|
||||||
("cpu", 1),
|
("cpu", 1),
|
||||||
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||||
@@ -118,18 +132,18 @@ def test_apply_lora_patches_change_device():
|
|||||||
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_apply_lora_sidecar_patches(device: str, num_layers: int):
|
def test_apply_lora_wrapper_patches(device: str, num_loras: int):
|
||||||
"""Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly."""
|
"""Test the basic behavior of ModelPatcher.apply_lora_wrapper_patches(...). Check that unpatching works correctly."""
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
linear_in_features = 4
|
linear_in_features = 4
|
||||||
linear_out_features = 8
|
linear_out_features = 8
|
||||||
lora_rank = 2
|
lora_rank = 2
|
||||||
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
||||||
|
|
||||||
# Initialize num_layers LoRA models with weights of 0.5.
|
# Initialize num_loras LoRA models with weights of 0.5.
|
||||||
lora_weight = 0.5
|
lora_weight = 0.5
|
||||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||||
for _ in range(num_layers):
|
for _ in range(num_loras):
|
||||||
lora_layers = {
|
lora_layers = {
|
||||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||||
values={
|
values={
|
||||||
@@ -146,7 +160,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
|
|||||||
output_before_patch = model(input)
|
output_before_patch = model(input)
|
||||||
|
|
||||||
# Patch the model and run inference during the patch.
|
# Patch the model and run inference during the patch.
|
||||||
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||||
output_during_patch = model(input)
|
output_during_patch = model(input)
|
||||||
|
|
||||||
# Run inference after unpatching.
|
# Run inference after unpatching.
|
||||||
@@ -159,20 +173,140 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
|
|||||||
assert torch.allclose(output_before_patch, output_after_patch)
|
assert torch.allclose(output_before_patch, output_after_patch)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["device", "num_loras"],
|
||||||
|
[
|
||||||
|
("cpu", 1),
|
||||||
|
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||||
|
("cpu", 2),
|
||||||
|
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||||
|
],
|
||||||
|
)
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@pytest.mark.parametrize(["num_layers"], [(1,), (2,)])
|
def test_apply_smart_lora_patches(device: str, num_loras: int):
|
||||||
def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
|
"""Test the basic behavior of ModelPatcher.apply_smart_lora_patches(...). Check that unpatching works correctly."""
|
||||||
"""Test that apply_lora_sidecar_patches(...) produces the same model outputs as apply_lora_patches(...)."""
|
dtype = torch.float16
|
||||||
|
linear_in_features = 4
|
||||||
|
linear_out_features = 8
|
||||||
|
lora_rank = 2
|
||||||
|
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Initialize num_loras LoRA models with weights of 0.5.
|
||||||
|
lora_weight = 0.5
|
||||||
|
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||||
|
for _ in range(num_loras):
|
||||||
|
lora_layers = {
|
||||||
|
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||||
|
values={
|
||||||
|
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||||
|
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
lora = LoRAModelRaw(lora_layers)
|
||||||
|
lora_models.append((lora, lora_weight))
|
||||||
|
|
||||||
|
# Run inference before patching the model.
|
||||||
|
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
|
||||||
|
output_before_patch = model(input)
|
||||||
|
|
||||||
|
# Patch the model and run inference during the patch.
|
||||||
|
with LoRAPatcher.apply_smart_lora_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||||
|
output_during_patch = model(input)
|
||||||
|
|
||||||
|
# Run inference after unpatching.
|
||||||
|
output_after_patch = model(input)
|
||||||
|
|
||||||
|
# Check that the output before patching is different from the output during patching.
|
||||||
|
assert not torch.allclose(output_before_patch, output_during_patch)
|
||||||
|
|
||||||
|
# Check that the output before patching is the same as the output after patching.
|
||||||
|
assert torch.allclose(output_before_patch, output_after_patch)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(["num_loras"], [(1,), (2,)])
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_apply_smart_lora_patches_to_partially_loaded_model(num_loras: int):
|
||||||
|
"""Test the behavior of ModelPatcher.apply_smart_lora_patches(...) when it is applied to a
|
||||||
|
CachedModelWithPartialLoad that is partially loaded into VRAM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
pytest.skip("requires CUDA device")
|
||||||
|
|
||||||
|
# Initialize the model on the CPU.
|
||||||
|
dtype = torch.float16
|
||||||
|
linear_in_features = 4
|
||||||
|
linear_out_features = 8
|
||||||
|
lora_rank = 2
|
||||||
|
model = DummyModuleWithTwoLayers(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
||||||
|
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("cuda"))
|
||||||
|
model_total_bytes = cached_model.total_bytes()
|
||||||
|
assert cached_model.cur_vram_bytes() == 0
|
||||||
|
|
||||||
|
# Partially load the model into VRAM.
|
||||||
|
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||||
|
_ = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||||
|
assert cached_model.model.linear_layer_1.weight.device.type == "cuda"
|
||||||
|
assert cached_model.model.linear_layer_2.weight.device.type == "cpu"
|
||||||
|
|
||||||
|
# Initialize num_loras LoRA models with weights of 0.5.
|
||||||
|
lora_weight = 0.5
|
||||||
|
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||||
|
for _ in range(num_loras):
|
||||||
|
lora_layers = {
|
||||||
|
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||||
|
values={
|
||||||
|
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||||
|
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"linear_layer_2": LoRALayer.from_state_dict_values(
|
||||||
|
values={
|
||||||
|
"lora_down.weight": torch.ones((lora_rank, linear_out_features), device="cpu", dtype=torch.float16),
|
||||||
|
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
lora = LoRAModelRaw(lora_layers)
|
||||||
|
lora_models.append((lora, lora_weight))
|
||||||
|
|
||||||
|
# Run inference before patching the model.
|
||||||
|
input = torch.randn(1, linear_in_features, device="cuda", dtype=dtype)
|
||||||
|
output_before_patch = cached_model.model(input)
|
||||||
|
|
||||||
|
# Patch the model and run inference during the patch.
|
||||||
|
with LoRAPatcher.apply_smart_lora_patches(model=cached_model.model, patches=lora_models, prefix="", dtype=dtype):
|
||||||
|
# Check that the second layer is wrapped in a LoRASidecarWrapper, but the first layer is not.
|
||||||
|
assert not isinstance(cached_model.model.linear_layer_1, LoRASidecarWrapper)
|
||||||
|
assert isinstance(cached_model.model.linear_layer_2, LoRASidecarWrapper)
|
||||||
|
|
||||||
|
output_during_patch = cached_model.model(input)
|
||||||
|
|
||||||
|
# Run inference after unpatching.
|
||||||
|
output_after_patch = cached_model.model(input)
|
||||||
|
|
||||||
|
# Check that the output before patching is different from the output during patching.
|
||||||
|
assert not torch.allclose(output_before_patch, output_during_patch)
|
||||||
|
|
||||||
|
# Check that the output before patching is the same as the output after patching.
|
||||||
|
assert torch.allclose(output_before_patch, output_after_patch)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@pytest.mark.parametrize(["num_loras"], [(1,), (2,)])
|
||||||
|
def test_all_patching_methods_produce_same_output(num_loras: int):
|
||||||
|
"""Test that apply_lora_wrapper_patches(...) produces the same model outputs as apply_lora_patches(...)."""
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
linear_in_features = 4
|
linear_in_features = 4
|
||||||
linear_out_features = 8
|
linear_out_features = 8
|
||||||
lora_rank = 2
|
lora_rank = 2
|
||||||
model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=dtype)
|
||||||
|
|
||||||
# Initialize num_layers LoRA models with weights of 0.5.
|
# Initialize num_loras LoRA models with weights of 0.5.
|
||||||
lora_weight = 0.5
|
lora_weight = 0.5
|
||||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||||
for _ in range(num_layers):
|
for _ in range(num_loras):
|
||||||
lora_layers = {
|
lora_layers = {
|
||||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||||
values={
|
values={
|
||||||
@@ -189,9 +323,13 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
|
|||||||
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
|
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
|
||||||
output_lora_patches = model(input)
|
output_lora_patches = model(input)
|
||||||
|
|
||||||
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||||
output_lora_sidecar_patches = model(input)
|
output_lora_wrapper_patches = model(input)
|
||||||
|
|
||||||
|
with LoRAPatcher.apply_smart_lora_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||||
|
output_smart_lora_patches = model(input)
|
||||||
|
|
||||||
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
|
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
|
||||||
# differences are tolerable and expected due to the difference between sidecar vs. patching.
|
# differences are tolerable and expected due to the difference between sidecar vs. patching.
|
||||||
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
|
assert torch.allclose(output_lora_patches, output_lora_wrapper_patches, atol=1e-5)
|
||||||
|
assert torch.allclose(output_lora_patches, output_smart_lora_patches, atol=1e-5)
|
||||||
|
|||||||
Reference in New Issue
Block a user