diff --git a/invokeai/backend/patches/layers/dora_layer.py b/invokeai/backend/patches/layers/dora_layer.py new file mode 100644 index 0000000000..3dff8876ea --- /dev/null +++ b/invokeai/backend/patches/layers/dora_layer.py @@ -0,0 +1,92 @@ +from typing import Dict, Optional + +import torch + +from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase +from invokeai.backend.util.calc_tensor_size import calc_tensors_size + + +class DoRALayer(LoRALayerBase): + """A DoRA layer. As defined in https://arxiv.org/pdf/2402.09353.""" + + def __init__( + self, + up: torch.Tensor, + down: torch.Tensor, + dora_scale: torch.Tensor, + alpha: float | None, + bias: Optional[torch.Tensor], + ): + super().__init__(alpha, bias) + self.up = up + self.down = down + self.dora_scale = dora_scale + + @classmethod + def from_state_dict_values(cls, values: Dict[str, torch.Tensor]): + alpha = cls._parse_alpha(values.get("alpha", None)) + bias = cls._parse_bias( + values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None) + ) + + layer = cls( + up=values["lora_up.weight"], + down=values["lora_down.weight"], + dora_scale=values["dora_scale"], + alpha=alpha, + bias=bias, + ) + + cls.warn_on_unhandled_keys( + values=values, + handled_keys={ + # Default keys. + "alpha", + "bias_indices", + "bias_values", + "bias_size", + # Layer-specific keys. + "lora_up.weight", + "lora_down.weight", + "dora_scale", + }, + ) + + return layer + + def _rank(self) -> int: + return self.down.shape[0] + + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + # Note: Variable names (e.g. delta_v) are based on the paper. + delta_v = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) + delta_v = delta_v.reshape(orig_weight.shape) + + # TODO(ryand): Should alpha be applied to delta_v here rather than the final diff? + # TODO(ryand): I expect this to fail if the original weight is BnB Quantized. This class shouldn't have to worry + # about that, but we should add a clear error message further up the stack. + + # At this point, out_weight is the unnormalized direction matrix. + out_weight = orig_weight + delta_v + + # TODO(ryand): Simplify this logic. + direction_norm = ( + out_weight.transpose(0, 1) + .reshape(out_weight.shape[0], -1) + .norm(dim=1, keepdim=True) + .reshape(out_weight.shape[1], *[1] * (out_weight.dim() - 1)) + .transpose(0, 1) + ) + + out_weight *= self.dora_scale / direction_norm + + return out_weight - orig_weight + + def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + super().to(device=device, dtype=dtype) + self.up = self.up.to(device=device, dtype=dtype) + self.down = self.down.to(device=device, dtype=dtype) + self.dora_scale = self.dora_scale.to(device=device, dtype=dtype) + + def calc_size(self) -> int: + return super().calc_size() + calc_tensors_size([self.up, self.down, self.dora_scale]) diff --git a/invokeai/backend/patches/layers/utils.py b/invokeai/backend/patches/layers/utils.py index 85a0c18dd5..8141a56644 100644 --- a/invokeai/backend/patches/layers/utils.py +++ b/invokeai/backend/patches/layers/utils.py @@ -3,6 +3,7 @@ from typing import Dict import torch from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.dora_layer import DoRALayer from invokeai.backend.patches.layers.full_layer import FullLayer from invokeai.backend.patches.layers.ia3_layer import IA3Layer from invokeai.backend.patches.layers.loha_layer import LoHALayer @@ -14,8 +15,9 @@ from invokeai.backend.patches.layers.norm_layer import NormLayer def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseLayerPatch: # Detect layers according to LyCORIS detection logic(`weight_list_det`) # https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules - - if "lora_up.weight" in state_dict: + if "dora_scale" in state_dict: + return DoRALayer.from_state_dict_values(state_dict) + elif "lora_up.weight" in state_dict: # LoRA a.k.a LoCon return LoRALayer.from_state_dict_values(state_dict) elif "hada_w1_a" in state_dict: