mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-03 00:25:24 -05:00
WIP - Implement sidecar LoRA layers using functional API.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -31,6 +32,7 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
)
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -191,21 +193,38 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
with (
|
||||
transformer_info.model_on_device() as (cached_weights, transformer),
|
||||
# Apply the LoRA after transformer has been moved to its target device for faster patching.
|
||||
# LoRAPatcher.apply_lora_sidecar_patches(
|
||||
# model=transformer,
|
||||
# patches=self._lora_iterator(context),
|
||||
# prefix="",
|
||||
# ),
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix="",
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
config = transformer_info.config
|
||||
assert config is not None
|
||||
|
||||
# Apply LoRA models to the transformer.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
if config.format in [ModelFormat.Checkpoint]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix="",
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
elif config.format in [ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b]:
|
||||
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
|
||||
# than directly patching the weights, but is agnostic to the quantization format.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_sidecar_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix="",
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model format: {config.format}")
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
|
||||
@@ -179,7 +179,9 @@ class LoRAPatcher:
|
||||
|
||||
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
|
||||
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
|
||||
lora_sidecar_layer.to(device=module.weight.device, dtype=module.weight.dtype)
|
||||
# HACK(ryand): Set the dtype properly here. We want to set it to the *compute* dtype of the original module.
|
||||
# In the case of quantized layers, this may be different than the weight dtype.
|
||||
lora_sidecar_layer.to(device=module.weight.device, dtype=torch.bfloat16)
|
||||
|
||||
if module_key in original_modules:
|
||||
# The module has already been patched with a LoRASidecarModule. Append to it.
|
||||
@@ -197,7 +199,7 @@ class LoRAPatcher:
|
||||
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
|
||||
if isinstance(orig_layer, torch.nn.Linear):
|
||||
if isinstance(lora_layer, LoRALayer):
|
||||
return LoRALinearSidecarLayer.from_layers(orig_layer, lora_layer, patch_weight)
|
||||
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
|
||||
elif isinstance(orig_layer, torch.nn.Conv1d):
|
||||
|
||||
@@ -70,6 +70,8 @@ class LoRAConvSidecarLayer(torch.nn.Module):
|
||||
weight=weight,
|
||||
)
|
||||
|
||||
# TODO(ryand): Are there cases where we need to reshape the weight matrices to match the conv layers?
|
||||
|
||||
# Inject weight into the LoRA layer.
|
||||
assert model._up.weight.shape == lora_layer.up.shape
|
||||
assert model._down.weight.shape == lora_layer.down.shape
|
||||
|
||||
@@ -4,95 +4,24 @@ from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
class LoRALinearSidecarLayer(torch.nn.Module):
|
||||
"""An implementation of a linear LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'.
|
||||
(https://arxiv.org/pdf/2106.09685.pdf)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
include_mid: bool,
|
||||
rank: int,
|
||||
alpha: float,
|
||||
lora_layer: LoRALayer,
|
||||
weight: float,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if rank > min(in_features, out_features):
|
||||
raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_features, out_features)}")
|
||||
|
||||
self._down = torch.nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||
self._up = torch.nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
||||
self._mid = None
|
||||
if include_mid:
|
||||
self._mid = torch.nn.Linear(rank, rank, bias=False, device=device, dtype=dtype)
|
||||
|
||||
# Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict.
|
||||
self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype))
|
||||
|
||||
self._lora_layer = lora_layer
|
||||
self._weight = weight
|
||||
self._rank = rank
|
||||
|
||||
@classmethod
|
||||
def from_layers(cls, orig_layer: torch.nn.Module, lora_layer: LoRALayer, weight: float):
|
||||
# Initialize the LoRA layer.
|
||||
with torch.device("meta"):
|
||||
model = cls.from_orig_layer(
|
||||
orig_layer,
|
||||
include_mid=lora_layer.mid is not None,
|
||||
rank=lora_layer.rank,
|
||||
# TODO(ryand): Is this the right default in case of missing alpha?
|
||||
alpha=lora_layer.alpha if lora_layer.alpha is not None else lora_layer.rank,
|
||||
weight=weight,
|
||||
)
|
||||
|
||||
# TODO(ryand): Are there cases where we need to reshape the weight matrices to match the conv layers?
|
||||
|
||||
# Inject weight into the LoRA layer.
|
||||
assert model._up.weight.shape == lora_layer.up.shape
|
||||
assert model._down.weight.shape == lora_layer.down.shape
|
||||
model._up.weight.data = lora_layer.up
|
||||
model._down.weight.data = lora_layer.down
|
||||
if lora_layer.mid is not None:
|
||||
assert model._mid is not None
|
||||
assert model._mid.weight.shape == lora_layer.mid.shape
|
||||
model._mid.weight.data = lora_layer.mid
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_orig_layer(
|
||||
cls,
|
||||
layer: torch.nn.Module,
|
||||
include_mid: bool,
|
||||
rank: int,
|
||||
alpha: float,
|
||||
weight: float,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
if not isinstance(layer, torch.nn.Linear):
|
||||
raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.")
|
||||
|
||||
return cls(
|
||||
in_features=layer.in_features,
|
||||
out_features=layer.out_features,
|
||||
include_mid=include_mid,
|
||||
rank=rank,
|
||||
alpha=alpha,
|
||||
weight=weight,
|
||||
device=layer.weight.device if device is None else device,
|
||||
dtype=layer.weight.dtype if dtype is None else dtype,
|
||||
)
|
||||
def to(self, device: torch.device, dtype: torch.dtype):
|
||||
self._lora_layer.to(device, dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self._down(x)
|
||||
if self._mid is not None:
|
||||
x = self._mid(x)
|
||||
x = self._up(x)
|
||||
|
||||
x *= self._weight * self.alpha / self._rank
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.down)
|
||||
if self._lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias)
|
||||
scale = self._lora_layer.alpha / self._lora_layer.rank if self._lora_layer.alpha is not None else 1.0
|
||||
x *= self._weight * scale
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user