change to new method to load safetensors

This commit is contained in:
Jordan
2023-02-19 17:33:24 -07:00
parent 5a7145c485
commit 82e4d5aed2
2 changed files with 70 additions and 106 deletions

View File

@@ -1,7 +1,68 @@
import re import re
from pathlib import Path from pathlib import Path
from ldm.invoke.globals import global_models_dir from ldm.invoke.globals import global_models_dir
from lora_diffusion import tune_lora_scale, patch_pipe import torch
from safetensors.torch import load_file
# modified from script at https://github.com/huggingface/diffusers/pull/2403
def merge_lora_into_pipe(pipeline, checkpoint_path, alpha):
# load LoRA weight from .safetensors
state_dict = load_file(checkpoint_path, device=torch.cuda.current_device())
visited = []
# directly update weight in diffusers model
for key in state_dict:
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# as we have set the alpha beforehand, so just skip
if ".alpha" in key or key in visited:
continue
if "text" in key:
layer_infos = key.split(".")[0].split("lora_te" + "_")[-1].split("_")
curr_layer = pipeline.text_encoder
else:
layer_infos = key.split(".")[0].split("lora_unet" + "_")[-1].split("_")
curr_layer = pipeline.unet
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
class LoraManager: class LoraManager:
@@ -16,32 +77,21 @@ class LoraManager:
def apply_lora_model(self, args): def apply_lora_model(self, args):
args = args.split(':') args = args.split(':')
name = args[0] name = args[0]
path = Path(self.lora_path, name) path = Path(self.lora_path, name)
file = Path(path, "pytorch_lora_weights.bin") file = Path(path, "pytorch_lora_weights.bin")
if path.is_dir() and file.is_file(): if path.is_dir() and file.is_file():
print(f"loading lora: {path}") print(f"loading diffusers lora: {path}")
self.pipe.unet.load_attn_procs(path.absolute().as_posix()) self.pipe.unet.load_attn_procs(path.absolute().as_posix())
if len(args) == 2:
self.weights[name] = args[1]
else: else:
# converting and saving in diffusers format file = Path(self.lora_path, f"{name}.safetensors")
path_file = Path(self.lora_path, f'{name}.ckpt') print(f"loading lora: {file}")
if Path(self.lora_path, f'{name}.safetensors').exists(): alpha = 1
path_file = Path(self.lora_path, f'{name}.safetensors') if len(args) == 2:
alpha = args[1]
if path_file.is_file(): merge_lora_into_pipe(self.pipe, file.absolute().as_posix(), alpha)
print(f"loading lora: {path}")
patch_pipe(
self.pipe,
path_file.absolute().as_posix(),
patch_text=True,
patch_ti=True,
patch_unet=True,
)
if len(args) == 2:
tune_lora_scale(self.pipe.unet, args[1])
tune_lora_scale(self.pipe.text_encoder, args[1])
def load_lora_from_prompt(self, prompt: str): def load_lora_from_prompt(self, prompt: str):

View File

@@ -1,86 +0,0 @@
#!/usr/bin/env python
import re
from pathlib import Path
import torch
from safetensors.torch import load_file
import argparse
from diffusers import UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import create_unet_diffusers_config
from omegaconf import OmegaConf
import requests
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Convert kohya lora to diffusers")
parser.add_argument(
"--lora_file",
type=str,
default=None,
required=True,
help="Lora file to convert",
)
parser.add_argument(
"--output_dir",
type=str,
default="models/lora",
help="The output directory where converted lora will be saved",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
return args
def replace_key_blocks(match_obj):
k = match_obj.groups()
return f"{k[0]}.{k[1]}"
def replace_key_out(match_obj):
return f"to_out"
def replace_key_main(match_obj):
k = match_obj.groups()
block = re.sub(r"(.+)_(\d+)", replace_key_blocks, k[0])
out = re.sub(r"to_out_(\d+)", replace_key_out, k[4])
return f"{block}.attentions.{k[1]}.transformer_blocks.{k[2]}.attn{k[3]}.processor.{out}_lora.{k[5]}"
def main(args):
response = requests.get(
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
original_config = OmegaConf.create(response.text)
new_dict = dict()
lora_file = Path(args.lora_file)
if lora_file.suffix == '.safetensors':
checkpoint = load_file(args.lora_file)
else:
checkpoint = torch.load(args.lora_file)
for idx, key in enumerate(checkpoint):
check = re.compile(r"lora_unet_(.+)_attentions_(\d+)_transformer_blocks_(\d+)_attn(\d+)_(.+).lora_(.+)")
if check.match(key):
new_key = check.sub(replace_key_main, key)
new_dict[new_key] = checkpoint[key]
unet_config = create_unet_diffusers_config(original_config, image_size=512)
unet = UNet2DConditionModel(**unet_config)
unet.load_attn_procs(new_dict)
output_dir = Path(args.output_dir, lora_file.name.split('.')[0])
unet.save_attn_procs(output_dir.absolute().as_posix())
if __name__ == "__main__":
args = parse_args()
main(args)