diff --git a/invokeai/backend/lora/lora_layer_wrappers.py b/invokeai/backend/lora/lora_layer_wrappers.py index 78edd2b3b9..59d4f58060 100644 --- a/invokeai/backend/lora/lora_layer_wrappers.py +++ b/invokeai/backend/lora/lora_layer_wrappers.py @@ -1,9 +1,11 @@ import torch from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer +from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer +from invokeai.backend.lora.layers.lora_layer import LoRALayer -class LoRAModuleWrapper(torch.nn.Module): +class LoRASidecarWrapper(torch.nn.Module): def __init__(self, orig_module: torch.nn.Module, lora_layers: list[AnyLoRALayer], lora_weights: list[float]): super().__init__() self._orig_module = orig_module @@ -19,40 +21,113 @@ class LoRAModuleWrapper(torch.nn.Module): self._lora_weights.append(lora_weight) @torch.no_grad() - def _get_lora_patched_parameters(self, params: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - for lora_layer, lora_weight in zip(self._lora_layers, self._lora_weights, strict=True): + def _get_lora_patched_parameters( + self, orig_params: dict[str, torch.Tensor], lora_layers: list[AnyLoRALayer], lora_weights: list[float] + ) -> dict[str, torch.Tensor]: + params: dict[str, torch.Tensor] = {} + for lora_layer, lora_weight in zip(lora_layers, lora_weights, strict=True): layer_params = lora_layer.get_parameters(self._orig_module) for param_name, param_weight in layer_params.items(): - if params[param_name].shape != param_weight.shape: - param_weight = param_weight.reshape(params[param_name].shape) + if orig_params[param_name].shape != param_weight.shape: + param_weight = param_weight.reshape(orig_params[param_name].shape) - # NOTE: It is important that params[param_name] is not modified in-place, because we initialize it - # with the original parameter - which we don't want to modify. In other words, - # `out_params[param_name] += ...` would not work. - params[param_name] = params[param_name] + param_weight * (lora_layer.scale() * lora_weight) + if param_name not in params: + params[param_name] = param_weight * (lora_layer.scale() * lora_weight) + else: + params[param_name] += param_weight * (lora_layer.scale() * lora_weight) return params -class LoRALinearWrapper(LoRAModuleWrapper): +class LoRALinearWrapper(LoRASidecarWrapper): + def _lora_linear_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor: + """An optimized implementation of the residual calculation for a Linear LoRALayer.""" + x = torch.nn.functional.linear(input, lora_layer.down) + if lora_layer.mid is not None: + x = torch.nn.functional.linear(x, lora_layer.mid) + x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias) + x *= lora_weight * lora_layer.scale() + return x + + def _concatenated_lora_forward( + self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float + ) -> torch.Tensor: + """An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer.""" + x_chunks: list[torch.Tensor] = [] + for lora_layer in concatenated_lora_layer.lora_layers: + x_chunk = torch.nn.functional.linear(input, lora_layer.down) + if lora_layer.mid is not None: + x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid) + x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias) + x_chunk *= lora_weight * lora_layer.scale() + x_chunks.append(x_chunk) + + # TODO(ryand): Generalize to support concat_axis != 0. + assert concatenated_lora_layer.concat_axis == 0 + x = torch.cat(x_chunks, dim=-1) + return x + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # Split the LoRA layers into those that have optimized implementations and those that don't. + optimized_layer_types = (LoRALayer, ConcatenatedLoRALayer) + optimized_layers = [ + (layer, weight) + for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True) + if isinstance(layer, optimized_layer_types) + ] + non_optimized_layers = [ + (layer, weight) + for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True) + if not isinstance(layer, optimized_layer_types) + ] + + # First, calculate the residual for LoRA layers for which there is an optimized implementation. + residual = None + for lora_layer, lora_weight in optimized_layers: + if isinstance(lora_layer, LoRALayer): + added_residual = self._lora_linear_forward(input, lora_layer, lora_weight) + elif isinstance(lora_layer, ConcatenatedLoRALayer): + added_residual = self._concatenated_lora_forward(input, lora_layer, lora_weight) + else: + raise ValueError(f"Unsupported LoRA layer type: {type(lora_layer)}") + + if residual is None: + residual = added_residual + else: + residual += added_residual + + # Next, calculate the residuals for the LoRA layers for which there is no optimized implementation. + if non_optimized_layers: + unoptimized_layers, unoptimized_weights = zip(*non_optimized_layers, strict=True) + params = self._get_lora_patched_parameters( + orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias}, + lora_layers=unoptimized_layers, + lora_weights=unoptimized_weights, + ) + added_residual = torch.nn.functional.linear(input, params["weight"], params.get("bias", None)) + if residual is None: + residual = added_residual + else: + residual += added_residual + + return self.orig_module(input) + residual + + +class LoRAConv1dWrapper(LoRASidecarWrapper): def forward(self, input: torch.Tensor) -> torch.Tensor: params = self._get_lora_patched_parameters( - params={"weight": self._orig_module.weight, "bias": self._orig_module.bias} + orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias}, + lora_layers=self._lora_layers, + lora_weights=self._lora_weights, ) - return torch.nn.functional.linear(input, params["weight"], params["bias"]) + return self.orig_module(input) + torch.nn.functional.conv1d(input, params["weight"], params.get("bias", None)) -class LoRAConv1dWrapper(LoRAModuleWrapper): +class LoRAConv2dWrapper(LoRASidecarWrapper): def forward(self, input: torch.Tensor) -> torch.Tensor: params = self._get_lora_patched_parameters( - params={"weight": self._orig_module.weight, "bias": self._orig_module.bias} + orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias}, + lora_layers=self._lora_layers, + lora_weights=self._lora_weights, ) - return torch.nn.functional.conv1d(input, params["weight"], params["bias"]) - - -class LoRAConv2dWrapper(LoRAModuleWrapper): - def forward(self, input: torch.Tensor) -> torch.Tensor: - params = self._get_lora_patched_parameters( - params={"weight": self._orig_module.weight, "bias": self._orig_module.bias} - ) - return torch.nn.functional.conv2d(input, params["weight"], params["bias"]) + return self.orig_module(input) + torch.nn.functional.conv2d(input, params["weight"], params.get("bias", None)) diff --git a/invokeai/backend/lora/lora_patcher.py b/invokeai/backend/lora/lora_patcher.py index 31d2457980..0859c6cdd7 100644 --- a/invokeai/backend/lora/lora_patcher.py +++ b/invokeai/backend/lora/lora_patcher.py @@ -10,7 +10,7 @@ from invokeai.backend.lora.lora_layer_wrappers import ( LoRAConv1dWrapper, LoRAConv2dWrapper, LoRALinearWrapper, - LoRAModuleWrapper, + LoRASidecarWrapper, ) from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import ( @@ -196,8 +196,8 @@ class LoRAPatcher: model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened ) - # Replace the original module with a LoRAModuleWrapper if it has not already been done. - if not isinstance(module, LoRAModuleWrapper): + # Replace the original module with a LoRASidecarWrapper if it has not already been done. + if not isinstance(module, LoRASidecarWrapper): lora_wrapper_layer = LoRAPatcher._initialize_lora_wrapper_layer(module) original_modules[module_key] = module module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key) @@ -211,7 +211,7 @@ class LoRAPatcher: # Move the LoRA layer to the same device/dtype as the orig module. layer.to(device=orig_module.weight.device, dtype=orig_module.weight.dtype) - # Add the LoRA wrapper layer to the LoRAModuleWrapper. + # Add the LoRA wrapper layer to the LoRASidecarWrapper. lora_wrapper_layer.add_lora_layer(layer, patch_weight) @staticmethod