mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-19 09:54:24 -05:00
start of rewrite for add / remove
This commit is contained in:
@@ -575,7 +575,13 @@ class Generate:
|
||||
save_original = save_original,
|
||||
image_callback = image_callback)
|
||||
|
||||
if self.model.lora_manager:
|
||||
self.model.lora_manager.reset_lora()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
if self.model.lora_manager:
|
||||
self.model.lora_manager.reset_lora()
|
||||
|
||||
# Clear the CUDA cache on an exception
|
||||
self.clear_cuda_cache()
|
||||
|
||||
@@ -584,6 +590,9 @@ class Generate:
|
||||
else:
|
||||
raise KeyboardInterrupt
|
||||
except RuntimeError:
|
||||
if self.model.lora_manager:
|
||||
self.model.lora_manager.reset_lora()
|
||||
|
||||
# Clear the CUDA cache on an exception
|
||||
self.clear_cuda_cache()
|
||||
|
||||
|
||||
@@ -3,44 +3,99 @@ from pathlib import Path
|
||||
from ldm.invoke.globals import global_models_dir
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing import List, Optional, Set, Type
|
||||
|
||||
# modified from script at https://github.com/huggingface/diffusers/pull/2403
|
||||
def merge_lora_into_pipe(pipeline, checkpoint_path, alpha):
|
||||
# load LoRA weight from .safetensors
|
||||
state_dict = load_file(checkpoint_path, device=torch.cuda.current_device())
|
||||
|
||||
visited = []
|
||||
class LoraLinear(torch.nn.Module):
|
||||
def __init__(
|
||||
self, in_features, out_features, rank=4
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# directly update weight in diffusers model
|
||||
for key in state_dict:
|
||||
if rank > min(in_features, out_features):
|
||||
raise ValueError(
|
||||
f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}"
|
||||
)
|
||||
self.rank = rank
|
||||
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
|
||||
self.lora = torch.nn.Linear(in_features, out_features, bias=False)
|
||||
|
||||
# it is suggested to print out the key, it usually will be something like below
|
||||
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
||||
def forward(self, hidden_states):
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.lora.weight.dtype
|
||||
return self.lora(hidden_states.to(dtype)).to(orig_dtype)
|
||||
|
||||
# as we have set the alpha beforehand, so just skip
|
||||
if ".alpha" in key or key in visited:
|
||||
continue
|
||||
if "text" in key:
|
||||
layer_infos = key.split(".")[0].split("lora_te" + "_")[-1].split("_")
|
||||
curr_layer = pipeline.text_encoder
|
||||
|
||||
class LoraManager:
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.pipe = pipe
|
||||
self.lora_path = Path(global_models_dir(), 'lora')
|
||||
self.lora_match = re.compile(r"<lora:([^>]+)>")
|
||||
self.prompt = None
|
||||
|
||||
def _process_lora(self, lora):
|
||||
processed_lora = {
|
||||
"unet": [],
|
||||
"text_encoder": []
|
||||
}
|
||||
visited = []
|
||||
for key in lora:
|
||||
if ".alpha" in key or key in visited:
|
||||
continue
|
||||
if "text" in key:
|
||||
lora_type, pair_keys = self._find_layer(
|
||||
"text_encoder",
|
||||
key.split(".")[0].split("lora_te" + "_")[-1].split("_"),
|
||||
key
|
||||
)
|
||||
else:
|
||||
lora_type, pair_keys = self._find_layer(
|
||||
"unet",
|
||||
key.split(".")[0].split("lora_unet" + "_")[-1].split("_"),
|
||||
key
|
||||
)
|
||||
|
||||
if len(lora[pair_keys[0]].shape) == 4:
|
||||
weight_up = lora[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = lora[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight = torch.mm(weight_up, weight_down)
|
||||
else:
|
||||
weight_up = lora[pair_keys[0]].to(torch.float32)
|
||||
weight_down = lora[pair_keys[1]].to(torch.float32)
|
||||
weight = torch.mm(weight_up, weight_down)
|
||||
|
||||
processed_lora[lora_type].append({
|
||||
"weight": weight,
|
||||
"rank": lora[pair_keys[1]].shape[0]
|
||||
})
|
||||
|
||||
for item in pair_keys:
|
||||
visited.append(item)
|
||||
|
||||
return processed_lora
|
||||
|
||||
def _find_layer(self, lora_type, layer_key, key):
|
||||
temp_name = layer_key.pop(0)
|
||||
if lora_type == "unet":
|
||||
curr_layer = self.pipe.unet
|
||||
elif lora_type == "text_encoder":
|
||||
curr_layer = self.pipe.text_encoder
|
||||
else:
|
||||
layer_infos = key.split(".")[0].split("lora_unet" + "_")[-1].split("_")
|
||||
curr_layer = pipeline.unet
|
||||
raise ValueError("Invalid Lora Type")
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
while len(layer_infos) > -1:
|
||||
while len(layer_key) > -1:
|
||||
try:
|
||||
curr_layer = curr_layer.__getattr__(temp_name)
|
||||
if len(layer_infos) > 0:
|
||||
temp_name = layer_infos.pop(0)
|
||||
elif len(layer_infos) == 0:
|
||||
if len(layer_key) > 0:
|
||||
temp_name = layer_key.pop(0)
|
||||
elif len(layer_key) == 0:
|
||||
break
|
||||
except Exception:
|
||||
if len(temp_name) > 0:
|
||||
temp_name += "_" + layer_infos.pop(0)
|
||||
temp_name += "_" + layer_key.pop(0)
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
temp_name = layer_key.pop(0)
|
||||
|
||||
pair_keys = []
|
||||
if "lora_down" in key:
|
||||
@@ -50,29 +105,89 @@ def merge_lora_into_pipe(pipeline, checkpoint_path, alpha):
|
||||
pair_keys.append(key)
|
||||
pair_keys.append(key.replace("lora_up", "lora_down"))
|
||||
|
||||
# update weight
|
||||
if len(state_dict[pair_keys[0]].shape) == 4:
|
||||
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
||||
curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
return lora_type, pair_keys
|
||||
|
||||
@staticmethod
|
||||
def _find_modules(
|
||||
model,
|
||||
ancestor_class: Optional[Set[str]] = None,
|
||||
search_class: List[Type[torch.nn.Module]] = [torch.nn.Linear],
|
||||
exclude_children_of: Optional[List[Type[torch.nn.Module]]] = [LoraLinear],
|
||||
):
|
||||
"""
|
||||
Find all modules of a certain class (or union of classes) that are direct or
|
||||
indirect descendants of other modules of a certain class (or union of classes).
|
||||
Returns all matching modules, along with the parent of those modules and the
|
||||
names they are referenced by.
|
||||
"""
|
||||
|
||||
# Get the targets we should replace all linears under
|
||||
if ancestor_class is not None:
|
||||
ancestors = (
|
||||
module
|
||||
for module in model.modules()
|
||||
if module.__class__.__name__ in ancestor_class
|
||||
)
|
||||
else:
|
||||
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
||||
curr_layer.weight.data += float(alpha) * torch.mm(weight_up, weight_down)
|
||||
# this, incase you want to naively iterate over all modules.
|
||||
ancestors = [module for module in model.modules()]
|
||||
|
||||
# update visited list
|
||||
for item in pair_keys:
|
||||
visited.append(item)
|
||||
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
||||
for ancestor in ancestors:
|
||||
for fullname, module in ancestor.named_modules():
|
||||
if any([isinstance(module, _class) for _class in search_class]):
|
||||
# Find the direct parent if this is a descendant, not a child, of target
|
||||
*path, name = fullname.split(".")
|
||||
parent = ancestor
|
||||
while path:
|
||||
parent = parent.get_submodule(path.pop(0))
|
||||
# Skip this linear if it's a child of a LoraInjectedLinear
|
||||
if exclude_children_of and any(
|
||||
[isinstance(parent, _class) for _class in exclude_children_of]
|
||||
):
|
||||
continue
|
||||
# Otherwise, yield it
|
||||
yield parent, name, module
|
||||
|
||||
@staticmethod
|
||||
def patch_module(lora_type, processed_lora, module, name, child_module, scale: float = 1.0):
|
||||
_source = (
|
||||
child_module.linear
|
||||
if isinstance(child_module, LoraLinear)
|
||||
else child_module
|
||||
)
|
||||
|
||||
class LoraManager:
|
||||
lora = processed_lora[lora_type].pop(0)
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.weights = {}
|
||||
self.pipe = pipe
|
||||
self.lora_path = Path(global_models_dir(), 'lora')
|
||||
self.lora_match = re.compile(r"<lora:([^>]+)>")
|
||||
self.prompt = None
|
||||
weight = _source.weight
|
||||
_tmp = LoraLinear(
|
||||
in_features=_source.in_features,
|
||||
out_features=_source.out_features,
|
||||
rank=lora["rank"]
|
||||
)
|
||||
_tmp.linear.weight = weight
|
||||
|
||||
# switch the module
|
||||
module._modules[name] = _tmp
|
||||
module._modules[name].lora.weight.data = lora["weight"]
|
||||
module._modules[name].to(weight.device)
|
||||
|
||||
def patch_lora(self, lora_path, scale: float = 1.0):
|
||||
lora = load_file(lora_path)
|
||||
processed_lora = self._process_lora(lora)
|
||||
for module, name, child_module in self._find_modules(
|
||||
self.pipe.unet,
|
||||
{"CrossAttention", "Attention", "GEGLU"},
|
||||
search_class=[torch.nn.Linear, LoraLinear]
|
||||
):
|
||||
self.patch_module("unet", processed_lora, module, name, child_module, scale)
|
||||
|
||||
for module, name, child_module in self._find_modules(
|
||||
self.pipe.text_encoder,
|
||||
{"CLIPAttention"},
|
||||
search_class=[torch.nn.Linear, LoraLinear]
|
||||
):
|
||||
self.patch_module("text_encoder", processed_lora, module, name, child_module, scale)
|
||||
|
||||
def apply_lora_model(self, args):
|
||||
args = args.split(':')
|
||||
@@ -87,14 +202,34 @@ class LoraManager:
|
||||
else:
|
||||
file = Path(self.lora_path, f"{name}.safetensors")
|
||||
print(f"loading lora: {file}")
|
||||
alpha = 1
|
||||
scale = 1.0
|
||||
if len(args) == 2:
|
||||
alpha = args[1]
|
||||
scale = float(args[1])
|
||||
|
||||
merge_lora_into_pipe(self.pipe, file.absolute().as_posix(), alpha)
|
||||
self.patch_lora(file.absolute().as_posix(), scale)
|
||||
|
||||
@staticmethod
|
||||
def remove_lora(child_module):
|
||||
_source = child_module.linear
|
||||
weight = _source.weight
|
||||
|
||||
_tmp = torch.nn.Linear(_source.in_features, _source.out_features)
|
||||
_tmp.weight = weight
|
||||
|
||||
def reset_lora(self):
|
||||
for module, name, child_module in self._find_modules(
|
||||
self.pipe.unet,
|
||||
search_class=[LoraLinear]
|
||||
):
|
||||
self.remove_lora(child_module)
|
||||
|
||||
for module, name, child_module in self._find_modules(
|
||||
self.pipe.text_encoder,
|
||||
search_class=[LoraLinear]
|
||||
):
|
||||
self.remove_lora(child_module)
|
||||
|
||||
def load_lora_from_prompt(self, prompt: str):
|
||||
|
||||
for m in re.findall(self.lora_match, prompt):
|
||||
self.apply_lora_model(m)
|
||||
|
||||
@@ -108,4 +243,3 @@ class LoraManager:
|
||||
return ""
|
||||
|
||||
return re.sub(self.lora_match, found, prompt)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user