mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 16:07:54 -05:00
Compare commits
1 Commits
main
...
revert-798
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2576bb00ab |
@@ -13,12 +13,6 @@ from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
|
||||
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a sidecar linear LoRALayer."""
|
||||
# up matrix and down matrix have different ranks so we can't simply multiply them
|
||||
if lora_layer.up.shape[1] != lora_layer.down.shape[0]:
|
||||
x = torch.nn.functional.linear(input, lora_layer.get_weight(lora_weight), bias=lora_layer.bias)
|
||||
x *= lora_weight * lora_layer.scale()
|
||||
return x
|
||||
|
||||
x = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, lora_layer.mid)
|
||||
|
||||
@@ -19,7 +19,6 @@ class LoRALayer(LoRALayerBase):
|
||||
self.up = up
|
||||
self.mid = mid
|
||||
self.down = down
|
||||
self.are_ranks_equal = up.shape[1] == down.shape[0]
|
||||
|
||||
@classmethod
|
||||
def from_state_dict_values(
|
||||
@@ -59,42 +58,12 @@ class LoRALayer(LoRALayerBase):
|
||||
def _rank(self) -> int:
|
||||
return self.down.shape[0]
|
||||
|
||||
def fuse_weights(self, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Fuse the weights of the up and down matrices of a LoRA layer with different ranks.
|
||||
|
||||
Since the Huggingface implementation of KQV projections are fused, when we convert to Kohya format
|
||||
the LoRA weights have different ranks. This function handles the fusion of these differently sized
|
||||
matrices.
|
||||
"""
|
||||
|
||||
fused_lora = torch.zeros((up.shape[0], down.shape[1]), device=down.device, dtype=down.dtype)
|
||||
rank_diff = down.shape[0] / up.shape[1]
|
||||
|
||||
if rank_diff > 1:
|
||||
rank_diff = down.shape[0] / up.shape[1]
|
||||
w_down = down.chunk(int(rank_diff), dim=0)
|
||||
for w_down_chunk in w_down:
|
||||
fused_lora = fused_lora + (torch.mm(up, w_down_chunk))
|
||||
else:
|
||||
rank_diff = up.shape[1] / down.shape[0]
|
||||
w_up = up.chunk(int(rank_diff), dim=0)
|
||||
for w_up_chunk in w_up:
|
||||
fused_lora = fused_lora + (torch.mm(w_up_chunk, down))
|
||||
|
||||
return fused_lora
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
# up matrix and down matrix have different ranks so we can't simply multiply them
|
||||
if not self.are_ranks_equal:
|
||||
weight = self.fuse_weights(self.up, self.down)
|
||||
return weight
|
||||
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
@@ -20,14 +20,6 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
|
||||
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
)
|
||||
|
||||
# A regex pattern that matches all of the last layer keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_unet_final_layer_linear.alpha
|
||||
# lora_unet_final_layer_linear.lora_down.weight
|
||||
# lora_unet_final_layer_linear.lora_up.weight
|
||||
FLUX_KOHYA_LAST_LAYER_KEY_REGEX = r"lora_unet_final_layer_(linear|linear1|linear2)_?(.*)"
|
||||
|
||||
# A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha
|
||||
@@ -52,7 +44,6 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
|
||||
"""
|
||||
return all(
|
||||
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
|
||||
or re.match(FLUX_KOHYA_LAST_LAYER_KEY_REGEX, k)
|
||||
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
|
||||
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
|
||||
for k in state_dict.keys()
|
||||
@@ -74,9 +65,6 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
||||
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for layer_name, layer_state_dict in grouped_state_dict.items():
|
||||
if layer_name.startswith("lora_unet"):
|
||||
# Skip the final layer. This is incompatible with current model definition.
|
||||
if layer_name.startswith("lora_unet_final_layer"):
|
||||
continue
|
||||
transformer_grouped_sd[layer_name] = layer_state_dict
|
||||
elif layer_name.startswith("lora_te1"):
|
||||
clip_grouped_sd[layer_name] = layer_state_dict
|
||||
|
||||
Reference in New Issue
Block a user