mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-03 01:24:56 -05:00
66 lines
2.6 KiB
Python
66 lines
2.6 KiB
Python
from dataclasses import dataclass
|
|
from typing import Sequence
|
|
|
|
import torch
|
|
|
|
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
|
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
|
|
|
|
|
|
@dataclass
|
|
class Range:
|
|
start: int
|
|
end: int
|
|
|
|
|
|
class MergedLayerPatch(BaseLayerPatch):
|
|
"""A patch layer that is composed of multiple sub-layers merged together.
|
|
|
|
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
|
|
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
|
|
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
lora_layers: Sequence[BaseLayerPatch],
|
|
ranges: Sequence[Range],
|
|
):
|
|
super().__init__()
|
|
|
|
self.lora_layers = lora_layers
|
|
# self.ranges[i] is the range for the i'th lora layer along the 0'th weight dimension.
|
|
self.ranges = ranges
|
|
assert len(self.ranges) == len(self.lora_layers)
|
|
|
|
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
|
out_parameters: dict[str, torch.Tensor] = {}
|
|
|
|
for lora_layer, range in zip(self.lora_layers, self.ranges, strict=True):
|
|
sliced_parameters: dict[str, torch.Tensor] = {
|
|
n: p[range.start : range.end] for n, p in orig_parameters.items()
|
|
}
|
|
|
|
# Note that `weight` is applied in the sub-layers, no need to apply it in this function.
|
|
layer_out_parameters = lora_layer.get_parameters(sliced_parameters, weight)
|
|
|
|
for out_param_name, out_param in layer_out_parameters.items():
|
|
if out_param_name not in out_parameters:
|
|
# If not already in the output dict, initialize an output tensor with the same shape as the full
|
|
# original parameter.
|
|
out_parameters[out_param_name] = torch.zeros(
|
|
get_param_shape(orig_parameters[out_param_name]),
|
|
dtype=out_param.dtype,
|
|
device=out_param.device,
|
|
)
|
|
out_parameters[out_param_name][range.start : range.end] += out_param
|
|
|
|
return out_parameters
|
|
|
|
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
|
for lora_layer in self.lora_layers:
|
|
lora_layer.to(device=device, dtype=dtype)
|
|
|
|
def calc_size(self) -> int:
|
|
return sum(lora_layer.calc_size() for lora_layer in self.lora_layers)
|