From 82e4d5aed2ece8ce696a013014c187c910ef7432 Mon Sep 17 00:00:00 2001 From: Jordan Date: Sun, 19 Feb 2023 17:33:24 -0700 Subject: [PATCH] change to new method to load safetensors --- ldm/modules/lora_manager.py | 90 ++++++++++++++++++++++++++++--------- scripts/convert_lora.py | 86 ----------------------------------- 2 files changed, 70 insertions(+), 106 deletions(-) delete mode 100644 scripts/convert_lora.py diff --git a/ldm/modules/lora_manager.py b/ldm/modules/lora_manager.py index 844c2a554f..008864a4b0 100644 --- a/ldm/modules/lora_manager.py +++ b/ldm/modules/lora_manager.py @@ -1,7 +1,68 @@ import re from pathlib import Path from ldm.invoke.globals import global_models_dir -from lora_diffusion import tune_lora_scale, patch_pipe +import torch +from safetensors.torch import load_file + +# modified from script at https://github.com/huggingface/diffusers/pull/2403 +def merge_lora_into_pipe(pipeline, checkpoint_path, alpha): + # load LoRA weight from .safetensors + state_dict = load_file(checkpoint_path, device=torch.cuda.current_device()) + + visited = [] + + # directly update weight in diffusers model + for key in state_dict: + + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + # as we have set the alpha beforehand, so just skip + if ".alpha" in key or key in visited: + continue + if "text" in key: + layer_infos = key.split(".")[0].split("lora_te" + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = key.split(".")[0].split("lora_unet" + "_")[-1].split("_") + curr_layer = pipeline.unet + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down) + + # update visited list + for item in pair_keys: + visited.append(item) class LoraManager: @@ -16,32 +77,21 @@ class LoraManager: def apply_lora_model(self, args): args = args.split(':') name = args[0] + path = Path(self.lora_path, name) file = Path(path, "pytorch_lora_weights.bin") if path.is_dir() and file.is_file(): - print(f"loading lora: {path}") + print(f"loading diffusers lora: {path}") self.pipe.unet.load_attn_procs(path.absolute().as_posix()) - if len(args) == 2: - self.weights[name] = args[1] else: - # converting and saving in diffusers format - path_file = Path(self.lora_path, f'{name}.ckpt') - if Path(self.lora_path, f'{name}.safetensors').exists(): - path_file = Path(self.lora_path, f'{name}.safetensors') + file = Path(self.lora_path, f"{name}.safetensors") + print(f"loading lora: {file}") + alpha = 1 + if len(args) == 2: + alpha = args[1] - if path_file.is_file(): - print(f"loading lora: {path}") - patch_pipe( - self.pipe, - path_file.absolute().as_posix(), - patch_text=True, - patch_ti=True, - patch_unet=True, - ) - if len(args) == 2: - tune_lora_scale(self.pipe.unet, args[1]) - tune_lora_scale(self.pipe.text_encoder, args[1]) + merge_lora_into_pipe(self.pipe, file.absolute().as_posix(), alpha) def load_lora_from_prompt(self, prompt: str): diff --git a/scripts/convert_lora.py b/scripts/convert_lora.py deleted file mode 100644 index b30890c756..0000000000 --- a/scripts/convert_lora.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python - -import re -from pathlib import Path -import torch -from safetensors.torch import load_file -import argparse -from diffusers import UNet2DConditionModel -from diffusers.pipelines.stable_diffusion.convert_from_ckpt import create_unet_diffusers_config -from omegaconf import OmegaConf -import requests - - -def parse_args(input_args=None): - parser = argparse.ArgumentParser(description="Convert kohya lora to diffusers") - parser.add_argument( - "--lora_file", - type=str, - default=None, - required=True, - help="Lora file to convert", - ) - parser.add_argument( - "--output_dir", - type=str, - default="models/lora", - help="The output directory where converted lora will be saved", - ) - - if input_args is not None: - args = parser.parse_args(input_args) - else: - args = parser.parse_args() - - return args - - -def replace_key_blocks(match_obj): - k = match_obj.groups() - - return f"{k[0]}.{k[1]}" - - -def replace_key_out(match_obj): - return f"to_out" - - -def replace_key_main(match_obj): - k = match_obj.groups() - block = re.sub(r"(.+)_(\d+)", replace_key_blocks, k[0]) - out = re.sub(r"to_out_(\d+)", replace_key_out, k[4]) - - return f"{block}.attentions.{k[1]}.transformer_blocks.{k[2]}.attn{k[3]}.processor.{out}_lora.{k[5]}" - - -def main(args): - response = requests.get( - "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - ) - original_config = OmegaConf.create(response.text) - - new_dict = dict() - lora_file = Path(args.lora_file) - - if lora_file.suffix == '.safetensors': - checkpoint = load_file(args.lora_file) - else: - checkpoint = torch.load(args.lora_file) - - for idx, key in enumerate(checkpoint): - check = re.compile(r"lora_unet_(.+)_attentions_(\d+)_transformer_blocks_(\d+)_attn(\d+)_(.+).lora_(.+)") - if check.match(key): - new_key = check.sub(replace_key_main, key) - new_dict[new_key] = checkpoint[key] - - unet_config = create_unet_diffusers_config(original_config, image_size=512) - unet = UNet2DConditionModel(**unet_config) - unet.load_attn_procs(new_dict) - - output_dir = Path(args.output_dir, lora_file.name.split('.')[0]) - unet.save_attn_procs(output_dir.absolute().as_posix()) - - -if __name__ == "__main__": - args = parse_args() - main(args)