Use a FluxControlLoRALayer when loading FLUX control LoRAs.

This commit is contained in:
Ryan Dick
2024-12-14 01:18:18 +00:00
parent 37e3089457
commit 80f64abd1e
2 changed files with 12 additions and 2 deletions

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
@@ -58,7 +59,16 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
layers: dict[str, BaseLayerPatch] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
if all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
if layer_key == "img_in":
# img_in is a special case because it changes the shape of the original weight.
layers[prefixed_key] = FluxControlLoRALayer(
layer_state_dict["lora_B.weight"],
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"],
)
elif all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
layers[prefixed_key] = LoRALayer(
layer_state_dict["lora_B.weight"],
None,

View File

@@ -40,7 +40,7 @@ class LinearSidecarWrapper(BaseSidecarWrapper):
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
# change the linear layer's in_features.
orig_input = input
input = orig_input[..., : self.orig_module.weight.shape[1]]
input = orig_input[..., : self.orig_module.in_features]
output = self.orig_module(input)
# Then, apply layers for which we have optimized implementations.