mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Update kohya_lora_manager.py
Bias parsing, fix LoHa parsing and weight calculation
This commit is contained in:
@@ -21,6 +21,7 @@ class LoRALayer:
|
||||
lora_name: str
|
||||
name: str
|
||||
scale: float
|
||||
|
||||
up: torch.nn.Module
|
||||
mid: Optional[torch.nn.Module] = None
|
||||
down: torch.nn.Module
|
||||
@@ -54,6 +55,7 @@ class LoHALayer:
|
||||
w2_b: torch.Tensor
|
||||
t1: Optional[torch.Tensor] = None
|
||||
t2: Optional[torch.Tensor] = None
|
||||
bias: Optional[torch.Tensor] = None
|
||||
|
||||
org_module: torch.nn.Module
|
||||
|
||||
@@ -77,51 +79,21 @@ class LoHALayer:
|
||||
op = torch.nn.functional.linear
|
||||
extra_args = {}
|
||||
|
||||
# implementation according to lycoris
|
||||
# i'm not so sure what happens here, but to properly work
|
||||
# i moved scaling from weight calculation to output calculation
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/blob/main/lycoris/loha.py#L175
|
||||
if self.t1 is None:
|
||||
diff_weight = ((self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b))
|
||||
weight = self.org_module.weight.data.reshape(diff_weight.shape) + diff_weight
|
||||
weight = ((self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b))
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a)
|
||||
weight = self.org_module.weight.data + rebuild1 * rebuild2
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
bias = None if self.org_module.bias is None else self.org_module.bias.data
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
return output + op(
|
||||
*input_h,
|
||||
weight.view(self.org_module.weight.shape),
|
||||
bias,
|
||||
(weight + bias).view(self.org_module.weight.shape),
|
||||
None,
|
||||
**extra_args,
|
||||
) * lora.multiplier * self.scale
|
||||
|
||||
|
||||
# implementation according to a1111-sd-webui-locon extension
|
||||
# https://github.com/KohakuBlueleaf/a1111-sd-webui-locon/blob/main/scripts/main.py#L248
|
||||
def pro3(t, wa, wb):
|
||||
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
||||
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
||||
|
||||
bias = 0 # TODO: implement bias
|
||||
if self.t1 is None:
|
||||
return output + op(
|
||||
*input_h,
|
||||
((self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) + bias).view(self.org_module.weight.shape),
|
||||
bias=None,
|
||||
**extra_args
|
||||
) * lora.multiplier * self.scale
|
||||
|
||||
else:
|
||||
return output + op(
|
||||
*input_h,
|
||||
(pro3(self.t1, self.w1_a, self.w1_b)
|
||||
* pro3(self.t2, self.w2_a, self.w2_b) + bias).view(self.org_module.weight.shape),
|
||||
bias=None,
|
||||
**extra_args
|
||||
) * lora.multiplier * self.scale
|
||||
|
||||
|
||||
class LoRAModuleWrapper:
|
||||
@@ -223,8 +195,6 @@ class LoRA:
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.wrapper = wrapper
|
||||
self.rank = None
|
||||
self.alpha = None
|
||||
|
||||
def load_from_dict(self, state_dict):
|
||||
state_dict_groupped = dict()
|
||||
@@ -235,22 +205,6 @@ class LoRA:
|
||||
state_dict_groupped[stem] = dict()
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
if leaf.endswith("alpha"):
|
||||
if self.alpha is None:
|
||||
self.alpha = value.item()
|
||||
continue
|
||||
|
||||
if (
|
||||
stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER)
|
||||
or stem.startswith(self.wrapper.LORA_PREFIX_UNET)
|
||||
):
|
||||
if (
|
||||
self.rank is None
|
||||
and leaf == "lora_down.weight"
|
||||
and len(value.size()) == 2
|
||||
):
|
||||
self.rank = value.shape[0]
|
||||
|
||||
|
||||
for stem, values in state_dict_groupped.items():
|
||||
if stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER):
|
||||
@@ -264,6 +218,22 @@ class LoRA:
|
||||
print(f">> Missing layer: {stem}")
|
||||
continue
|
||||
|
||||
# TODO: diff key
|
||||
|
||||
bias = None
|
||||
alpha = None
|
||||
|
||||
if "alpha" in values:
|
||||
alpha = values["alpha"].item()
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
bias = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
value_down = values["lora_down.weight"]
|
||||
@@ -305,12 +275,10 @@ class LoRA:
|
||||
layer_up.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
alpha = None
|
||||
if "alpha" in values:
|
||||
alpha = values["alpha"].item()
|
||||
rank = value_down.shape[0]
|
||||
|
||||
|
||||
layer = LoRALayer(self.name, stem, self.rank, alpha)
|
||||
layer = LoRALayer(self.name, stem, rank, alpha)
|
||||
#layer.bias = bias # TODO: find and debug lora/locon with bias
|
||||
layer.down = layer_down
|
||||
layer.mid = layer_mid
|
||||
layer.up = layer_up
|
||||
@@ -318,26 +286,25 @@ class LoRA:
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
|
||||
alpha = None
|
||||
if "alpha" in values:
|
||||
alpha = values["alpha"].item()
|
||||
|
||||
rank = values["hada_w1_b"].shape[0]
|
||||
|
||||
layer = LoHALayer(self.name, stem, rank, alpha)
|
||||
layer.org_module = wrapped
|
||||
layer.bias = bias
|
||||
|
||||
layer.w1_a = values["hada_w1_a"].to(device=self.device, dtype=self.dtype).requires_grad_(False)
|
||||
layer.w1_b = values["hada_w1_b"].to(device=self.device, dtype=self.dtype).requires_grad_(False)
|
||||
layer.w2_a = values["hada_w2_a"].to(device=self.device, dtype=self.dtype).requires_grad_(False)
|
||||
layer.w2_b = values["hada_w2_b"].to(device=self.device, dtype=self.dtype).requires_grad_(False)
|
||||
|
||||
if type(wrapped) == torch.nn.Conv2d and wrapped.kernel_size != (1, 1):
|
||||
layer.t1 = values["hada_t1"].to(device=self.device, dtype=self.dtype).requires_grad_(False)
|
||||
layer.t2 = values["hada_t2"].to(device=self.device, dtype=self.dtype).requires_grad_(False)
|
||||
layer.w1_a = values["hada_w1_a"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w1_b = values["hada_w1_b"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w2_a = values["hada_w2_a"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w2_b = values["hada_w2_b"].to(device=self.device, dtype=self.dtype)
|
||||
|
||||
if "hada_t1" in values:
|
||||
layer.t1 = values["hada_t1"].to(device=self.device, dtype=self.dtype)
|
||||
else:
|
||||
layer.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
layer.t2 = values["hada_t2"].to(device=self.device, dtype=self.dtype)
|
||||
else:
|
||||
layer.t2 = None
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user