change to new method to load safetensors

This commit is contained in:
Jordan
2023-02-19 17:33:24 -07:00
parent 5a7145c485
commit 82e4d5aed2
2 changed files with 70 additions and 106 deletions

View File

@@ -1,7 +1,68 @@
import re
from pathlib import Path
from ldm.invoke.globals import global_models_dir
from lora_diffusion import tune_lora_scale, patch_pipe
import torch
from safetensors.torch import load_file
# modified from script at https://github.com/huggingface/diffusers/pull/2403
def merge_lora_into_pipe(pipeline, checkpoint_path, alpha):
# load LoRA weight from .safetensors
state_dict = load_file(checkpoint_path, device=torch.cuda.current_device())
visited = []
# directly update weight in diffusers model
for key in state_dict:
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# as we have set the alpha beforehand, so just skip
if ".alpha" in key or key in visited:
continue
if "text" in key:
layer_infos = key.split(".")[0].split("lora_te" + "_")[-1].split("_")
curr_layer = pipeline.text_encoder
else:
layer_infos = key.split(".")[0].split("lora_unet" + "_")[-1].split("_")
curr_layer = pipeline.unet
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
class LoraManager:
@@ -16,32 +77,21 @@ class LoraManager:
def apply_lora_model(self, args):
args = args.split(':')
name = args[0]
path = Path(self.lora_path, name)
file = Path(path, "pytorch_lora_weights.bin")
if path.is_dir() and file.is_file():
print(f"loading lora: {path}")
print(f"loading diffusers lora: {path}")
self.pipe.unet.load_attn_procs(path.absolute().as_posix())
if len(args) == 2:
self.weights[name] = args[1]
else:
# converting and saving in diffusers format
path_file = Path(self.lora_path, f'{name}.ckpt')
if Path(self.lora_path, f'{name}.safetensors').exists():
path_file = Path(self.lora_path, f'{name}.safetensors')
file = Path(self.lora_path, f"{name}.safetensors")
print(f"loading lora: {file}")
alpha = 1
if len(args) == 2:
alpha = args[1]
if path_file.is_file():
print(f"loading lora: {path}")
patch_pipe(
self.pipe,
path_file.absolute().as_posix(),
patch_text=True,
patch_ti=True,
patch_unet=True,
)
if len(args) == 2:
tune_lora_scale(self.pipe.unet, args[1])
tune_lora_scale(self.pipe.text_encoder, args[1])
merge_lora_into_pipe(self.pipe, file.absolute().as_posix(), alpha)
def load_lora_from_prompt(self, prompt: str):