Fixes to get FLUX Control LoRA working.

This commit is contained in:
Ryan Dick
2024-12-12 00:19:39 +00:00
committed by Kent Keirsey
parent f53da60b84
commit 040551d4fb

View File

@@ -6,13 +6,11 @@ import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
)
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora.lora_full_linear_sidecar_layer import LoRAFullLinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
@@ -54,7 +52,8 @@ class LoRAPatcher:
yield
finally:
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
cur_param = model.get_parameter(param_key)
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
@staticmethod
@torch.no_grad()
@@ -120,10 +119,20 @@ class LoRAPatcher:
if module_param.nelement() == lora_param_weight.nelement():
lora_param_weight = lora_param_weight.reshape(module_param.shape)
else:
expanded_weight = torch.zeros_like(lora_param_weight, device=module_param.device)
# This condition was added to handle layers in FLUX control LoRAs.
# TODO(ryand): Move the weight update into the LoRA layer so that the LoRAPatcher doesn't need
# to worry about this?
expanded_weight = torch.zeros_like(
lora_param_weight, dtype=module_param.dtype, device=module_param.device
)
slices = tuple(slice(0, dim) for dim in module_param.shape)
expanded_weight[slices] = module_param
setattr(module, param_name, expanded_weight)
setattr(
module,
param_name,
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
)
module_param = expanded_weight
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
@@ -253,8 +262,6 @@ class LoRAPatcher:
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
elif isinstance(lora_layer, ConcatenatedLoRALayer):
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
elif isinstance(lora_layer, FullLayer):
return LoRAFullLinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
else:
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
else: