mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-01 03:01:13 -04:00
change to new method to load safetensors
This commit is contained in:
@@ -1,7 +1,68 @@
|
|||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ldm.invoke.globals import global_models_dir
|
from ldm.invoke.globals import global_models_dir
|
||||||
from lora_diffusion import tune_lora_scale, patch_pipe
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
# modified from script at https://github.com/huggingface/diffusers/pull/2403
|
||||||
|
def merge_lora_into_pipe(pipeline, checkpoint_path, alpha):
|
||||||
|
# load LoRA weight from .safetensors
|
||||||
|
state_dict = load_file(checkpoint_path, device=torch.cuda.current_device())
|
||||||
|
|
||||||
|
visited = []
|
||||||
|
|
||||||
|
# directly update weight in diffusers model
|
||||||
|
for key in state_dict:
|
||||||
|
|
||||||
|
# it is suggested to print out the key, it usually will be something like below
|
||||||
|
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
||||||
|
|
||||||
|
# as we have set the alpha beforehand, so just skip
|
||||||
|
if ".alpha" in key or key in visited:
|
||||||
|
continue
|
||||||
|
if "text" in key:
|
||||||
|
layer_infos = key.split(".")[0].split("lora_te" + "_")[-1].split("_")
|
||||||
|
curr_layer = pipeline.text_encoder
|
||||||
|
else:
|
||||||
|
layer_infos = key.split(".")[0].split("lora_unet" + "_")[-1].split("_")
|
||||||
|
curr_layer = pipeline.unet
|
||||||
|
|
||||||
|
# find the target layer
|
||||||
|
temp_name = layer_infos.pop(0)
|
||||||
|
while len(layer_infos) > -1:
|
||||||
|
try:
|
||||||
|
curr_layer = curr_layer.__getattr__(temp_name)
|
||||||
|
if len(layer_infos) > 0:
|
||||||
|
temp_name = layer_infos.pop(0)
|
||||||
|
elif len(layer_infos) == 0:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
if len(temp_name) > 0:
|
||||||
|
temp_name += "_" + layer_infos.pop(0)
|
||||||
|
else:
|
||||||
|
temp_name = layer_infos.pop(0)
|
||||||
|
|
||||||
|
pair_keys = []
|
||||||
|
if "lora_down" in key:
|
||||||
|
pair_keys.append(key.replace("lora_down", "lora_up"))
|
||||||
|
pair_keys.append(key)
|
||||||
|
else:
|
||||||
|
pair_keys.append(key)
|
||||||
|
pair_keys.append(key.replace("lora_up", "lora_down"))
|
||||||
|
|
||||||
|
# update weight
|
||||||
|
if len(state_dict[pair_keys[0]].shape) == 4:
|
||||||
|
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
||||||
|
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
||||||
|
curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
||||||
|
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
||||||
|
curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down)
|
||||||
|
|
||||||
|
# update visited list
|
||||||
|
for item in pair_keys:
|
||||||
|
visited.append(item)
|
||||||
|
|
||||||
|
|
||||||
class LoraManager:
|
class LoraManager:
|
||||||
@@ -16,32 +77,21 @@ class LoraManager:
|
|||||||
def apply_lora_model(self, args):
|
def apply_lora_model(self, args):
|
||||||
args = args.split(':')
|
args = args.split(':')
|
||||||
name = args[0]
|
name = args[0]
|
||||||
|
|
||||||
path = Path(self.lora_path, name)
|
path = Path(self.lora_path, name)
|
||||||
file = Path(path, "pytorch_lora_weights.bin")
|
file = Path(path, "pytorch_lora_weights.bin")
|
||||||
|
|
||||||
if path.is_dir() and file.is_file():
|
if path.is_dir() and file.is_file():
|
||||||
print(f"loading lora: {path}")
|
print(f"loading diffusers lora: {path}")
|
||||||
self.pipe.unet.load_attn_procs(path.absolute().as_posix())
|
self.pipe.unet.load_attn_procs(path.absolute().as_posix())
|
||||||
if len(args) == 2:
|
|
||||||
self.weights[name] = args[1]
|
|
||||||
else:
|
else:
|
||||||
# converting and saving in diffusers format
|
file = Path(self.lora_path, f"{name}.safetensors")
|
||||||
path_file = Path(self.lora_path, f'{name}.ckpt')
|
print(f"loading lora: {file}")
|
||||||
if Path(self.lora_path, f'{name}.safetensors').exists():
|
alpha = 1
|
||||||
path_file = Path(self.lora_path, f'{name}.safetensors')
|
if len(args) == 2:
|
||||||
|
alpha = args[1]
|
||||||
|
|
||||||
if path_file.is_file():
|
merge_lora_into_pipe(self.pipe, file.absolute().as_posix(), alpha)
|
||||||
print(f"loading lora: {path}")
|
|
||||||
patch_pipe(
|
|
||||||
self.pipe,
|
|
||||||
path_file.absolute().as_posix(),
|
|
||||||
patch_text=True,
|
|
||||||
patch_ti=True,
|
|
||||||
patch_unet=True,
|
|
||||||
)
|
|
||||||
if len(args) == 2:
|
|
||||||
tune_lora_scale(self.pipe.unet, args[1])
|
|
||||||
tune_lora_scale(self.pipe.text_encoder, args[1])
|
|
||||||
|
|
||||||
def load_lora_from_prompt(self, prompt: str):
|
def load_lora_from_prompt(self, prompt: str):
|
||||||
|
|
||||||
|
|||||||
@@ -1,86 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
import argparse
|
|
||||||
from diffusers import UNet2DConditionModel
|
|
||||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import create_unet_diffusers_config
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args(input_args=None):
|
|
||||||
parser = argparse.ArgumentParser(description="Convert kohya lora to diffusers")
|
|
||||||
parser.add_argument(
|
|
||||||
"--lora_file",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
required=True,
|
|
||||||
help="Lora file to convert",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_dir",
|
|
||||||
type=str,
|
|
||||||
default="models/lora",
|
|
||||||
help="The output directory where converted lora will be saved",
|
|
||||||
)
|
|
||||||
|
|
||||||
if input_args is not None:
|
|
||||||
args = parser.parse_args(input_args)
|
|
||||||
else:
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def replace_key_blocks(match_obj):
|
|
||||||
k = match_obj.groups()
|
|
||||||
|
|
||||||
return f"{k[0]}.{k[1]}"
|
|
||||||
|
|
||||||
|
|
||||||
def replace_key_out(match_obj):
|
|
||||||
return f"to_out"
|
|
||||||
|
|
||||||
|
|
||||||
def replace_key_main(match_obj):
|
|
||||||
k = match_obj.groups()
|
|
||||||
block = re.sub(r"(.+)_(\d+)", replace_key_blocks, k[0])
|
|
||||||
out = re.sub(r"to_out_(\d+)", replace_key_out, k[4])
|
|
||||||
|
|
||||||
return f"{block}.attentions.{k[1]}.transformer_blocks.{k[2]}.attn{k[3]}.processor.{out}_lora.{k[5]}"
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
response = requests.get(
|
|
||||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
|
||||||
)
|
|
||||||
original_config = OmegaConf.create(response.text)
|
|
||||||
|
|
||||||
new_dict = dict()
|
|
||||||
lora_file = Path(args.lora_file)
|
|
||||||
|
|
||||||
if lora_file.suffix == '.safetensors':
|
|
||||||
checkpoint = load_file(args.lora_file)
|
|
||||||
else:
|
|
||||||
checkpoint = torch.load(args.lora_file)
|
|
||||||
|
|
||||||
for idx, key in enumerate(checkpoint):
|
|
||||||
check = re.compile(r"lora_unet_(.+)_attentions_(\d+)_transformer_blocks_(\d+)_attn(\d+)_(.+).lora_(.+)")
|
|
||||||
if check.match(key):
|
|
||||||
new_key = check.sub(replace_key_main, key)
|
|
||||||
new_dict[new_key] = checkpoint[key]
|
|
||||||
|
|
||||||
unet_config = create_unet_diffusers_config(original_config, image_size=512)
|
|
||||||
unet = UNet2DConditionModel(**unet_config)
|
|
||||||
unet.load_attn_procs(new_dict)
|
|
||||||
|
|
||||||
output_dir = Path(args.output_dir, lora_file.name.split('.')[0])
|
|
||||||
unet.save_attn_procs(output_dir.absolute().as_posix())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = parse_args()
|
|
||||||
main(args)
|
|
||||||
Reference in New Issue
Block a user