From cd333e414bc88fed16b17f018188b07f76eb0e59 Mon Sep 17 00:00:00 2001 From: Jordan Date: Tue, 21 Feb 2023 21:38:15 -0700 Subject: [PATCH] move key converter to wrapper --- ldm/modules/lora_manager.py | 179 ++++++++++++++++++------------------ 1 file changed, 88 insertions(+), 91 deletions(-) diff --git a/ldm/modules/lora_manager.py b/ldm/modules/lora_manager.py index 23d7c9eb22..b62feae2fb 100644 --- a/ldm/modules/lora_manager.py +++ b/ldm/modules/lora_manager.py @@ -25,6 +25,7 @@ class LoRALayer: class LoRAModuleWrapper: unet: UNet2DConditionModel text_encoder: CLIPTextModel + hooks: list[RemovableHandle] def __init__(self, unet, text_encoder): self.unet = unet @@ -41,6 +42,26 @@ class LoRAModuleWrapper: self.LORA_PREFIX_UNET = 'lora_unet' self.LORA_PREFIX_TEXT_ENCODER = 'lora_te' + self.re_digits = re.compile(r"\d+") + self.re_unet_transformer_attn_blocks = re.compile( + r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_attn(\d+)_(.+).(weight|alpha)" + ) + self.re_unet_mid_blocks = re.compile( + r"lora_unet_mid_block_attentions_(\d+)_(.+).(weight|alpha)" + ) + self.re_unet_transformer_blocks = re.compile( + r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_(.+).(weight|alpha)" + ) + self.re_unet_mid_transformer_blocks = re.compile( + r"lora_unet_mid_block_attentions_(\d+)_transformer_blocks_(\d+)_(.+).(weight|alpha)" + ) + self.re_unet_norm_blocks = re.compile( + r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_(.+).(weight|alpha)" + ) + self.re_out = re.compile(r"to_out_(\d+)") + self.re_processor_weight = re.compile(r"(.+)_(\d+)_(.+)") + self.re_processor_alpha = re.compile(r"(.+)_(\d+)") + def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]: mapping = {} for name, module in root_module.named_modules(): @@ -68,6 +89,71 @@ class LoRAModuleWrapper: self.UNET_TARGET_REPLACE_MODULE ) + def convert_key_to_diffusers(self, key): + def match(match_list, regex, subject): + r = re.match(regex, subject) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(self.re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + def get_front_block(first, second, third, fourth=None): + if first == "mid": + b_type = f"mid_block" + else: + b_type = f"{first}_blocks.{second}" + + if fourth is None: + return f"{b_type}.attentions.{third}" + + return f"{b_type}.attentions.{third}.transformer_blocks.{fourth}" + + def get_back_block(first, second, third): + second = second.replace(".lora_", "_lora.") + if third == "weight": + bm = [] + if match(bm, self.re_processor_weight, second): + s_bm = bm[2].split('.') + s_front = f"{bm[0]}_{s_bm[0]}" + s_back = f"{s_bm[1]}" + if int(bm[1]) == 0: + second = f"{s_front}.{s_back}" + else: + second = f"{s_front}.{bm[1]}.{s_back}" + elif third == "alpha": + bma = [] + if match(bma, self.re_processor_alpha, second): + if int(bma[1]) == 0: + second = f"{bma[0]}" + else: + second = f"{bma[0]}.{bma[1]}" + + if first is None: + return f"processor.{second}.{third}" + + return f"attn{first}.processor.{second}.{third}" + + if match(m, self.re_unet_transformer_attn_blocks, key): + return f"{get_front_block(m[0], m[1], m[2], m[3])}.{get_back_block(m[4], m[5], m[6])}" + + if match(m, self.re_unet_transformer_blocks, key): + return f"{get_front_block(m[0], m[1], m[2], m[3])}.{get_back_block(None, m[4], m[5])}" + + if match(m, self.re_unet_mid_transformer_blocks, key): + return f"{get_front_block('mid', None, m[0], m[1])}.{get_back_block(None, m[2], m[3])}" + + if match(m, self.re_unet_norm_blocks, key): + return f"{get_front_block(m[0], m[1], m[2])}.{get_back_block(None, m[3], m[4])}" + + if match(m, self.re_unet_mid_blocks, key): + return f"{get_front_block('mid', None, m[0])}.{get_back_block(None, m[1], m[2])}" + + return key + def lora_forward_hook(self, name): wrapper = self @@ -187,93 +273,6 @@ class LoRA: return -re_digits = re.compile(r"\d+") -re_unet_transformer_attn_blocks = re.compile( - r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_attn(\d+)_(.+).(weight|alpha)" -) -re_unet_mid_blocks = re.compile( - r"lora_unet_mid_block_attentions_(\d+)_(.+).(weight|alpha)" -) -re_unet_transformer_blocks = re.compile( - r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_(.+).(weight|alpha)" -) -re_unet_mid_transformer_blocks = re.compile( - r"lora_unet_mid_block_attentions_(\d+)_transformer_blocks_(\d+)_(.+).(weight|alpha)" -) -re_unet_norm_blocks = re.compile( - r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_(.+).(weight|alpha)" -) -re_out = re.compile(r"to_out_(\d+)") -re_processor_weight = re.compile(r"(.+)_(\d+)_(.+)") -re_processor_alpha = re.compile(r"(.+)_(\d+)") - - -def convert_key_to_diffusers(key): - def match(match_list, regex, subject): - r = re.match(regex, subject) - if not r: - return False - - match_list.clear() - match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) - return True - - m = [] - - def get_front_block(first, second, third, fourth=None): - if first == "mid": - b_type = f"mid_block" - else: - b_type = f"{first}_blocks.{second}" - - if fourth is None: - return f"{b_type}.attentions.{third}" - - return f"{b_type}.attentions.{third}.transformer_blocks.{fourth}" - - def get_back_block(first, second, third): - second = second.replace(".lora_", "_lora.") - if third == "weight": - bm = [] - if match(bm, re_processor_weight, second): - s_bm = bm[2].split('.') - s_front = f"{bm[0]}_{s_bm[0]}" - s_back = f"{s_bm[1]}" - if int(bm[1]) == 0: - second = f"{s_front}.{s_back}" - else: - second = f"{s_front}.{bm[1]}.{s_back}" - elif third == "alpha": - bma = [] - if match(bma, re_processor_alpha, second): - if int(bma[1]) == 0: - second = f"{bma[0]}" - else: - second = f"{bma[0]}.{bma[1]}" - - if first is None: - return f"processor.{second}.{third}" - - return f"attn{first}.processor.{second}.{third}" - - if match(m, re_unet_transformer_attn_blocks, key): - return f"{get_front_block(m[0], m[1], m[2], m[3])}.{get_back_block(m[4], m[5], m[6])}" - - if match(m, re_unet_transformer_blocks, key): - return f"{get_front_block(m[0], m[1], m[2], m[3])}.{get_back_block(None, m[4], m[5])}" - - if match(m, re_unet_mid_transformer_blocks, key): - return f"{get_front_block('mid', None, m[0], m[1])}.{get_back_block(None, m[2], m[3])}" - - if match(m, re_unet_norm_blocks, key): - return f"{get_front_block(m[0], m[1], m[2])}.{get_back_block(None, m[3], m[4])}" - - if match(m, re_unet_mid_blocks, key): - return f"{get_front_block('mid', None, m[0])}.{get_back_block(None, m[1], m[2])}" - - return key - - def load_lora_attn( name: str, path_file: Path, @@ -289,10 +288,10 @@ def load_lora_attn( for key in list(checkpoint.keys()): if key.startswith(wrapper.LORA_PREFIX_UNET): # convert unet keys - checkpoint[convert_key_to_diffusers(key)] = checkpoint.pop(key) + checkpoint[wrapper.convert_key_to_diffusers(key)] = checkpoint.pop(key) elif key.startswith(wrapper.LORA_PREFIX_UNET): # convert text encoder keys (not yet supported) - # state_dict[convert_key_to_diffusers(key)] = state_dict.pop(key) + # state_dict[wrapper.convert_key_to_diffusers(key)] = state_dict.pop(key) checkpoint.pop(key) else: # remove invalid key @@ -304,7 +303,6 @@ def load_lora_attn( class LoraManager: loras_to_load: dict[str, float] - hooks: list[RemovableHandle] def __init__(self, pipe): self.lora_path = Path(global_models_dir(), 'lora') @@ -356,7 +354,6 @@ class LoraManager: self.wrapper.applied_loras[name] = lora def load_lora(self): - print(self.loras_to_load) for name, multiplier in self.loras_to_load.items(): self.apply_lora_model(name, multiplier)