Merge branch 'feat/lora-support-2.3' of github.com:invoke-ai/InvokeAI into feat/lora-support-2.3

This commit is contained in:
Lincoln Stein
2023-04-05 22:02:01 -04:00
3 changed files with 186 additions and 75 deletions

View File

@@ -22,6 +22,7 @@ import transformers
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.utils.import_utils import is_xformers_available
from omegaconf import OmegaConf
from pathlib import Path
from PIL import Image, ImageOps
from pytorch_lightning import logging, seed_everything
@@ -992,9 +993,17 @@ class Generate:
self.model_name = model_name
self._set_sampler() # requires self.model_name to be set first
self._save_last_used_model(model_name)
return self.model
def _save_last_used_model(self,model_name:str):
"""
Save name of the last model used.
"""
model_file_path = Path(Globals.root,'.last_model')
with open(model_file_path,'w') as f:
f.write(model_name)
def load_huggingface_concepts(self, concepts: list[str]):
self.model.textual_inversion_manager.load_huggingface_concepts(concepts)

View File

@@ -183,7 +183,6 @@ def main():
# web server loops forever
if opt.web or opt.gui:
invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan)
save_last_used_model(gen.model_name)
sys.exit(0)
if not infile:
@@ -504,7 +503,6 @@ def main_loop(gen, opt, completer):
print(
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
)
save_last_used_model(gen.model_name)
# TO DO: remove repetitive code and the awkward command.replace() trope
@@ -1301,14 +1299,6 @@ def retrieve_last_used_model()->str:
with open(model_file_path,'r') as f:
return f.readline()
def save_last_used_model(model_name:str):
"""
Save name of the last model used.
"""
model_file_path = Path(Globals.root,'.last_model')
with open(model_file_path,'w') as f:
f.write(model_name)
# This routine performs any patch-ups needed after installation
def run_patches():
install_missing_config_files()

View File

@@ -1,5 +1,6 @@
import re
from pathlib import Path
from typing import Optional
import torch
from compel import Compel
@@ -20,7 +21,9 @@ class LoRALayer:
lora_name: str
name: str
scale: float
up: torch.nn.Module
mid: Optional[torch.nn.Module] = None
down: torch.nn.Module
def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
@@ -28,6 +31,70 @@ class LoRALayer:
self.name = name
self.scale = alpha / rank if (alpha and rank) else 1.0
def forward(self, lora, input_h, output):
if self.mid is None:
output = (
output
+ self.up(self.down(*input_h)) * lora.multiplier * self.scale
)
else:
output = (
output
+ self.up(self.mid(self.down(*input_h))) * lora.multiplier * self.scale
)
return output
class LoHALayer:
lora_name: str
name: str
scale: float
w1_a: torch.Tensor
w1_b: torch.Tensor
w2_a: torch.Tensor
w2_b: torch.Tensor
t1: Optional[torch.Tensor] = None
t2: Optional[torch.Tensor] = None
bias: Optional[torch.Tensor] = None
org_module: torch.nn.Module
def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
self.lora_name = lora_name
self.name = name
self.scale = alpha / rank if (alpha and rank) else 1.0
def forward(self, lora, input_h, output):
if type(self.org_module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=self.org_module.stride,
padding=self.org_module.padding,
dilation=self.org_module.dilation,
groups=self.org_module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
if self.t1 is None:
weight = ((self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b))
else:
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
bias = self.bias if self.bias is not None else 0
return output + op(
*input_h,
(weight + bias).view(self.org_module.weight.shape),
None,
**extra_args,
) * lora.multiplier * self.scale
class LoRAModuleWrapper:
unet: UNet2DConditionModel
@@ -44,8 +111,8 @@ class LoRAModuleWrapper:
self.applied_loras = {}
self.loaded_loras = {}
self.UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
self.UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D", "SpatialTransformer"]
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = ["ResidualAttentionBlock", "CLIPAttention", "CLIPMLP"]
self.LORA_PREFIX_UNET = "lora_unet"
self.LORA_PREFIX_TEXT_ENCODER = "lora_te"
@@ -60,7 +127,7 @@ class LoRAModuleWrapper:
layer_type = child_module.__class__.__name__
if layer_type == "Linear" or (
layer_type == "Conv2d"
and child_module.kernel_size == (1, 1)
and child_module.kernel_size in [(1, 1), (3, 3)]
):
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
@@ -92,10 +159,7 @@ class LoRAModuleWrapper:
layer = lora.layers.get(name, None)
if layer is None:
continue
output = (
output
+ layer.up(layer.down(*input_h)) * lora.multiplier * layer.scale
)
output = layer.forward(lora, input_h, output)
return output
return lora_forward
@@ -131,78 +195,126 @@ class LoRA:
self.device = device
self.dtype = dtype
self.wrapper = wrapper
self.rank = None
self.alpha = None
def load_from_dict(self, state_dict):
state_dict_groupped = dict()
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value
if leaf.endswith("alpha"):
if self.alpha is None:
self.alpha = value.item()
continue
for stem, values in state_dict_groupped.items():
if stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER):
wrapped = self.wrapper.text_modules.get(stem, None)
if wrapped is None:
print(f">> Missing layer: {stem}")
continue
if (
self.rank is None
and leaf == "lora_down.weight"
and len(value.size()) == 2
):
self.rank = value.shape[0]
self.load_lora_layer(stem, leaf, value, wrapped)
continue
elif stem.startswith(self.wrapper.LORA_PREFIX_UNET):
wrapped = self.wrapper.unet_modules.get(stem, None)
if wrapped is None:
print(f">> Missing layer: {stem}")
continue
if (
self.rank is None
and leaf == "lora_down.weight"
and len(value.size()) == 2
):
self.rank = value.shape[0]
self.load_lora_layer(stem, leaf, value, wrapped)
continue
else:
continue
def load_lora_layer(self, stem: str, leaf: str, value, wrapped: torch.nn.Module):
layer = self.layers.get(stem, None)
if layer is None:
layer = LoRALayer(self.name, stem, self.rank, self.alpha)
if wrapped is None:
print(f">> Missing layer: {stem}")
continue
# TODO: diff key
bias = None
alpha = None
if "alpha" in values:
alpha = values["alpha"].item()
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
).to(device=self.device, dtype=self.dtype)
# lora and locon
if "lora_down.weight" in values:
value_down = values["lora_down.weight"]
value_mid = values.get("lora_mid.weight", None)
value_up = values["lora_up.weight"]
if type(wrapped) == torch.nn.Conv2d:
if value_mid is not None:
layer_down = torch.nn.Conv2d(value_down.shape[1], value_down.shape[0], (1, 1), bias=False)
layer_mid = torch.nn.Conv2d(value_mid.shape[1], value_mid.shape[0], wrapped.kernel_size, wrapped.stride, wrapped.padding, bias=False)
else:
layer_down = torch.nn.Conv2d(value_down.shape[1], value_down.shape[0], wrapped.kernel_size, wrapped.stride, wrapped.padding, bias=False)
layer_mid = None
layer_up = torch.nn.Conv2d(value_up.shape[1], value_up.shape[0], (1, 1), bias=False)
elif type(wrapped) == torch.nn.Linear:
layer_down = torch.nn.Linear(value_down.shape[1], value_down.shape[0], bias=False)
layer_mid = None
layer_up = torch.nn.Linear(value_up.shape[1], value_up.shape[0], bias=False)
else:
print(
f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}"
)
return
with torch.no_grad():
layer_down.weight.copy_(value_down)
if layer_mid is not None:
layer_mid.weight.copy_(value_mid)
layer_up.weight.copy_(value_up)
layer_down.to(device=self.device, dtype=self.dtype)
if layer_mid is not None:
layer_mid.to(device=self.device, dtype=self.dtype)
layer_up.to(device=self.device, dtype=self.dtype)
rank = value_down.shape[0]
layer = LoRALayer(self.name, stem, rank, alpha)
#layer.bias = bias # TODO: find and debug lora/locon with bias
layer.down = layer_down
layer.mid = layer_mid
layer.up = layer_up
# loha
elif "hada_w1_b" in values:
rank = values["hada_w1_b"].shape[0]
layer = LoHALayer(self.name, stem, rank, alpha)
layer.org_module = wrapped
layer.bias = bias
layer.w1_a = values["hada_w1_a"].to(device=self.device, dtype=self.dtype)
layer.w1_b = values["hada_w1_b"].to(device=self.device, dtype=self.dtype)
layer.w2_a = values["hada_w2_a"].to(device=self.device, dtype=self.dtype)
layer.w2_b = values["hada_w2_b"].to(device=self.device, dtype=self.dtype)
if "hada_t1" in values:
layer.t1 = values["hada_t1"].to(device=self.device, dtype=self.dtype)
else:
layer.t1 = None
if "hada_t2" in values:
layer.t2 = values["hada_t2"].to(device=self.device, dtype=self.dtype)
else:
layer.t2 = None
else:
print(
f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}"
)
return
self.layers[stem] = layer
if type(wrapped) == torch.nn.Linear:
module = torch.nn.Linear(value.shape[1], value.shape[0], bias=False)
elif type(wrapped) == torch.nn.Conv2d:
module = torch.nn.Conv2d(value.shape[1], value.shape[0], (1, 1), bias=False)
else:
print(
f">> Encountered unknown lora layer module in {self.name}: {type(value).__name__}"
)
return
with torch.no_grad():
module.weight.copy_(value)
module.to(device=self.device, dtype=self.dtype)
if leaf == "lora_up.weight":
layer.up = module
elif leaf == "lora_down.weight":
layer.down = module
else:
print(f">> Encountered unknown layer in lora {self.name}: {leaf}")
return
class KohyaLoraManager:
def __init__(self, pipe, lora_path):