diff --git a/invokeai/backend/lora/lora_patcher.py b/invokeai/backend/lora/lora_patcher.py index 6640dcfdd8..1e32053f5b 100644 --- a/invokeai/backend/lora/lora_patcher.py +++ b/invokeai/backend/lora/lora_patcher.py @@ -6,13 +6,11 @@ import torch from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer from invokeai.backend.lora.layers.lora_layer import LoRALayer -from invokeai.backend.lora.layers.full_layer import FullLayer from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import ( ConcatenatedLoRALinearSidecarLayer, ) from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer -from invokeai.backend.lora.sidecar_layers.lora.lora_full_linear_sidecar_layer import LoRAFullLinearSidecarLayer from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage @@ -54,7 +52,8 @@ class LoRAPatcher: yield finally: for param_key, weight in original_weights.get_changed_weights(): - model.get_parameter(param_key).copy_(weight) + cur_param = model.get_parameter(param_key) + cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True) @staticmethod @torch.no_grad() @@ -120,10 +119,20 @@ class LoRAPatcher: if module_param.nelement() == lora_param_weight.nelement(): lora_param_weight = lora_param_weight.reshape(module_param.shape) else: - expanded_weight = torch.zeros_like(lora_param_weight, device=module_param.device) + # This condition was added to handle layers in FLUX control LoRAs. + # TODO(ryand): Move the weight update into the LoRA layer so that the LoRAPatcher doesn't need + # to worry about this? + expanded_weight = torch.zeros_like( + lora_param_weight, dtype=module_param.dtype, device=module_param.device + ) slices = tuple(slice(0, dim) for dim in module_param.shape) expanded_weight[slices] = module_param - setattr(module, param_name, expanded_weight) + setattr( + module, + param_name, + torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad), + ) + module_param = expanded_weight lora_param_weight *= patch_weight * layer_scale module_param += lora_param_weight.to(dtype=dtype) @@ -253,8 +262,6 @@ class LoRAPatcher: return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight) elif isinstance(lora_layer, ConcatenatedLoRALayer): return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight) - elif isinstance(lora_layer, FullLayer): - return LoRAFullLinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight) else: raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}") else: