From 096e1d3a5d36fc2cc5cf5730e02b5f635d8df04d Mon Sep 17 00:00:00 2001 From: Jordan Date: Mon, 20 Feb 2023 02:37:44 -0700 Subject: [PATCH] start of rewrite for add / remove --- ldm/generate.py | 9 ++ ldm/modules/lora_manager.py | 230 ++++++++++++++++++++++++++++-------- 2 files changed, 191 insertions(+), 48 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 8e21f39f33..2833787458 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -575,7 +575,13 @@ class Generate: save_original = save_original, image_callback = image_callback) + if self.model.lora_manager: + self.model.lora_manager.reset_lora() + except KeyboardInterrupt: + if self.model.lora_manager: + self.model.lora_manager.reset_lora() + # Clear the CUDA cache on an exception self.clear_cuda_cache() @@ -584,6 +590,9 @@ class Generate: else: raise KeyboardInterrupt except RuntimeError: + if self.model.lora_manager: + self.model.lora_manager.reset_lora() + # Clear the CUDA cache on an exception self.clear_cuda_cache() diff --git a/ldm/modules/lora_manager.py b/ldm/modules/lora_manager.py index 008864a4b0..16b1d4363c 100644 --- a/ldm/modules/lora_manager.py +++ b/ldm/modules/lora_manager.py @@ -3,44 +3,99 @@ from pathlib import Path from ldm.invoke.globals import global_models_dir import torch from safetensors.torch import load_file +from typing import List, Optional, Set, Type -# modified from script at https://github.com/huggingface/diffusers/pull/2403 -def merge_lora_into_pipe(pipeline, checkpoint_path, alpha): - # load LoRA weight from .safetensors - state_dict = load_file(checkpoint_path, device=torch.cuda.current_device()) - visited = [] +class LoraLinear(torch.nn.Module): + def __init__( + self, in_features, out_features, rank=4 + ): + super().__init__() - # directly update weight in diffusers model - for key in state_dict: + if rank > min(in_features, out_features): + raise ValueError( + f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" + ) + self.rank = rank + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.lora = torch.nn.Linear(in_features, out_features, bias=False) - # it is suggested to print out the key, it usually will be something like below - # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.lora.weight.dtype + return self.lora(hidden_states.to(dtype)).to(orig_dtype) - # as we have set the alpha beforehand, so just skip - if ".alpha" in key or key in visited: - continue - if "text" in key: - layer_infos = key.split(".")[0].split("lora_te" + "_")[-1].split("_") - curr_layer = pipeline.text_encoder + +class LoraManager: + + def __init__(self, pipe): + self.pipe = pipe + self.lora_path = Path(global_models_dir(), 'lora') + self.lora_match = re.compile(r"]+)>") + self.prompt = None + + def _process_lora(self, lora): + processed_lora = { + "unet": [], + "text_encoder": [] + } + visited = [] + for key in lora: + if ".alpha" in key or key in visited: + continue + if "text" in key: + lora_type, pair_keys = self._find_layer( + "text_encoder", + key.split(".")[0].split("lora_te" + "_")[-1].split("_"), + key + ) + else: + lora_type, pair_keys = self._find_layer( + "unet", + key.split(".")[0].split("lora_unet" + "_")[-1].split("_"), + key + ) + + if len(lora[pair_keys[0]].shape) == 4: + weight_up = lora[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = lora[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + weight = torch.mm(weight_up, weight_down) + else: + weight_up = lora[pair_keys[0]].to(torch.float32) + weight_down = lora[pair_keys[1]].to(torch.float32) + weight = torch.mm(weight_up, weight_down) + + processed_lora[lora_type].append({ + "weight": weight, + "rank": lora[pair_keys[1]].shape[0] + }) + + for item in pair_keys: + visited.append(item) + + return processed_lora + + def _find_layer(self, lora_type, layer_key, key): + temp_name = layer_key.pop(0) + if lora_type == "unet": + curr_layer = self.pipe.unet + elif lora_type == "text_encoder": + curr_layer = self.pipe.text_encoder else: - layer_infos = key.split(".")[0].split("lora_unet" + "_")[-1].split("_") - curr_layer = pipeline.unet + raise ValueError("Invalid Lora Type") - # find the target layer - temp_name = layer_infos.pop(0) - while len(layer_infos) > -1: + while len(layer_key) > -1: try: curr_layer = curr_layer.__getattr__(temp_name) - if len(layer_infos) > 0: - temp_name = layer_infos.pop(0) - elif len(layer_infos) == 0: + if len(layer_key) > 0: + temp_name = layer_key.pop(0) + elif len(layer_key) == 0: break except Exception: if len(temp_name) > 0: - temp_name += "_" + layer_infos.pop(0) + temp_name += "_" + layer_key.pop(0) else: - temp_name = layer_infos.pop(0) + temp_name = layer_key.pop(0) pair_keys = [] if "lora_down" in key: @@ -50,29 +105,89 @@ def merge_lora_into_pipe(pipeline, checkpoint_path, alpha): pair_keys.append(key) pair_keys.append(key.replace("lora_up", "lora_down")) - # update weight - if len(state_dict[pair_keys[0]].shape) == 4: - weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) - weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) - curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + return lora_type, pair_keys + + @staticmethod + def _find_modules( + model, + ancestor_class: Optional[Set[str]] = None, + search_class: List[Type[torch.nn.Module]] = [torch.nn.Linear], + exclude_children_of: Optional[List[Type[torch.nn.Module]]] = [LoraLinear], + ): + """ + Find all modules of a certain class (or union of classes) that are direct or + indirect descendants of other modules of a certain class (or union of classes). + Returns all matching modules, along with the parent of those modules and the + names they are referenced by. + """ + + # Get the targets we should replace all linears under + if ancestor_class is not None: + ancestors = ( + module + for module in model.modules() + if module.__class__.__name__ in ancestor_class + ) else: - weight_up = state_dict[pair_keys[0]].to(torch.float32) - weight_down = state_dict[pair_keys[1]].to(torch.float32) - curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down) + # this, incase you want to naively iterate over all modules. + ancestors = [module for module in model.modules()] - # update visited list - for item in pair_keys: - visited.append(item) + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for ancestor in ancestors: + for fullname, module in ancestor.named_modules(): + if any([isinstance(module, _class) for _class in search_class]): + # Find the direct parent if this is a descendant, not a child, of target + *path, name = fullname.split(".") + parent = ancestor + while path: + parent = parent.get_submodule(path.pop(0)) + # Skip this linear if it's a child of a LoraInjectedLinear + if exclude_children_of and any( + [isinstance(parent, _class) for _class in exclude_children_of] + ): + continue + # Otherwise, yield it + yield parent, name, module + @staticmethod + def patch_module(lora_type, processed_lora, module, name, child_module, scale: float = 1.0): + _source = ( + child_module.linear + if isinstance(child_module, LoraLinear) + else child_module + ) -class LoraManager: + lora = processed_lora[lora_type].pop(0) - def __init__(self, pipe): - self.weights = {} - self.pipe = pipe - self.lora_path = Path(global_models_dir(), 'lora') - self.lora_match = re.compile(r"]+)>") - self.prompt = None + weight = _source.weight + _tmp = LoraLinear( + in_features=_source.in_features, + out_features=_source.out_features, + rank=lora["rank"] + ) + _tmp.linear.weight = weight + + # switch the module + module._modules[name] = _tmp + module._modules[name].lora.weight.data = lora["weight"] + module._modules[name].to(weight.device) + + def patch_lora(self, lora_path, scale: float = 1.0): + lora = load_file(lora_path) + processed_lora = self._process_lora(lora) + for module, name, child_module in self._find_modules( + self.pipe.unet, + {"CrossAttention", "Attention", "GEGLU"}, + search_class=[torch.nn.Linear, LoraLinear] + ): + self.patch_module("unet", processed_lora, module, name, child_module, scale) + + for module, name, child_module in self._find_modules( + self.pipe.text_encoder, + {"CLIPAttention"}, + search_class=[torch.nn.Linear, LoraLinear] + ): + self.patch_module("text_encoder", processed_lora, module, name, child_module, scale) def apply_lora_model(self, args): args = args.split(':') @@ -87,14 +202,34 @@ class LoraManager: else: file = Path(self.lora_path, f"{name}.safetensors") print(f"loading lora: {file}") - alpha = 1 + scale = 1.0 if len(args) == 2: - alpha = args[1] + scale = float(args[1]) - merge_lora_into_pipe(self.pipe, file.absolute().as_posix(), alpha) + self.patch_lora(file.absolute().as_posix(), scale) + + @staticmethod + def remove_lora(child_module): + _source = child_module.linear + weight = _source.weight + + _tmp = torch.nn.Linear(_source.in_features, _source.out_features) + _tmp.weight = weight + + def reset_lora(self): + for module, name, child_module in self._find_modules( + self.pipe.unet, + search_class=[LoraLinear] + ): + self.remove_lora(child_module) + + for module, name, child_module in self._find_modules( + self.pipe.text_encoder, + search_class=[LoraLinear] + ): + self.remove_lora(child_module) def load_lora_from_prompt(self, prompt: str): - for m in re.findall(self.lora_match, prompt): self.apply_lora_model(m) @@ -108,4 +243,3 @@ class LoraManager: return "" return re.sub(self.lora_match, found, prompt) -