mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-19 09:54:24 -05:00
move key converter to wrapper
This commit is contained in:
@@ -25,6 +25,7 @@ class LoRALayer:
|
||||
class LoRAModuleWrapper:
|
||||
unet: UNet2DConditionModel
|
||||
text_encoder: CLIPTextModel
|
||||
hooks: list[RemovableHandle]
|
||||
|
||||
def __init__(self, unet, text_encoder):
|
||||
self.unet = unet
|
||||
@@ -41,6 +42,26 @@ class LoRAModuleWrapper:
|
||||
self.LORA_PREFIX_UNET = 'lora_unet'
|
||||
self.LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||
|
||||
self.re_digits = re.compile(r"\d+")
|
||||
self.re_unet_transformer_attn_blocks = re.compile(
|
||||
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_attn(\d+)_(.+).(weight|alpha)"
|
||||
)
|
||||
self.re_unet_mid_blocks = re.compile(
|
||||
r"lora_unet_mid_block_attentions_(\d+)_(.+).(weight|alpha)"
|
||||
)
|
||||
self.re_unet_transformer_blocks = re.compile(
|
||||
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_(.+).(weight|alpha)"
|
||||
)
|
||||
self.re_unet_mid_transformer_blocks = re.compile(
|
||||
r"lora_unet_mid_block_attentions_(\d+)_transformer_blocks_(\d+)_(.+).(weight|alpha)"
|
||||
)
|
||||
self.re_unet_norm_blocks = re.compile(
|
||||
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_(.+).(weight|alpha)"
|
||||
)
|
||||
self.re_out = re.compile(r"to_out_(\d+)")
|
||||
self.re_processor_weight = re.compile(r"(.+)_(\d+)_(.+)")
|
||||
self.re_processor_alpha = re.compile(r"(.+)_(\d+)")
|
||||
|
||||
def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]:
|
||||
mapping = {}
|
||||
for name, module in root_module.named_modules():
|
||||
@@ -68,6 +89,71 @@ class LoRAModuleWrapper:
|
||||
self.UNET_TARGET_REPLACE_MODULE
|
||||
)
|
||||
|
||||
def convert_key_to_diffusers(self, key):
|
||||
def match(match_list, regex, subject):
|
||||
r = re.match(regex, subject)
|
||||
if not r:
|
||||
return False
|
||||
|
||||
match_list.clear()
|
||||
match_list.extend([int(x) if re.match(self.re_digits, x) else x for x in r.groups()])
|
||||
return True
|
||||
|
||||
m = []
|
||||
|
||||
def get_front_block(first, second, third, fourth=None):
|
||||
if first == "mid":
|
||||
b_type = f"mid_block"
|
||||
else:
|
||||
b_type = f"{first}_blocks.{second}"
|
||||
|
||||
if fourth is None:
|
||||
return f"{b_type}.attentions.{third}"
|
||||
|
||||
return f"{b_type}.attentions.{third}.transformer_blocks.{fourth}"
|
||||
|
||||
def get_back_block(first, second, third):
|
||||
second = second.replace(".lora_", "_lora.")
|
||||
if third == "weight":
|
||||
bm = []
|
||||
if match(bm, self.re_processor_weight, second):
|
||||
s_bm = bm[2].split('.')
|
||||
s_front = f"{bm[0]}_{s_bm[0]}"
|
||||
s_back = f"{s_bm[1]}"
|
||||
if int(bm[1]) == 0:
|
||||
second = f"{s_front}.{s_back}"
|
||||
else:
|
||||
second = f"{s_front}.{bm[1]}.{s_back}"
|
||||
elif third == "alpha":
|
||||
bma = []
|
||||
if match(bma, self.re_processor_alpha, second):
|
||||
if int(bma[1]) == 0:
|
||||
second = f"{bma[0]}"
|
||||
else:
|
||||
second = f"{bma[0]}.{bma[1]}"
|
||||
|
||||
if first is None:
|
||||
return f"processor.{second}.{third}"
|
||||
|
||||
return f"attn{first}.processor.{second}.{third}"
|
||||
|
||||
if match(m, self.re_unet_transformer_attn_blocks, key):
|
||||
return f"{get_front_block(m[0], m[1], m[2], m[3])}.{get_back_block(m[4], m[5], m[6])}"
|
||||
|
||||
if match(m, self.re_unet_transformer_blocks, key):
|
||||
return f"{get_front_block(m[0], m[1], m[2], m[3])}.{get_back_block(None, m[4], m[5])}"
|
||||
|
||||
if match(m, self.re_unet_mid_transformer_blocks, key):
|
||||
return f"{get_front_block('mid', None, m[0], m[1])}.{get_back_block(None, m[2], m[3])}"
|
||||
|
||||
if match(m, self.re_unet_norm_blocks, key):
|
||||
return f"{get_front_block(m[0], m[1], m[2])}.{get_back_block(None, m[3], m[4])}"
|
||||
|
||||
if match(m, self.re_unet_mid_blocks, key):
|
||||
return f"{get_front_block('mid', None, m[0])}.{get_back_block(None, m[1], m[2])}"
|
||||
|
||||
return key
|
||||
|
||||
def lora_forward_hook(self, name):
|
||||
wrapper = self
|
||||
|
||||
@@ -187,93 +273,6 @@ class LoRA:
|
||||
return
|
||||
|
||||
|
||||
re_digits = re.compile(r"\d+")
|
||||
re_unet_transformer_attn_blocks = re.compile(
|
||||
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_attn(\d+)_(.+).(weight|alpha)"
|
||||
)
|
||||
re_unet_mid_blocks = re.compile(
|
||||
r"lora_unet_mid_block_attentions_(\d+)_(.+).(weight|alpha)"
|
||||
)
|
||||
re_unet_transformer_blocks = re.compile(
|
||||
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_transformer_blocks_(\d+)_(.+).(weight|alpha)"
|
||||
)
|
||||
re_unet_mid_transformer_blocks = re.compile(
|
||||
r"lora_unet_mid_block_attentions_(\d+)_transformer_blocks_(\d+)_(.+).(weight|alpha)"
|
||||
)
|
||||
re_unet_norm_blocks = re.compile(
|
||||
r"lora_unet_(.+)_blocks_(\d+)_attentions_(\d+)_(.+).(weight|alpha)"
|
||||
)
|
||||
re_out = re.compile(r"to_out_(\d+)")
|
||||
re_processor_weight = re.compile(r"(.+)_(\d+)_(.+)")
|
||||
re_processor_alpha = re.compile(r"(.+)_(\d+)")
|
||||
|
||||
|
||||
def convert_key_to_diffusers(key):
|
||||
def match(match_list, regex, subject):
|
||||
r = re.match(regex, subject)
|
||||
if not r:
|
||||
return False
|
||||
|
||||
match_list.clear()
|
||||
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
||||
return True
|
||||
|
||||
m = []
|
||||
|
||||
def get_front_block(first, second, third, fourth=None):
|
||||
if first == "mid":
|
||||
b_type = f"mid_block"
|
||||
else:
|
||||
b_type = f"{first}_blocks.{second}"
|
||||
|
||||
if fourth is None:
|
||||
return f"{b_type}.attentions.{third}"
|
||||
|
||||
return f"{b_type}.attentions.{third}.transformer_blocks.{fourth}"
|
||||
|
||||
def get_back_block(first, second, third):
|
||||
second = second.replace(".lora_", "_lora.")
|
||||
if third == "weight":
|
||||
bm = []
|
||||
if match(bm, re_processor_weight, second):
|
||||
s_bm = bm[2].split('.')
|
||||
s_front = f"{bm[0]}_{s_bm[0]}"
|
||||
s_back = f"{s_bm[1]}"
|
||||
if int(bm[1]) == 0:
|
||||
second = f"{s_front}.{s_back}"
|
||||
else:
|
||||
second = f"{s_front}.{bm[1]}.{s_back}"
|
||||
elif third == "alpha":
|
||||
bma = []
|
||||
if match(bma, re_processor_alpha, second):
|
||||
if int(bma[1]) == 0:
|
||||
second = f"{bma[0]}"
|
||||
else:
|
||||
second = f"{bma[0]}.{bma[1]}"
|
||||
|
||||
if first is None:
|
||||
return f"processor.{second}.{third}"
|
||||
|
||||
return f"attn{first}.processor.{second}.{third}"
|
||||
|
||||
if match(m, re_unet_transformer_attn_blocks, key):
|
||||
return f"{get_front_block(m[0], m[1], m[2], m[3])}.{get_back_block(m[4], m[5], m[6])}"
|
||||
|
||||
if match(m, re_unet_transformer_blocks, key):
|
||||
return f"{get_front_block(m[0], m[1], m[2], m[3])}.{get_back_block(None, m[4], m[5])}"
|
||||
|
||||
if match(m, re_unet_mid_transformer_blocks, key):
|
||||
return f"{get_front_block('mid', None, m[0], m[1])}.{get_back_block(None, m[2], m[3])}"
|
||||
|
||||
if match(m, re_unet_norm_blocks, key):
|
||||
return f"{get_front_block(m[0], m[1], m[2])}.{get_back_block(None, m[3], m[4])}"
|
||||
|
||||
if match(m, re_unet_mid_blocks, key):
|
||||
return f"{get_front_block('mid', None, m[0])}.{get_back_block(None, m[1], m[2])}"
|
||||
|
||||
return key
|
||||
|
||||
|
||||
def load_lora_attn(
|
||||
name: str,
|
||||
path_file: Path,
|
||||
@@ -289,10 +288,10 @@ def load_lora_attn(
|
||||
for key in list(checkpoint.keys()):
|
||||
if key.startswith(wrapper.LORA_PREFIX_UNET):
|
||||
# convert unet keys
|
||||
checkpoint[convert_key_to_diffusers(key)] = checkpoint.pop(key)
|
||||
checkpoint[wrapper.convert_key_to_diffusers(key)] = checkpoint.pop(key)
|
||||
elif key.startswith(wrapper.LORA_PREFIX_UNET):
|
||||
# convert text encoder keys (not yet supported)
|
||||
# state_dict[convert_key_to_diffusers(key)] = state_dict.pop(key)
|
||||
# state_dict[wrapper.convert_key_to_diffusers(key)] = state_dict.pop(key)
|
||||
checkpoint.pop(key)
|
||||
else:
|
||||
# remove invalid key
|
||||
@@ -304,7 +303,6 @@ def load_lora_attn(
|
||||
|
||||
class LoraManager:
|
||||
loras_to_load: dict[str, float]
|
||||
hooks: list[RemovableHandle]
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.lora_path = Path(global_models_dir(), 'lora')
|
||||
@@ -356,7 +354,6 @@ class LoraManager:
|
||||
self.wrapper.applied_loras[name] = lora
|
||||
|
||||
def load_lora(self):
|
||||
print(self.loras_to_load)
|
||||
for name, multiplier in self.loras_to_load.items():
|
||||
self.apply_lora_model(name, multiplier)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user