mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-01 23:35:10 -05:00
116 lines
4.1 KiB
Python
116 lines
4.1 KiB
Python
from typing import Dict, Optional
|
|
|
|
import torch
|
|
|
|
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
|
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
|
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
|
|
|
|
|
class DoRALayer(LoRALayerBase):
|
|
"""A DoRA layer. As defined in https://arxiv.org/pdf/2402.09353."""
|
|
|
|
def __init__(
|
|
self,
|
|
up: torch.Tensor,
|
|
down: torch.Tensor,
|
|
dora_scale: torch.Tensor,
|
|
alpha: float | None,
|
|
bias: Optional[torch.Tensor],
|
|
):
|
|
super().__init__(alpha, bias)
|
|
self.up = up
|
|
self.down = down
|
|
self.dora_scale = dora_scale
|
|
|
|
@classmethod
|
|
def from_state_dict_values(cls, values: Dict[str, torch.Tensor]):
|
|
alpha = cls._parse_alpha(values.get("alpha", None))
|
|
bias = cls._parse_bias(
|
|
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
|
|
)
|
|
|
|
layer = cls(
|
|
up=values["lora_up.weight"],
|
|
down=values["lora_down.weight"],
|
|
dora_scale=values["dora_scale"],
|
|
alpha=alpha,
|
|
bias=bias,
|
|
)
|
|
|
|
cls.warn_on_unhandled_keys(
|
|
values=values,
|
|
handled_keys={
|
|
# Default keys.
|
|
"alpha",
|
|
"bias_indices",
|
|
"bias_values",
|
|
"bias_size",
|
|
# Layer-specific keys.
|
|
"lora_up.weight",
|
|
"lora_down.weight",
|
|
"dora_scale",
|
|
},
|
|
)
|
|
|
|
return layer
|
|
|
|
def _rank(self) -> int:
|
|
return self.down.shape[0]
|
|
|
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
orig_weight = cast_to_device(orig_weight, self.up.device)
|
|
|
|
# Note: Variable names (e.g. delta_v) are based on the paper.
|
|
delta_v = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
|
delta_v = delta_v.reshape(orig_weight.shape)
|
|
|
|
delta_v = delta_v * self.scale()
|
|
|
|
# At this point, out_weight is the unnormalized direction matrix.
|
|
out_weight = orig_weight + delta_v
|
|
|
|
# TODO(ryand): Simplify this logic.
|
|
direction_norm = (
|
|
out_weight.transpose(0, 1)
|
|
.reshape(out_weight.shape[1], -1)
|
|
.norm(dim=1, keepdim=True)
|
|
.reshape(out_weight.shape[1], *[1] * (out_weight.dim() - 1))
|
|
.transpose(0, 1)
|
|
)
|
|
|
|
out_weight *= self.dora_scale / direction_norm
|
|
|
|
return out_weight - orig_weight
|
|
|
|
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
|
super().to(device=device, dtype=dtype)
|
|
self.up = self.up.to(device=device, dtype=dtype)
|
|
self.down = self.down.to(device=device, dtype=dtype)
|
|
self.dora_scale = self.dora_scale.to(device=device, dtype=dtype)
|
|
|
|
def calc_size(self) -> int:
|
|
return super().calc_size() + calc_tensors_size([self.up, self.down, self.dora_scale])
|
|
|
|
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
|
if any(p.device.type == "meta" for p in orig_parameters.values()):
|
|
# If any of the original parameters are on the 'meta' device, we assume this is because the base model is in
|
|
# a quantization format that doesn't allow easy dequantization.
|
|
raise RuntimeError(
|
|
"The base model quantization format (likely bitsandbytes) is not compatible with DoRA patches."
|
|
)
|
|
|
|
scale = self.scale()
|
|
params = {"weight": self.get_weight(orig_parameters["weight"]) * weight}
|
|
bias = self.get_bias(orig_parameters.get("bias", None))
|
|
if bias is not None:
|
|
params["bias"] = bias * (weight * scale)
|
|
|
|
# Reshape all params to match the original module's shape.
|
|
for param_name, param_weight in params.items():
|
|
orig_param = orig_parameters[param_name]
|
|
if param_weight.shape != orig_param.shape:
|
|
params[param_name] = param_weight.reshape(orig_param.shape)
|
|
|
|
return params
|