start of rewrite for add / remove

This commit is contained in:
Jordan
2023-02-20 02:37:44 -07:00
parent 82e4d5aed2
commit 096e1d3a5d
2 changed files with 191 additions and 48 deletions

View File

@@ -575,7 +575,13 @@ class Generate:
save_original = save_original,
image_callback = image_callback)
if self.model.lora_manager:
self.model.lora_manager.reset_lora()
except KeyboardInterrupt:
if self.model.lora_manager:
self.model.lora_manager.reset_lora()
# Clear the CUDA cache on an exception
self.clear_cuda_cache()
@@ -584,6 +590,9 @@ class Generate:
else:
raise KeyboardInterrupt
except RuntimeError:
if self.model.lora_manager:
self.model.lora_manager.reset_lora()
# Clear the CUDA cache on an exception
self.clear_cuda_cache()

View File

@@ -3,44 +3,99 @@ from pathlib import Path
from ldm.invoke.globals import global_models_dir
import torch
from safetensors.torch import load_file
from typing import List, Optional, Set, Type
# modified from script at https://github.com/huggingface/diffusers/pull/2403
def merge_lora_into_pipe(pipeline, checkpoint_path, alpha):
# load LoRA weight from .safetensors
state_dict = load_file(checkpoint_path, device=torch.cuda.current_device())
visited = []
class LoraLinear(torch.nn.Module):
def __init__(
self, in_features, out_features, rank=4
):
super().__init__()
# directly update weight in diffusers model
for key in state_dict:
if rank > min(in_features, out_features):
raise ValueError(
f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}"
)
self.rank = rank
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
self.lora = torch.nn.Linear(in_features, out_features, bias=False)
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.lora.weight.dtype
return self.lora(hidden_states.to(dtype)).to(orig_dtype)
# as we have set the alpha beforehand, so just skip
if ".alpha" in key or key in visited:
continue
if "text" in key:
layer_infos = key.split(".")[0].split("lora_te" + "_")[-1].split("_")
curr_layer = pipeline.text_encoder
class LoraManager:
def __init__(self, pipe):
self.pipe = pipe
self.lora_path = Path(global_models_dir(), 'lora')
self.lora_match = re.compile(r"<lora:([^>]+)>")
self.prompt = None
def _process_lora(self, lora):
processed_lora = {
"unet": [],
"text_encoder": []
}
visited = []
for key in lora:
if ".alpha" in key or key in visited:
continue
if "text" in key:
lora_type, pair_keys = self._find_layer(
"text_encoder",
key.split(".")[0].split("lora_te" + "_")[-1].split("_"),
key
)
else:
lora_type, pair_keys = self._find_layer(
"unet",
key.split(".")[0].split("lora_unet" + "_")[-1].split("_"),
key
)
if len(lora[pair_keys[0]].shape) == 4:
weight_up = lora[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = lora[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
weight = torch.mm(weight_up, weight_down)
else:
weight_up = lora[pair_keys[0]].to(torch.float32)
weight_down = lora[pair_keys[1]].to(torch.float32)
weight = torch.mm(weight_up, weight_down)
processed_lora[lora_type].append({
"weight": weight,
"rank": lora[pair_keys[1]].shape[0]
})
for item in pair_keys:
visited.append(item)
return processed_lora
def _find_layer(self, lora_type, layer_key, key):
temp_name = layer_key.pop(0)
if lora_type == "unet":
curr_layer = self.pipe.unet
elif lora_type == "text_encoder":
curr_layer = self.pipe.text_encoder
else:
layer_infos = key.split(".")[0].split("lora_unet" + "_")[-1].split("_")
curr_layer = pipeline.unet
raise ValueError("Invalid Lora Type")
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
while len(layer_key) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
if len(layer_key) > 0:
temp_name = layer_key.pop(0)
elif len(layer_key) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
temp_name += "_" + layer_key.pop(0)
else:
temp_name = layer_infos.pop(0)
temp_name = layer_key.pop(0)
pair_keys = []
if "lora_down" in key:
@@ -50,29 +105,89 @@ def merge_lora_into_pipe(pipeline, checkpoint_path, alpha):
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
return lora_type, pair_keys
@staticmethod
def _find_modules(
model,
ancestor_class: Optional[Set[str]] = None,
search_class: List[Type[torch.nn.Module]] = [torch.nn.Linear],
exclude_children_of: Optional[List[Type[torch.nn.Module]]] = [LoraLinear],
):
"""
Find all modules of a certain class (or union of classes) that are direct or
indirect descendants of other modules of a certain class (or union of classes).
Returns all matching modules, along with the parent of those modules and the
names they are referenced by.
"""
# Get the targets we should replace all linears under
if ancestor_class is not None:
ancestors = (
module
for module in model.modules()
if module.__class__.__name__ in ancestor_class
)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down)
# this, incase you want to naively iterate over all modules.
ancestors = [module for module in model.modules()]
# update visited list
for item in pair_keys:
visited.append(item)
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
for ancestor in ancestors:
for fullname, module in ancestor.named_modules():
if any([isinstance(module, _class) for _class in search_class]):
# Find the direct parent if this is a descendant, not a child, of target
*path, name = fullname.split(".")
parent = ancestor
while path:
parent = parent.get_submodule(path.pop(0))
# Skip this linear if it's a child of a LoraInjectedLinear
if exclude_children_of and any(
[isinstance(parent, _class) for _class in exclude_children_of]
):
continue
# Otherwise, yield it
yield parent, name, module
@staticmethod
def patch_module(lora_type, processed_lora, module, name, child_module, scale: float = 1.0):
_source = (
child_module.linear
if isinstance(child_module, LoraLinear)
else child_module
)
class LoraManager:
lora = processed_lora[lora_type].pop(0)
def __init__(self, pipe):
self.weights = {}
self.pipe = pipe
self.lora_path = Path(global_models_dir(), 'lora')
self.lora_match = re.compile(r"<lora:([^>]+)>")
self.prompt = None
weight = _source.weight
_tmp = LoraLinear(
in_features=_source.in_features,
out_features=_source.out_features,
rank=lora["rank"]
)
_tmp.linear.weight = weight
# switch the module
module._modules[name] = _tmp
module._modules[name].lora.weight.data = lora["weight"]
module._modules[name].to(weight.device)
def patch_lora(self, lora_path, scale: float = 1.0):
lora = load_file(lora_path)
processed_lora = self._process_lora(lora)
for module, name, child_module in self._find_modules(
self.pipe.unet,
{"CrossAttention", "Attention", "GEGLU"},
search_class=[torch.nn.Linear, LoraLinear]
):
self.patch_module("unet", processed_lora, module, name, child_module, scale)
for module, name, child_module in self._find_modules(
self.pipe.text_encoder,
{"CLIPAttention"},
search_class=[torch.nn.Linear, LoraLinear]
):
self.patch_module("text_encoder", processed_lora, module, name, child_module, scale)
def apply_lora_model(self, args):
args = args.split(':')
@@ -87,14 +202,34 @@ class LoraManager:
else:
file = Path(self.lora_path, f"{name}.safetensors")
print(f"loading lora: {file}")
alpha = 1
scale = 1.0
if len(args) == 2:
alpha = args[1]
scale = float(args[1])
merge_lora_into_pipe(self.pipe, file.absolute().as_posix(), alpha)
self.patch_lora(file.absolute().as_posix(), scale)
@staticmethod
def remove_lora(child_module):
_source = child_module.linear
weight = _source.weight
_tmp = torch.nn.Linear(_source.in_features, _source.out_features)
_tmp.weight = weight
def reset_lora(self):
for module, name, child_module in self._find_modules(
self.pipe.unet,
search_class=[LoraLinear]
):
self.remove_lora(child_module)
for module, name, child_module in self._find_modules(
self.pipe.text_encoder,
search_class=[LoraLinear]
):
self.remove_lora(child_module)
def load_lora_from_prompt(self, prompt: str):
for m in re.findall(self.lora_match, prompt):
self.apply_lora_model(m)
@@ -108,4 +243,3 @@ class LoraManager:
return ""
return re.sub(self.lora_match, found, prompt)