Update all lycoris layer types to use the new torch.nn.Module base class.

This commit is contained in:
Ryan Dick
2024-09-12 14:55:40 +00:00
committed by Kent Keirsey
parent 81fbaf2b8b
commit 9438ea608c
7 changed files with 182 additions and 181 deletions

View File

@@ -3,33 +3,24 @@ from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight)
def __init__(
self,
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
super().__init__(values)
layer = cls(weight=values["diff"], bias=values.get("diff_b", None))
cls.warn_on_unhandled_keys(values=values, handled_keys={"diff", "diff_b"})
return layer
self.weight = values["diff"]
self.bias = values.get("diff_b", None)
self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
return calc_tensor_size(self.weight) + super().calc_size()
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)

View File

@@ -3,39 +3,46 @@ from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(self, weight: torch.Tensor, on_input: torch.Tensor, bias: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight)
self.on_input = torch.nn.Parameter(on_input)
def __init__(
self,
def rank(self) -> int | None:
return None
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
super().__init__(values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
weight=values["weight"],
on_input=values["on_input"],
bias=bias,
)
cls.warn_on_unhandled_keys(
values=values,
handled_keys={
# Default keys.
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"weight",
"on_input",
},
)
return layer
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
assert orig_weight is not None
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += calc_tensors_size([self.weight, self.on_input])
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)

View File

@@ -1,33 +1,63 @@
from typing import Dict, Optional
from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
w1_a: torch.Tensor,
w1_b: torch.Tensor,
w2_a: torch.Tensor,
w2_b: torch.Tensor,
t1: torch.Tensor | None,
t2: torch.Tensor | None,
alpha: float | None,
bias: torch.Tensor | None,
):
super().__init__(alpha=alpha, bias=bias)
self.w1_a = torch.nn.Parameter(w1_a)
self.w1_b = torch.nn.Parameter(w1_b)
self.w2_a = torch.nn.Parameter(w2_a)
self.w2_b = torch.nn.Parameter(w2_b)
self.t1 = torch.nn.Parameter(t1) if t1 is not None else None
self.t2 = torch.nn.Parameter(t2) if t2 is not None else None
assert (self.t1 is None) == (self.t2 is None)
def __init__(self, values: Dict[str, torch.Tensor]):
super().__init__(values)
def rank(self) -> int | None:
return self.w1_b.shape[0]
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
alpha = cls._parse_alpha(values.get("alpha", None))
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
w1_a=values["hada_w1_a"],
w1_b=values["hada_w1_b"],
w2_a=values["hada_w2_a"],
w2_b=values["hada_w2_b"],
t1=values.get("hada_t1", None),
t2=values.get("hada_t2", None),
alpha=alpha,
bias=bias,
)
self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
cls.warn_on_unhandled_keys(
values=values,
handled_keys={
# Default keys.
"alpha",
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
@@ -37,31 +67,14 @@ class LoHALayer(LoRALayerBase):
},
)
return layer
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += calc_tensors_size([self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2])
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)

View File

@@ -1,54 +1,75 @@
from typing import Dict, Optional
from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
w1: torch.Tensor | None,
w1_a: torch.Tensor | None,
w1_b: torch.Tensor | None,
w2: torch.Tensor | None,
w2_a: torch.Tensor | None,
w2_b: torch.Tensor | None,
t2: torch.Tensor | None,
alpha: float | None,
bias: torch.Tensor | None,
):
super().__init__(alpha=alpha, bias=bias)
self.w1 = torch.nn.Parameter(w1) if w1 is not None else None
self.w1_a = torch.nn.Parameter(w1_a) if w1_a is not None else None
self.w1_b = torch.nn.Parameter(w1_b) if w1_b is not None else None
self.w2 = torch.nn.Parameter(w2) if w2 is not None else None
self.w2_a = torch.nn.Parameter(w2_a) if w2_a is not None else None
self.w2_b = torch.nn.Parameter(w2_b) if w2_b is not None else None
self.t2 = torch.nn.Parameter(t2) if t2 is not None else None
# Validate parameters.
assert (self.w1 is None) != (self.w1_a is None)
assert (self.w1_a is None) == (self.w1_b is None)
assert (self.w2 is None) != (self.w2_a is None)
def rank(self) -> int | None:
if self.w1_b is not None:
return self.w1_b.shape[0]
elif self.w2_b is not None:
return self.w2_b.shape[0]
else:
return None
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
super().__init__(values)
alpha = cls._parse_alpha(values.get("alpha", None))
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
w1=values.get("lokr_w1", None),
w1_a=values.get("lokr_w1_a", None),
w1_b=values.get("lokr_w1_b", None),
w2=values.get("lokr_w2", None),
w2_a=values.get("lokr_w2_a", None),
w2_b=values.get("lokr_w2_b", None),
t2=values.get("lokr_t2", None),
alpha=alpha,
bias=bias,
)
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
else:
self.w1_b = None
self.w1_a = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
else:
self.w2_a = None
self.w2_b = None
self.t2 = values.get("lokr_t2", None)
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled
self.check_keys(
cls.warn_on_unhandled_keys(
values,
{
# Default keys.
"alpha",
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
@@ -59,8 +80,10 @@ class LoKRLayer(LoRALayerBase):
},
)
return layer
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
w1 = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
@@ -78,35 +101,5 @@ class LoKRLayer(LoRALayerBase):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
assert w1 is not None
assert w2 is not None
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += calc_tensors_size([self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2])
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)

View File

@@ -31,11 +31,13 @@ class LoRALayerBase(torch.nn.Module):
) -> float | None:
return alpha.item() if alpha is not None else None
def rank(self) -> int:
def rank(self) -> int | None:
raise NotImplementedError()
def scale(self) -> float:
return self._alpha / self.rank() if self._alpha is not None else 1.0
if self._alpha is None or self.rank() is None:
return 1.0
return self._alpha / self.rank()
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
@@ -58,3 +60,9 @@ class LoRALayerBase(torch.nn.Module):
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Unexpected keys: {unknown_keys}"
)
def calc_size(self) -> int:
# HACK(ryand): Fix this issue with circular imports.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self)

View File

@@ -1,37 +1,26 @@
from typing import Dict, Optional
from typing import Dict
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class NormLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(self, weight: torch.Tensor, bias: torch.Tensor | None):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight)
def __init__(
self,
@classmethod
def from_state_dict_values(
cls,
values: Dict[str, torch.Tensor],
):
super().__init__(values)
layer = cls(weight=values["w_norm"], bias=values.get("b_norm", None))
cls.warn_on_unhandled_keys(values, {"w_norm", "b_norm"})
return layer
self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)
self.rank = None # unscaled
self.check_keys(values, {"w_norm", "b_norm"})
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += calc_tensor_size(self.weight)
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)

View File

@@ -19,15 +19,15 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLo
# LoRA a.k.a LoCon
return LoRALayer.from_state_dict_values(state_dict)
elif "hada_w1_a" in state_dict:
return LoHALayer(state_dict)
return LoHALayer.from_state_dict_values(state_dict)
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
return LoKRLayer(state_dict)
return LoKRLayer.from_state_dict_values(state_dict)
elif "diff" in state_dict:
# Full a.k.a Diff
return FullLayer(state_dict)
return FullLayer.from_state_dict_values(state_dict)
elif "on_input" in state_dict:
return IA3Layer(state_dict)
return IA3Layer.from_state_dict_values(state_dict)
elif "w_norm" in state_dict:
return NormLayer(state_dict)
return NormLayer.from_state_dict_values(state_dict)
else:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")