mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Update all lycoris layer types to use the new torch.nn.Module base class.
This commit is contained in:
@@ -3,33 +3,24 @@ from typing import Dict, Optional
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||
super().__init__(alpha=None, bias=bias)
|
||||
self.weight = torch.nn.Parameter(weight)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@classmethod
|
||||
def from_state_dict_values(
|
||||
cls,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
layer = cls(weight=values["diff"], bias=values.get("diff_b", None))
|
||||
cls.warn_on_unhandled_keys(values=values, handled_keys={"diff", "diff_b"})
|
||||
return layer
|
||||
|
||||
self.weight = values["diff"]
|
||||
self.bias = values.get("diff_b", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"diff", "diff_b"})
|
||||
def rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return calc_tensor_size(self.weight) + super().calc_size()
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -3,39 +3,46 @@ from typing import Dict, Optional
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
def __init__(self, weight: torch.Tensor, on_input: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||
super().__init__(alpha=None, bias=bias)
|
||||
self.weight = torch.nn.Parameter(weight)
|
||||
self.on_input = torch.nn.Parameter(on_input)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
def rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_state_dict_values(
|
||||
cls,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"weight", "on_input"})
|
||||
bias = cls._parse_bias(
|
||||
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
|
||||
)
|
||||
layer = cls(
|
||||
weight=values["weight"],
|
||||
on_input=values["on_input"],
|
||||
bias=bias,
|
||||
)
|
||||
cls.warn_on_unhandled_keys(
|
||||
values=values,
|
||||
handled_keys={
|
||||
# Default keys.
|
||||
"bias_indices",
|
||||
"bias_values",
|
||||
"bias_size",
|
||||
# Layer-specific keys.
|
||||
"weight",
|
||||
"on_input",
|
||||
},
|
||||
)
|
||||
return layer
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += calc_tensors_size([self.weight, self.on_input])
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -1,33 +1,63 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
def __init__(
|
||||
self,
|
||||
w1_a: torch.Tensor,
|
||||
w1_b: torch.Tensor,
|
||||
w2_a: torch.Tensor,
|
||||
w2_b: torch.Tensor,
|
||||
t1: torch.Tensor | None,
|
||||
t2: torch.Tensor | None,
|
||||
alpha: float | None,
|
||||
bias: torch.Tensor | None,
|
||||
):
|
||||
super().__init__(alpha=alpha, bias=bias)
|
||||
self.w1_a = torch.nn.Parameter(w1_a)
|
||||
self.w1_b = torch.nn.Parameter(w1_b)
|
||||
self.w2_a = torch.nn.Parameter(w2_a)
|
||||
self.w2_b = torch.nn.Parameter(w2_b)
|
||||
self.t1 = torch.nn.Parameter(t1) if t1 is not None else None
|
||||
self.t2 = torch.nn.Parameter(t2) if t2 is not None else None
|
||||
assert (self.t1 is None) == (self.t2 is None)
|
||||
|
||||
def __init__(self, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(values)
|
||||
def rank(self) -> int | None:
|
||||
return self.w1_b.shape[0]
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
self.t1 = values.get("hada_t1", None)
|
||||
self.t2 = values.get("hada_t2", None)
|
||||
@classmethod
|
||||
def from_state_dict_values(
|
||||
cls,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
alpha = cls._parse_alpha(values.get("alpha", None))
|
||||
bias = cls._parse_bias(
|
||||
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
|
||||
)
|
||||
layer = cls(
|
||||
w1_a=values["hada_w1_a"],
|
||||
w1_b=values["hada_w1_b"],
|
||||
w2_a=values["hada_w2_a"],
|
||||
w2_b=values["hada_w2_b"],
|
||||
t1=values.get("hada_t1", None),
|
||||
t2=values.get("hada_t2", None),
|
||||
alpha=alpha,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
cls.warn_on_unhandled_keys(
|
||||
values=values,
|
||||
handled_keys={
|
||||
# Default keys.
|
||||
"alpha",
|
||||
"bias_indices",
|
||||
"bias_values",
|
||||
"bias_size",
|
||||
# Layer-specific keys.
|
||||
"hada_w1_a",
|
||||
"hada_w1_b",
|
||||
"hada_w2_a",
|
||||
@@ -37,31 +67,14 @@ class LoHALayer(LoRALayerBase):
|
||||
},
|
||||
)
|
||||
|
||||
return layer
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += calc_tensors_size([self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2])
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -1,54 +1,75 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
w1: torch.Tensor | None,
|
||||
w1_a: torch.Tensor | None,
|
||||
w1_b: torch.Tensor | None,
|
||||
w2: torch.Tensor | None,
|
||||
w2_a: torch.Tensor | None,
|
||||
w2_b: torch.Tensor | None,
|
||||
t2: torch.Tensor | None,
|
||||
alpha: float | None,
|
||||
bias: torch.Tensor | None,
|
||||
):
|
||||
super().__init__(alpha=alpha, bias=bias)
|
||||
self.w1 = torch.nn.Parameter(w1) if w1 is not None else None
|
||||
self.w1_a = torch.nn.Parameter(w1_a) if w1_a is not None else None
|
||||
self.w1_b = torch.nn.Parameter(w1_b) if w1_b is not None else None
|
||||
self.w2 = torch.nn.Parameter(w2) if w2 is not None else None
|
||||
self.w2_a = torch.nn.Parameter(w2_a) if w2_a is not None else None
|
||||
self.w2_b = torch.nn.Parameter(w2_b) if w2_b is not None else None
|
||||
self.t2 = torch.nn.Parameter(t2) if t2 is not None else None
|
||||
|
||||
# Validate parameters.
|
||||
assert (self.w1 is None) != (self.w1_a is None)
|
||||
assert (self.w1_a is None) == (self.w1_b is None)
|
||||
assert (self.w2 is None) != (self.w2_a is None)
|
||||
|
||||
def rank(self) -> int | None:
|
||||
if self.w1_b is not None:
|
||||
return self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
return self.w2_b.shape[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_state_dict_values(
|
||||
cls,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
alpha = cls._parse_alpha(values.get("alpha", None))
|
||||
bias = cls._parse_bias(
|
||||
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
|
||||
)
|
||||
layer = cls(
|
||||
w1=values.get("lokr_w1", None),
|
||||
w1_a=values.get("lokr_w1_a", None),
|
||||
w1_b=values.get("lokr_w1_b", None),
|
||||
w2=values.get("lokr_w2", None),
|
||||
w2_a=values.get("lokr_w2_a", None),
|
||||
w2_b=values.get("lokr_w2_b", None),
|
||||
t2=values.get("lokr_t2", None),
|
||||
alpha=alpha,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.w1 = values.get("lokr_w1", None)
|
||||
if self.w1 is None:
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
else:
|
||||
self.w1_b = None
|
||||
self.w1_a = None
|
||||
|
||||
self.w2 = values.get("lokr_w2", None)
|
||||
if self.w2 is None:
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
else:
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
|
||||
self.t2 = values.get("lokr_t2", None)
|
||||
|
||||
if self.w1_b is not None:
|
||||
self.rank = self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
self.rank = self.w2_b.shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
self.check_keys(
|
||||
cls.warn_on_unhandled_keys(
|
||||
values,
|
||||
{
|
||||
# Default keys.
|
||||
"alpha",
|
||||
"bias_indices",
|
||||
"bias_values",
|
||||
"bias_size",
|
||||
# Layer-specific keys.
|
||||
"lokr_w1",
|
||||
"lokr_w1_a",
|
||||
"lokr_w1_b",
|
||||
@@ -59,8 +80,10 @@ class LoKRLayer(LoRALayerBase):
|
||||
},
|
||||
)
|
||||
|
||||
return layer
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
w1 = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
@@ -78,35 +101,5 @@ class LoKRLayer(LoRALayerBase):
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += calc_tensors_size([self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2])
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -31,11 +31,13 @@ class LoRALayerBase(torch.nn.Module):
|
||||
) -> float | None:
|
||||
return alpha.item() if alpha is not None else None
|
||||
|
||||
def rank(self) -> int:
|
||||
def rank(self) -> int | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def scale(self) -> float:
|
||||
return self._alpha / self.rank() if self._alpha is not None else 1.0
|
||||
if self._alpha is None or self.rank() is None:
|
||||
return 1.0
|
||||
return self._alpha / self.rank()
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
@@ -58,3 +60,9 @@ class LoRALayerBase(torch.nn.Module):
|
||||
logger.warning(
|
||||
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Unexpected keys: {unknown_keys}"
|
||||
)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
# HACK(ryand): Fix this issue with circular imports.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self)
|
||||
|
||||
@@ -1,37 +1,26 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class NormLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
def __init__(self, weight: torch.Tensor, bias: torch.Tensor | None):
|
||||
super().__init__(alpha=None, bias=bias)
|
||||
self.weight = torch.nn.Parameter(weight)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@classmethod
|
||||
def from_state_dict_values(
|
||||
cls,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
layer = cls(weight=values["w_norm"], bias=values.get("b_norm", None))
|
||||
cls.warn_on_unhandled_keys(values, {"w_norm", "b_norm"})
|
||||
return layer
|
||||
|
||||
self.weight = values["w_norm"]
|
||||
self.bias = values.get("b_norm", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"w_norm", "b_norm"})
|
||||
def rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += calc_tensor_size(self.weight)
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -19,15 +19,15 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLo
|
||||
# LoRA a.k.a LoCon
|
||||
return LoRALayer.from_state_dict_values(state_dict)
|
||||
elif "hada_w1_a" in state_dict:
|
||||
return LoHALayer(state_dict)
|
||||
return LoHALayer.from_state_dict_values(state_dict)
|
||||
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
|
||||
return LoKRLayer(state_dict)
|
||||
return LoKRLayer.from_state_dict_values(state_dict)
|
||||
elif "diff" in state_dict:
|
||||
# Full a.k.a Diff
|
||||
return FullLayer(state_dict)
|
||||
return FullLayer.from_state_dict_values(state_dict)
|
||||
elif "on_input" in state_dict:
|
||||
return IA3Layer(state_dict)
|
||||
return IA3Layer.from_state_dict_values(state_dict)
|
||||
elif "w_norm" in state_dict:
|
||||
return NormLayer(state_dict)
|
||||
return NormLayer.from_state_dict_values(state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
|
||||
|
||||
Reference in New Issue
Block a user