From af3543a8c7ecf91bae24383c43ce287d47427197 Mon Sep 17 00:00:00 2001 From: Jordan Date: Tue, 21 Feb 2023 20:42:40 -0700 Subject: [PATCH] further cleanup and implement wrapper --- ldm/modules/lora_manager.py | 250 ++++++++++++++++++------------------ 1 file changed, 128 insertions(+), 122 deletions(-) diff --git a/ldm/modules/lora_manager.py b/ldm/modules/lora_manager.py index b3f85ff0d5..23d7c9eb22 100644 --- a/ldm/modules/lora_manager.py +++ b/ldm/modules/lora_manager.py @@ -22,26 +22,112 @@ class LoRALayer: self.scale = alpha / rank +class LoRAModuleWrapper: + unet: UNet2DConditionModel + text_encoder: CLIPTextModel + + def __init__(self, unet, text_encoder): + self.unet = unet + self.text_encoder = text_encoder + self.hooks = [] + self.text_modules = None + self.unet_modules = None + + self.applied_loras = {} + self.loaded_loras = {} + + self.UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + self.TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + self.LORA_PREFIX_UNET = 'lora_unet' + self.LORA_PREFIX_TEXT_ENCODER = 'lora_te' + + def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]: + mapping = {} + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + layer_type = child_module.__class__.__name__ + if layer_type == "Linear" or (layer_type == "Conv2d" and child_module.kernel_size == (1, 1)): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + mapping[lora_name] = child_module + self.apply_module_forward(child_module, lora_name) + return mapping + + if self.text_modules is None: + self.text_modules = find_modules( + self.LORA_PREFIX_TEXT_ENCODER, + text_encoder, + self.TEXT_ENCODER_TARGET_REPLACE_MODULE + ) + + if self.unet_modules is None: + self.unet_modules = find_modules( + self.LORA_PREFIX_UNET, + unet, + self.UNET_TARGET_REPLACE_MODULE + ) + + def lora_forward_hook(self, name): + wrapper = self + + def lora_forward(module, input_h, output): + if len(wrapper.loaded_loras) == 0: + return output + + for lora in wrapper.applied_loras.values(): + layer = lora.layers.get(name, None) + if layer is None: + continue + output = output + layer.up(layer.down(*input_h)) * lora.multiplier * layer.scale + return output + + return lora_forward + + def apply_module_forward(self, module, name): + handle = module.register_forward_hook(self.lora_forward_hook(name)) + self.hooks.append(handle) + + def clear_hooks(self): + for hook in self.hooks: + hook.remove() + + self.hooks.clear() + + def clear_applied_loras(self): + self.applied_loras.clear() + + def clear_loaded_loras(self): + self.loaded_loras.clear() + + def __del__(self): + self.clear_hooks() + self.clear_applied_loras() + self.clear_loaded_loras() + del self.text_modules + del self.unet_modules + del self.hooks + + class LoRA: name: str layers: dict[str, LoRALayer] device: torch.device dtype: torch.dtype + wrapper: LoRAModuleWrapper multiplier: float - def __init__(self, name: str, device, dtype, multiplier=1.0): + def __init__(self, name: str, device, dtype, wrapper, multiplier=1.0): self.name = name self.layers = {} self.multiplier = multiplier self.device = device self.dtype = dtype + self.wrapper = wrapper self.rank = None self.alpha = None - def load_from_dict(self, - state_dict, - text_modules: dict[str, torch.nn.Module], - unet_modules: dict[str, torch.nn.Module]): + def load_from_dict(self, state_dict): for key, value in state_dict.items(): stem, leaf = key.split(".", 1) @@ -50,8 +136,8 @@ class LoRA: self.alpha = value.item() continue - if stem.startswith(LORA_PREFIX_TEXT_ENCODER): - wrapped = text_modules.get(stem, None) + if stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER): + wrapped = self.wrapper.text_modules.get(stem, None) if wrapped is None: print(f">> Missing layer: {stem}") continue @@ -60,8 +146,8 @@ class LoRA: self.rank = value.shape[0] self.load_lora_layer(stem, leaf, value, wrapped) continue - elif stem.startswith(LORA_PREFIX_UNET): - wrapped = unet_modules.get(stem, None) + elif stem.startswith(self.wrapper.LORA_PREFIX_UNET): + wrapped = self.wrapper.unet_modules.get(stem, None) if wrapped is None: print(f">> Missing layer: {stem}") continue @@ -101,11 +187,6 @@ class LoRA: return -UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] -TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] -LORA_PREFIX_UNET = 'lora_unet' -LORA_PREFIX_TEXT_ENCODER = 'lora_te' - re_digits = re.compile(r"\d+") re_unet_transformer_attn_blocks = re.compile( r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_attn(\d+)_(.+).(weight|alpha)" @@ -196,8 +277,7 @@ def convert_key_to_diffusers(key): def load_lora_attn( name: str, path_file: Path, - unet: UNet2DConditionModel, - text_encoder: CLIPTextModel, + wrapper: LoRAModuleWrapper, multiplier=1.0 ): print(f">> Loading lora {name} from {path_file}") @@ -207,10 +287,10 @@ def load_lora_attn( checkpoint = torch.load(path_file, map_location='cpu') for key in list(checkpoint.keys()): - if key.startswith(LORA_PREFIX_UNET): + if key.startswith(wrapper.LORA_PREFIX_UNET): # convert unet keys checkpoint[convert_key_to_diffusers(key)] = checkpoint.pop(key) - elif key.startswith(LORA_PREFIX_UNET): + elif key.startswith(wrapper.LORA_PREFIX_UNET): # convert text encoder keys (not yet supported) # state_dict[convert_key_to_diffusers(key)] = state_dict.pop(key) checkpoint.pop(key) @@ -218,44 +298,8 @@ def load_lora_attn( # remove invalid key checkpoint.pop(key) - unet.load_attn_procs(checkpoint) - # text_encoder.load_attn_procs(checkpoint) - - -def lora_forward_hook(name): - def lora_forward(module, input_h, output): - if len(loaded_loras) == 0: - return output - - for lora in applied_loras.values(): - layer = lora.layers.get(name, None) - if layer is None: - continue - output = output + layer.up(layer.down(*input_h)) * lora.multiplier * layer.scale - return output - - return lora_forward - - -def load_lora( - name: str, - path_file: Path, - device: torch.device, - dtype: torch.dtype, - text_modules: dict[str, torch.nn.Module], - unet_modules: dict[str, torch.nn.Module], - multiplier=1.0 -): - print(f">> Loading lora {name} from {path_file}") - if path_file.suffix == '.safetensors': - checkpoint = load_file(path_file.absolute().as_posix(), device='cpu') - else: - checkpoint = torch.load(path_file, map_location='cpu') - - lora = LoRA(name, device, dtype, multiplier) - lora.load_from_dict(checkpoint, text_modules, unet_modules) - - return lora + wrapper.unet.load_attn_procs(checkpoint) + # wrapper.text_encoder.load_attn_procs(checkpoint) class LoraManager: @@ -269,37 +313,23 @@ class LoraManager: self.device = torch.device(choose_torch_device()) self.dtype = pipe.unet.dtype self.loras_to_load = {} - self.hooks = [] + self.wrapper = LoRAModuleWrapper(pipe.unet, pipe.text_encoder) - def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]: - mapping = {} - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - layer_type = child_module.__class__.__name__ - if layer_type == "Linear" or (layer_type == "Conv2d" and child_module.kernel_size == (1, 1)): - lora_name = prefix + '.' + name + '.' + child_name - lora_name = lora_name.replace('.', '_') - mapping[lora_name] = child_module - self.apply_module_forward(child_module, lora_name) - return mapping - - self.text_modules = find_modules( - LORA_PREFIX_TEXT_ENCODER, self.text_encoder, TEXT_ENCODER_TARGET_REPLACE_MODULE) - - self.unet_modules = find_modules( - LORA_PREFIX_UNET, self.unet, UNET_TARGET_REPLACE_MODULE) - - def _load_lora(self, name, path_file, multiplier: float = 1.0): + def load_lora_module(self, name, path_file, multiplier: float = 1.0): # can be used instead to load through diffusers, once enough support is added - # lora = load_lora_attn(name, path_file, self.unet, self.text_encoder, multiplier) - lora = load_lora(name, path_file, self.device, self.dtype, self.text_modules, self.unet_modules, multiplier) - loaded_loras[name] = lora - return lora + # lora = load_lora_attn(name, path_file, self.wrapper, multiplier) - def apply_module_forward(self, module, lora_name): - handle = module.register_forward_hook(lora_forward_hook(lora_name)) - self.hooks.append(handle) + print(f">> Loading lora {name} from {path_file}") + if path_file.suffix == '.safetensors': + checkpoint = load_file(path_file.absolute().as_posix(), device='cpu') + else: + checkpoint = torch.load(path_file, map_location='cpu') + + lora = LoRA(name, self.device, self.dtype, self.wrapper, multiplier) + lora.load_from_dict(checkpoint) + self.wrapper.loaded_loras[name] = lora + + return lora def apply_lora_model(self, name, mult: float = 1.0): path = Path(self.lora_path, name) @@ -318,31 +348,30 @@ class LoraManager: print(f">> Unable to find lora: {name}") return - lora = loaded_loras.get(name, None) + lora = self.wrapper.loaded_loras.get(name, None) if lora is None: - lora = self._load_lora(name, path_file, mult) + lora = self.load_lora_module(name, path_file, mult) lora.multiplier = mult - applied_loras[name] = lora + self.wrapper.applied_loras[name] = lora def load_lora(self): + print(self.loras_to_load) for name, multiplier in self.loras_to_load.items(): self.apply_lora_model(name, multiplier) # unload any lora's not defined by loras_to_load - for name in list(applied_loras.keys()): + for name in list(self.wrapper.applied_loras.keys()): if name not in self.loras_to_load: self.unload_applied_lora(name) - @staticmethod - def unload_applied_lora(lora_name: str): - if lora_name in applied_loras: - del applied_loras[lora_name] + def unload_applied_lora(self, lora_name: str): + if lora_name in self.wrapper.applied_loras: + del self.wrapper.applied_loras[lora_name] - @staticmethod - def unload_lora(lora_name: str): - if lora_name in loaded_loras: - del loaded_loras[lora_name] + def unload_lora(self, lora_name: str): + if lora_name in self.wrapper.loaded_loras: + del self.wrapper.loaded_loras[lora_name] # Define a lora to be loaded # Can be used to define a lora to be loaded outside of prompts @@ -350,8 +379,8 @@ class LoraManager: self.loras_to_load[name] = multiplier # update the multiplier if the lora was already loaded - if name in loaded_loras: - loaded_loras[name].multiplier = multiplier + if name in self.wrapper.loaded_loras: + self.wrapper.loaded_loras[name].multiplier = multiplier # Load the lora from a prompt, syntax is # Multiplier should be a value between 0.0 and 1.0 @@ -374,31 +403,8 @@ class LoraManager: return re.sub(lora_match, "", prompt) def clear_loras(self): - clear_applied_loras() + self.wrapper.clear_applied_loras() self.loras_to_load = {} - def clear_hooks(self): - for hook in self.hooks: - hook.remove() - - self.hooks.clear() - def __del__(self): - self.clear_hooks() - clear_applied_loras() - clear_loaded_loras() - del self.text_modules - del self.unet_modules del self.loras_to_load - - -applied_loras = {} -loaded_loras = {} - - -def clear_applied_loras(): - applied_loras.clear() - - -def clear_loaded_loras(): - loaded_loras.clear()