Compare commits

...

1 Commits

Author SHA1 Message Date
psychedelicious
2576bb00ab Revert "Implementing support for Non-Standard LoRA Format (#7985)"
This reverts commit 1f63b60021.
2025-05-06 10:53:41 +10:00
3 changed files with 0 additions and 49 deletions

View File

@@ -13,12 +13,6 @@ from invokeai.backend.patches.layers.lora_layer import LoRALayer
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a sidecar linear LoRALayer."""
# up matrix and down matrix have different ranks so we can't simply multiply them
if lora_layer.up.shape[1] != lora_layer.down.shape[0]:
x = torch.nn.functional.linear(input, lora_layer.get_weight(lora_weight), bias=lora_layer.bias)
x *= lora_weight * lora_layer.scale()
return x
x = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x = torch.nn.functional.linear(x, lora_layer.mid)

View File

@@ -19,7 +19,6 @@ class LoRALayer(LoRALayerBase):
self.up = up
self.mid = mid
self.down = down
self.are_ranks_equal = up.shape[1] == down.shape[0]
@classmethod
def from_state_dict_values(
@@ -59,42 +58,12 @@ class LoRALayer(LoRALayerBase):
def _rank(self) -> int:
return self.down.shape[0]
def fuse_weights(self, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor:
"""
Fuse the weights of the up and down matrices of a LoRA layer with different ranks.
Since the Huggingface implementation of KQV projections are fused, when we convert to Kohya format
the LoRA weights have different ranks. This function handles the fusion of these differently sized
matrices.
"""
fused_lora = torch.zeros((up.shape[0], down.shape[1]), device=down.device, dtype=down.dtype)
rank_diff = down.shape[0] / up.shape[1]
if rank_diff > 1:
rank_diff = down.shape[0] / up.shape[1]
w_down = down.chunk(int(rank_diff), dim=0)
for w_down_chunk in w_down:
fused_lora = fused_lora + (torch.mm(up, w_down_chunk))
else:
rank_diff = up.shape[1] / down.shape[0]
w_up = up.chunk(int(rank_diff), dim=0)
for w_up_chunk in w_up:
fused_lora = fused_lora + (torch.mm(w_up_chunk, down))
return fused_lora
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
# up matrix and down matrix have different ranks so we can't simply multiply them
if not self.are_ranks_equal:
weight = self.fuse_weights(self.up, self.down)
return weight
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight

View File

@@ -20,14 +20,6 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
)
# A regex pattern that matches all of the last layer keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_unet_final_layer_linear.alpha
# lora_unet_final_layer_linear.lora_down.weight
# lora_unet_final_layer_linear.lora_up.weight
FLUX_KOHYA_LAST_LAYER_KEY_REGEX = r"lora_unet_final_layer_(linear|linear1|linear2)_?(.*)"
# A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha
@@ -52,7 +44,6 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
"""
return all(
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
or re.match(FLUX_KOHYA_LAST_LAYER_KEY_REGEX, k)
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
@@ -74,9 +65,6 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("lora_unet"):
# Skip the final layer. This is incompatible with current model definition.
if layer_name.startswith("lora_unet_final_layer"):
continue
transformer_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te1"):
clip_grouped_sd[layer_name] = layer_state_dict