diff --git a/ldm/modules/lora_manager.py b/ldm/modules/lora_manager.py index eaaac19f65..0071bf5650 100644 --- a/ldm/modules/lora_manager.py +++ b/ldm/modules/lora_manager.py @@ -38,25 +38,19 @@ LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' -def lora_forward(module, input_h, output): - if len(loaded_loras) == 0: +def lora_forward_hook(name): + def lora_forward(module, input_h, output): + if len(loaded_loras) == 0: + return output + + for lora in applied_loras.values(): + layer = lora.layers.get(name, None) + if layer is None: + continue + output = output + layer.up(layer.down(*input_h)) * lora.multiplier * layer.scale return output - lora_name = getattr(module, 'lora_name', None) - for lora in applied_loras.values(): - layer = lora.layers.get(lora_name, None) - if layer is None: - continue - output = output + layer.up(layer.down(*input_h)) * lora.multiplier * layer.scale - return output - - -def lora_linear_forward(self, input_h): - return lora_forward(self, input_h, torch.nn.Linear.forward_before_lora(self, input_h)) - - -def lora_conv2d_forward(self, input_h): - return lora_forward(self, input_h, torch.nn.Conv2d.forward_before_lora(self, input_h)) + return lora_forward def load_lora( @@ -141,6 +135,7 @@ def load_lora( class LoraManager: loras_to_load: dict[str, float] + hooks: list[RemovableHandle] def __init__(self, pipe): self.lora_path = Path(global_models_dir(), 'lora') @@ -149,26 +144,19 @@ class LoraManager: self.device = torch.device(choose_torch_device()) self.dtype = pipe.unet.dtype self.loras_to_load = {} - - if not hasattr(torch.nn.Linear, 'forward_before_lora'): - torch.nn.Linear.forward_before_lora = torch.nn.Linear.forward - - if not hasattr(torch.nn.Conv2d, 'forward_before_lora'): - torch.nn.Conv2d.forward_before_lora = torch.nn.Conv2d.forward - - torch.nn.Linear.forward = lora_linear_forward - torch.nn.Conv2d.forward = lora_conv2d_forward + self.hooks = [] def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]: mapping = {} for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): + layer_type = child_module.__class__.__name__ + if layer_type == "Linear" or (layer_type == "Conv2d" and child_module.kernel_size == (1, 1)): lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_') mapping[lora_name] = child_module - module.lora_name = lora_name + self.apply_module_forward(child_module, lora_name) return mapping self.text_modules = find_modules( @@ -182,6 +170,12 @@ class LoraManager: loaded_loras[name] = lora return lora + def apply_module_forward(self, module, lora_name): + handle = RemovableHandle(module._forward_hooks) + handle.id = 9000 + module._forward_hooks[handle.id] = lora_forward_hook(lora_name) + self.hooks.append(handle) + def apply_lora_model(self, name, mult: float = 1.0): path = Path(self.lora_path, name) file = Path(path, "pytorch_lora_weights.bin") @@ -258,16 +252,14 @@ class LoraManager: clear_applied_loras() self.loras_to_load = {} + def clear_hooks(self): + for hook in self.hooks: + hook.remove() + + self.hooks.clear() + def __del__(self): - # cleanup overrides - if hasattr(torch.nn.Linear, 'forward_before_lora'): - torch.nn.Linear.forward = torch.nn.Linear.forward_before_lora - del torch.nn.Linear.forward_before_lora - - if hasattr(torch.nn.Conv2d, 'forward_before_lora'): - torch.nn.Conv2d.forward = torch.nn.Conv2d.forward_before_lora - del torch.nn.Conv2d.forward_before_lora - + self.clear_hooks() clear_applied_loras() clear_loaded_loras() del self.text_modules