Update kohya_lora_manager.py

Bias parsing, fix LoHa parsing and weight calculation
This commit is contained in:
Sergey Borisov
2023-04-06 01:44:20 +03:00
parent b62cce20b8
commit baf60948ee

View File

@@ -21,6 +21,7 @@ class LoRALayer:
lora_name: str
name: str
scale: float
up: torch.nn.Module
mid: Optional[torch.nn.Module] = None
down: torch.nn.Module
@@ -54,6 +55,7 @@ class LoHALayer:
w2_b: torch.Tensor
t1: Optional[torch.Tensor] = None
t2: Optional[torch.Tensor] = None
bias: Optional[torch.Tensor] = None
org_module: torch.nn.Module
@@ -77,51 +79,21 @@ class LoHALayer:
op = torch.nn.functional.linear
extra_args = {}
# implementation according to lycoris
# i'm not so sure what happens here, but to properly work
# i moved scaling from weight calculation to output calculation
# https://github.com/KohakuBlueleaf/LyCORIS/blob/main/lycoris/loha.py#L175
if self.t1 is None:
diff_weight = ((self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b))
weight = self.org_module.weight.data.reshape(diff_weight.shape) + diff_weight
weight = ((self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b))
else:
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a)
weight = self.org_module.weight.data + rebuild1 * rebuild2
weight = rebuild1 * rebuild2
bias = None if self.org_module.bias is None else self.org_module.bias.data
bias = self.bias if self.bias is not None else 0
return output + op(
*input_h,
weight.view(self.org_module.weight.shape),
bias,
(weight + bias).view(self.org_module.weight.shape),
None,
**extra_args,
) * lora.multiplier * self.scale
# implementation according to a1111-sd-webui-locon extension
# https://github.com/KohakuBlueleaf/a1111-sd-webui-locon/blob/main/scripts/main.py#L248
def pro3(t, wa, wb):
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
bias = 0 # TODO: implement bias
if self.t1 is None:
return output + op(
*input_h,
((self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) + bias).view(self.org_module.weight.shape),
bias=None,
**extra_args
) * lora.multiplier * self.scale
else:
return output + op(
*input_h,
(pro3(self.t1, self.w1_a, self.w1_b)
* pro3(self.t2, self.w2_a, self.w2_b) + bias).view(self.org_module.weight.shape),
bias=None,
**extra_args
) * lora.multiplier * self.scale
class LoRAModuleWrapper:
@@ -223,8 +195,6 @@ class LoRA:
self.device = device
self.dtype = dtype
self.wrapper = wrapper
self.rank = None
self.alpha = None
def load_from_dict(self, state_dict):
state_dict_groupped = dict()
@@ -235,22 +205,6 @@ class LoRA:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value
if leaf.endswith("alpha"):
if self.alpha is None:
self.alpha = value.item()
continue
if (
stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER)
or stem.startswith(self.wrapper.LORA_PREFIX_UNET)
):
if (
self.rank is None
and leaf == "lora_down.weight"
and len(value.size()) == 2
):
self.rank = value.shape[0]
for stem, values in state_dict_groupped.items():
if stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER):
@@ -264,6 +218,22 @@ class LoRA:
print(f">> Missing layer: {stem}")
continue
# TODO: diff key
bias = None
alpha = None
if "alpha" in values:
alpha = values["alpha"].item()
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
).to(device=self.device, dtype=self.dtype)
# lora and locon
if "lora_down.weight" in values:
value_down = values["lora_down.weight"]
@@ -305,12 +275,10 @@ class LoRA:
layer_up.to(device=self.device, dtype=self.dtype)
alpha = None
if "alpha" in values:
alpha = values["alpha"].item()
rank = value_down.shape[0]
layer = LoRALayer(self.name, stem, self.rank, alpha)
layer = LoRALayer(self.name, stem, rank, alpha)
#layer.bias = bias # TODO: find and debug lora/locon with bias
layer.down = layer_down
layer.mid = layer_mid
layer.up = layer_up
@@ -318,26 +286,25 @@ class LoRA:
# loha
elif "hada_w1_b" in values:
alpha = None
if "alpha" in values:
alpha = values["alpha"].item()
rank = values["hada_w1_b"].shape[0]
layer = LoHALayer(self.name, stem, rank, alpha)
layer.org_module = wrapped
layer.bias = bias
layer.w1_a = values["hada_w1_a"].to(device=self.device, dtype=self.dtype).requires_grad_(False)
layer.w1_b = values["hada_w1_b"].to(device=self.device, dtype=self.dtype).requires_grad_(False)
layer.w2_a = values["hada_w2_a"].to(device=self.device, dtype=self.dtype).requires_grad_(False)
layer.w2_b = values["hada_w2_b"].to(device=self.device, dtype=self.dtype).requires_grad_(False)
if type(wrapped) == torch.nn.Conv2d and wrapped.kernel_size != (1, 1):
layer.t1 = values["hada_t1"].to(device=self.device, dtype=self.dtype).requires_grad_(False)
layer.t2 = values["hada_t2"].to(device=self.device, dtype=self.dtype).requires_grad_(False)
layer.w1_a = values["hada_w1_a"].to(device=self.device, dtype=self.dtype)
layer.w1_b = values["hada_w1_b"].to(device=self.device, dtype=self.dtype)
layer.w2_a = values["hada_w2_a"].to(device=self.device, dtype=self.dtype)
layer.w2_b = values["hada_w2_b"].to(device=self.device, dtype=self.dtype)
if "hada_t1" in values:
layer.t1 = values["hada_t1"].to(device=self.device, dtype=self.dtype)
else:
layer.t1 = None
if "hada_t2" in values:
layer.t2 = values["hada_t2"].to(device=self.device, dtype=self.dtype)
else:
layer.t2 = None
else: