Files
InvokeAI/invokeai/backend/patches/layers/lora_layer.py
Kent Keirsey 1f63b60021 Implementing support for Non-Standard LoRA Format (#7985)
* integrate loRA

* idk anymore tbh

* enable fused matrix for quantized models

* integrate loRA

* idk anymore tbh

* enable fused matrix for quantized models

* ruff fix

---------

Co-authored-by: Sam <bhaskarmdutt@gmail.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-05-05 09:40:38 -04:00

111 lines
3.8 KiB
Python

from typing import Dict, Optional
import torch
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoRALayer(LoRALayerBase):
def __init__(
self,
up: torch.Tensor,
mid: Optional[torch.Tensor],
down: torch.Tensor,
alpha: float | None,
bias: Optional[torch.Tensor],
):
super().__init__(alpha, bias)
self.up = up
self.mid = mid
self.down = down
self.are_ranks_equal = up.shape[1] == down.shape[0]
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
alpha = cls._parse_alpha(values.get("alpha", None))
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
up=values["lora_up.weight"],
down=values["lora_down.weight"],
mid=values.get("lora_mid.weight", None),
alpha=alpha,
bias=bias,
)
cls.warn_on_unhandled_keys(
values=values,
handled_keys={
# Default keys.
"alpha",
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
return layer
def _rank(self) -> int:
return self.down.shape[0]
def fuse_weights(self, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor:
"""
Fuse the weights of the up and down matrices of a LoRA layer with different ranks.
Since the Huggingface implementation of KQV projections are fused, when we convert to Kohya format
the LoRA weights have different ranks. This function handles the fusion of these differently sized
matrices.
"""
fused_lora = torch.zeros((up.shape[0], down.shape[1]), device=down.device, dtype=down.dtype)
rank_diff = down.shape[0] / up.shape[1]
if rank_diff > 1:
rank_diff = down.shape[0] / up.shape[1]
w_down = down.chunk(int(rank_diff), dim=0)
for w_down_chunk in w_down:
fused_lora = fused_lora + (torch.mm(up, w_down_chunk))
else:
rank_diff = up.shape[1] / down.shape[0]
w_up = up.chunk(int(rank_diff), dim=0)
for w_up_chunk in w_up:
fused_lora = fused_lora + (torch.mm(w_up_chunk, down))
return fused_lora
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
# up matrix and down matrix have different ranks so we can't simply multiply them
if not self.are_ranks_equal:
weight = self.fuse_weights(self.up, self.down)
return weight
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensors_size([self.up, self.mid, self.down])