mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 05:34:58 -05:00
* integrate loRA * idk anymore tbh * enable fused matrix for quantized models * integrate loRA * idk anymore tbh * enable fused matrix for quantized models * ruff fix --------- Co-authored-by: Sam <bhaskarmdutt@gmail.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
111 lines
3.8 KiB
Python
111 lines
3.8 KiB
Python
from typing import Dict, Optional
|
|
|
|
import torch
|
|
|
|
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
|
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
|
|
|
|
|
class LoRALayer(LoRALayerBase):
|
|
def __init__(
|
|
self,
|
|
up: torch.Tensor,
|
|
mid: Optional[torch.Tensor],
|
|
down: torch.Tensor,
|
|
alpha: float | None,
|
|
bias: Optional[torch.Tensor],
|
|
):
|
|
super().__init__(alpha, bias)
|
|
self.up = up
|
|
self.mid = mid
|
|
self.down = down
|
|
self.are_ranks_equal = up.shape[1] == down.shape[0]
|
|
|
|
@classmethod
|
|
def from_state_dict_values(
|
|
cls,
|
|
values: Dict[str, torch.Tensor],
|
|
):
|
|
alpha = cls._parse_alpha(values.get("alpha", None))
|
|
bias = cls._parse_bias(
|
|
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
|
|
)
|
|
|
|
layer = cls(
|
|
up=values["lora_up.weight"],
|
|
down=values["lora_down.weight"],
|
|
mid=values.get("lora_mid.weight", None),
|
|
alpha=alpha,
|
|
bias=bias,
|
|
)
|
|
|
|
cls.warn_on_unhandled_keys(
|
|
values=values,
|
|
handled_keys={
|
|
# Default keys.
|
|
"alpha",
|
|
"bias_indices",
|
|
"bias_values",
|
|
"bias_size",
|
|
# Layer-specific keys.
|
|
"lora_up.weight",
|
|
"lora_down.weight",
|
|
"lora_mid.weight",
|
|
},
|
|
)
|
|
|
|
return layer
|
|
|
|
def _rank(self) -> int:
|
|
return self.down.shape[0]
|
|
|
|
def fuse_weights(self, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Fuse the weights of the up and down matrices of a LoRA layer with different ranks.
|
|
|
|
Since the Huggingface implementation of KQV projections are fused, when we convert to Kohya format
|
|
the LoRA weights have different ranks. This function handles the fusion of these differently sized
|
|
matrices.
|
|
"""
|
|
|
|
fused_lora = torch.zeros((up.shape[0], down.shape[1]), device=down.device, dtype=down.dtype)
|
|
rank_diff = down.shape[0] / up.shape[1]
|
|
|
|
if rank_diff > 1:
|
|
rank_diff = down.shape[0] / up.shape[1]
|
|
w_down = down.chunk(int(rank_diff), dim=0)
|
|
for w_down_chunk in w_down:
|
|
fused_lora = fused_lora + (torch.mm(up, w_down_chunk))
|
|
else:
|
|
rank_diff = up.shape[1] / down.shape[0]
|
|
w_up = up.chunk(int(rank_diff), dim=0)
|
|
for w_up_chunk in w_up:
|
|
fused_lora = fused_lora + (torch.mm(w_up_chunk, down))
|
|
|
|
return fused_lora
|
|
|
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
if self.mid is not None:
|
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
|
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
|
else:
|
|
# up matrix and down matrix have different ranks so we can't simply multiply them
|
|
if not self.are_ranks_equal:
|
|
weight = self.fuse_weights(self.up, self.down)
|
|
return weight
|
|
|
|
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
|
|
|
return weight
|
|
|
|
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
|
super().to(device=device, dtype=dtype)
|
|
self.up = self.up.to(device=device, dtype=dtype)
|
|
if self.mid is not None:
|
|
self.mid = self.mid.to(device=device, dtype=dtype)
|
|
self.down = self.down.to(device=device, dtype=dtype)
|
|
|
|
def calc_size(self) -> int:
|
|
return super().calc_size() + calc_tensors_size([self.up, self.mid, self.down])
|