move key converter to wrapper

This commit is contained in:
Jordan
2023-02-21 21:38:15 -07:00
parent af3543a8c7
commit cd333e414b

View File

@@ -25,6 +25,7 @@ class LoRALayer:
class LoRAModuleWrapper:
unet: UNet2DConditionModel
text_encoder: CLIPTextModel
hooks: list[RemovableHandle]
def __init__(self, unet, text_encoder):
self.unet = unet
@@ -41,6 +42,26 @@ class LoRAModuleWrapper:
self.LORA_PREFIX_UNET = 'lora_unet'
self.LORA_PREFIX_TEXT_ENCODER = 'lora_te'
self.re_digits = re.compile(r"\d+")
self.re_unet_transformer_attn_blocks = re.compile(
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_attn(\d+)_(.+).(weight|alpha)"
)
self.re_unet_mid_blocks = re.compile(
r"lora_unet_mid_block_attentions_(\d+)_(.+).(weight|alpha)"
)
self.re_unet_transformer_blocks = re.compile(
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_(.+).(weight|alpha)"
)
self.re_unet_mid_transformer_blocks = re.compile(
r"lora_unet_mid_block_attentions_(\d+)_transformer_blocks_(\d+)_(.+).(weight|alpha)"
)
self.re_unet_norm_blocks = re.compile(
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_(.+).(weight|alpha)"
)
self.re_out = re.compile(r"to_out_(\d+)")
self.re_processor_weight = re.compile(r"(.+)_(\d+)_(.+)")
self.re_processor_alpha = re.compile(r"(.+)_(\d+)")
def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]:
mapping = {}
for name, module in root_module.named_modules():
@@ -68,6 +89,71 @@ class LoRAModuleWrapper:
self.UNET_TARGET_REPLACE_MODULE
)
def convert_key_to_diffusers(self, key):
def match(match_list, regex, subject):
r = re.match(regex, subject)
if not r:
return False
match_list.clear()
match_list.extend([int(x) if re.match(self.re_digits, x) else x for x in r.groups()])
return True
m = []
def get_front_block(first, second, third, fourth=None):
if first == "mid":
b_type = f"mid_block"
else:
b_type = f"{first}_blocks.{second}"
if fourth is None:
return f"{b_type}.attentions.{third}"
return f"{b_type}.attentions.{third}.transformer_blocks.{fourth}"
def get_back_block(first, second, third):
second = second.replace(".lora_", "_lora.")
if third == "weight":
bm = []
if match(bm, self.re_processor_weight, second):
s_bm = bm[2].split('.')
s_front = f"{bm[0]}_{s_bm[0]}"
s_back = f"{s_bm[1]}"
if int(bm[1]) == 0:
second = f"{s_front}.{s_back}"
else:
second = f"{s_front}.{bm[1]}.{s_back}"
elif third == "alpha":
bma = []
if match(bma, self.re_processor_alpha, second):
if int(bma[1]) == 0:
second = f"{bma[0]}"
else:
second = f"{bma[0]}.{bma[1]}"
if first is None:
return f"processor.{second}.{third}"
return f"attn{first}.processor.{second}.{third}"
if match(m, self.re_unet_transformer_attn_blocks, key):
return f"{get_front_block(m[0], m[1], m[2], m[3])}.{get_back_block(m[4], m[5], m[6])}"
if match(m, self.re_unet_transformer_blocks, key):
return f"{get_front_block(m[0], m[1], m[2], m[3])}.{get_back_block(None, m[4], m[5])}"
if match(m, self.re_unet_mid_transformer_blocks, key):
return f"{get_front_block('mid', None, m[0], m[1])}.{get_back_block(None, m[2], m[3])}"
if match(m, self.re_unet_norm_blocks, key):
return f"{get_front_block(m[0], m[1], m[2])}.{get_back_block(None, m[3], m[4])}"
if match(m, self.re_unet_mid_blocks, key):
return f"{get_front_block('mid', None, m[0])}.{get_back_block(None, m[1], m[2])}"
return key
def lora_forward_hook(self, name):
wrapper = self
@@ -187,93 +273,6 @@ class LoRA:
return
re_digits = re.compile(r"\d+")
re_unet_transformer_attn_blocks = re.compile(
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_attn(\d+)_(.+).(weight|alpha)"
)
re_unet_mid_blocks = re.compile(
r"lora_unet_mid_block_attentions_(\d+)_(.+).(weight|alpha)"
)
re_unet_transformer_blocks = re.compile(
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_(.+).(weight|alpha)"
)
re_unet_mid_transformer_blocks = re.compile(
r"lora_unet_mid_block_attentions_(\d+)_transformer_blocks_(\d+)_(.+).(weight|alpha)"
)
re_unet_norm_blocks = re.compile(
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_(.+).(weight|alpha)"
)
re_out = re.compile(r"to_out_(\d+)")
re_processor_weight = re.compile(r"(.+)_(\d+)_(.+)")
re_processor_alpha = re.compile(r"(.+)_(\d+)")
def convert_key_to_diffusers(key):
def match(match_list, regex, subject):
r = re.match(regex, subject)
if not r:
return False
match_list.clear()
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
return True
m = []
def get_front_block(first, second, third, fourth=None):
if first == "mid":
b_type = f"mid_block"
else:
b_type = f"{first}_blocks.{second}"
if fourth is None:
return f"{b_type}.attentions.{third}"
return f"{b_type}.attentions.{third}.transformer_blocks.{fourth}"
def get_back_block(first, second, third):
second = second.replace(".lora_", "_lora.")
if third == "weight":
bm = []
if match(bm, re_processor_weight, second):
s_bm = bm[2].split('.')
s_front = f"{bm[0]}_{s_bm[0]}"
s_back = f"{s_bm[1]}"
if int(bm[1]) == 0:
second = f"{s_front}.{s_back}"
else:
second = f"{s_front}.{bm[1]}.{s_back}"
elif third == "alpha":
bma = []
if match(bma, re_processor_alpha, second):
if int(bma[1]) == 0:
second = f"{bma[0]}"
else:
second = f"{bma[0]}.{bma[1]}"
if first is None:
return f"processor.{second}.{third}"
return f"attn{first}.processor.{second}.{third}"
if match(m, re_unet_transformer_attn_blocks, key):
return f"{get_front_block(m[0], m[1], m[2], m[3])}.{get_back_block(m[4], m[5], m[6])}"
if match(m, re_unet_transformer_blocks, key):
return f"{get_front_block(m[0], m[1], m[2], m[3])}.{get_back_block(None, m[4], m[5])}"
if match(m, re_unet_mid_transformer_blocks, key):
return f"{get_front_block('mid', None, m[0], m[1])}.{get_back_block(None, m[2], m[3])}"
if match(m, re_unet_norm_blocks, key):
return f"{get_front_block(m[0], m[1], m[2])}.{get_back_block(None, m[3], m[4])}"
if match(m, re_unet_mid_blocks, key):
return f"{get_front_block('mid', None, m[0])}.{get_back_block(None, m[1], m[2])}"
return key
def load_lora_attn(
name: str,
path_file: Path,
@@ -289,10 +288,10 @@ def load_lora_attn(
for key in list(checkpoint.keys()):
if key.startswith(wrapper.LORA_PREFIX_UNET):
# convert unet keys
checkpoint[convert_key_to_diffusers(key)] = checkpoint.pop(key)
checkpoint[wrapper.convert_key_to_diffusers(key)] = checkpoint.pop(key)
elif key.startswith(wrapper.LORA_PREFIX_UNET):
# convert text encoder keys (not yet supported)
# state_dict[convert_key_to_diffusers(key)] = state_dict.pop(key)
# state_dict[wrapper.convert_key_to_diffusers(key)] = state_dict.pop(key)
checkpoint.pop(key)
else:
# remove invalid key
@@ -304,7 +303,6 @@ def load_lora_attn(
class LoraManager:
loras_to_load: dict[str, float]
hooks: list[RemovableHandle]
def __init__(self, pipe):
self.lora_path = Path(global_models_dir(), 'lora')
@@ -356,7 +354,6 @@ class LoraManager:
self.wrapper.applied_loras[name] = lora
def load_lora(self):
print(self.loras_to_load)
for name, multiplier in self.loras_to_load.items():
self.apply_lora_model(name, multiplier)