mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-02 19:55:07 -05:00
91 lines
3.7 KiB
Python
91 lines
3.7 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
|
|
import invokeai.backend.util.logging as logger
|
|
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
|
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
|
|
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
|
|
|
|
|
class LoRALayerBase(BaseLayerPatch):
|
|
"""Base class for all LoRA-like patching layers."""
|
|
|
|
# Note: It is tempting to make this a torch.nn.Module sub-class and make all tensors 'torch.nn.Parameter's. Then we
|
|
# could inherit automatic .to(...) behavior for this class, its subclasses, and all sidecar layers that wrap a
|
|
# LoRALayerBase. We would also be able to implement a single calc_size() method that could be inherited by all
|
|
# subclasses. But, it turns out that the speed overhead of the default .to(...) implementation in torch.nn.Module is
|
|
# noticeable, so for now we have opted not to use torch.nn.Module.
|
|
|
|
def __init__(self, alpha: float | None, bias: torch.Tensor | None):
|
|
self._alpha = alpha
|
|
self.bias = bias
|
|
|
|
@classmethod
|
|
def _parse_bias(
|
|
cls, bias_indices: torch.Tensor | None, bias_values: torch.Tensor | None, bias_size: torch.Tensor | None
|
|
) -> torch.Tensor | None:
|
|
"""Helper function to parse a bias tensor from a state dict in LyCORIS format."""
|
|
assert (bias_indices is None) == (bias_values is None) == (bias_size is None)
|
|
|
|
bias = None
|
|
if bias_indices is not None:
|
|
bias = torch.sparse_coo_tensor(bias_indices, bias_values, tuple(bias_size))
|
|
return bias
|
|
|
|
@classmethod
|
|
def _parse_alpha(
|
|
cls,
|
|
alpha: torch.Tensor | None,
|
|
) -> float | None:
|
|
return alpha.item() if alpha is not None else None
|
|
|
|
def _rank(self) -> int | None:
|
|
"""Return the rank of the LoRA-like layer. Or None if the layer does not have a rank. This value is used to
|
|
calculate the scale.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def scale(self) -> float:
|
|
rank = self._rank()
|
|
if self._alpha is None or rank is None:
|
|
return 1.0
|
|
return self._alpha / rank
|
|
|
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]:
|
|
return self.bias
|
|
|
|
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
|
scale = self.scale()
|
|
params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)}
|
|
bias = self.get_bias(orig_parameters.get("bias", None))
|
|
if bias is not None:
|
|
params["bias"] = bias * (weight * scale)
|
|
|
|
# Reshape all params to match the original module's shape.
|
|
for param_name, param_weight in params.items():
|
|
orig_param = orig_parameters[param_name]
|
|
if param_weight.shape != get_param_shape(orig_param):
|
|
params[param_name] = param_weight.reshape(get_param_shape(orig_param))
|
|
|
|
return params
|
|
|
|
@classmethod
|
|
def warn_on_unhandled_keys(cls, values: dict[str, torch.Tensor], handled_keys: set[str]):
|
|
"""Log a warning if values contains unhandled keys."""
|
|
unknown_keys = set(values.keys()) - handled_keys
|
|
if unknown_keys:
|
|
logger.warning(
|
|
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Unexpected keys: {unknown_keys}"
|
|
)
|
|
|
|
def calc_size(self) -> int:
|
|
return calc_tensors_size([self.bias])
|
|
|
|
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
|
if self.bias is not None:
|
|
self.bias = self.bias.to(device=device, dtype=dtype)
|