mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Merge branch 'feat/lora-support-2.3' of github.com:invoke-ai/InvokeAI into feat/lora-support-2.3
This commit is contained in:
@@ -22,6 +22,7 @@ import transformers
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from PIL import Image, ImageOps
|
||||
from pytorch_lightning import logging, seed_everything
|
||||
|
||||
@@ -992,9 +993,17 @@ class Generate:
|
||||
|
||||
self.model_name = model_name
|
||||
self._set_sampler() # requires self.model_name to be set first
|
||||
|
||||
self._save_last_used_model(model_name)
|
||||
return self.model
|
||||
|
||||
def _save_last_used_model(self,model_name:str):
|
||||
"""
|
||||
Save name of the last model used.
|
||||
"""
|
||||
model_file_path = Path(Globals.root,'.last_model')
|
||||
with open(model_file_path,'w') as f:
|
||||
f.write(model_name)
|
||||
|
||||
def load_huggingface_concepts(self, concepts: list[str]):
|
||||
self.model.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||
|
||||
|
||||
@@ -183,7 +183,6 @@ def main():
|
||||
# web server loops forever
|
||||
if opt.web or opt.gui:
|
||||
invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan)
|
||||
save_last_used_model(gen.model_name)
|
||||
sys.exit(0)
|
||||
|
||||
if not infile:
|
||||
@@ -504,7 +503,6 @@ def main_loop(gen, opt, completer):
|
||||
print(
|
||||
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
|
||||
)
|
||||
save_last_used_model(gen.model_name)
|
||||
|
||||
|
||||
# TO DO: remove repetitive code and the awkward command.replace() trope
|
||||
@@ -1301,14 +1299,6 @@ def retrieve_last_used_model()->str:
|
||||
with open(model_file_path,'r') as f:
|
||||
return f.readline()
|
||||
|
||||
def save_last_used_model(model_name:str):
|
||||
"""
|
||||
Save name of the last model used.
|
||||
"""
|
||||
model_file_path = Path(Globals.root,'.last_model')
|
||||
with open(model_file_path,'w') as f:
|
||||
f.write(model_name)
|
||||
|
||||
# This routine performs any patch-ups needed after installation
|
||||
def run_patches():
|
||||
install_missing_config_files()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from compel import Compel
|
||||
@@ -20,7 +21,9 @@ class LoRALayer:
|
||||
lora_name: str
|
||||
name: str
|
||||
scale: float
|
||||
|
||||
up: torch.nn.Module
|
||||
mid: Optional[torch.nn.Module] = None
|
||||
down: torch.nn.Module
|
||||
|
||||
def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
|
||||
@@ -28,6 +31,70 @@ class LoRALayer:
|
||||
self.name = name
|
||||
self.scale = alpha / rank if (alpha and rank) else 1.0
|
||||
|
||||
def forward(self, lora, input_h, output):
|
||||
if self.mid is None:
|
||||
output = (
|
||||
output
|
||||
+ self.up(self.down(*input_h)) * lora.multiplier * self.scale
|
||||
)
|
||||
else:
|
||||
output = (
|
||||
output
|
||||
+ self.up(self.mid(self.down(*input_h))) * lora.multiplier * self.scale
|
||||
)
|
||||
return output
|
||||
|
||||
class LoHALayer:
|
||||
lora_name: str
|
||||
name: str
|
||||
scale: float
|
||||
|
||||
w1_a: torch.Tensor
|
||||
w1_b: torch.Tensor
|
||||
w2_a: torch.Tensor
|
||||
w2_b: torch.Tensor
|
||||
t1: Optional[torch.Tensor] = None
|
||||
t2: Optional[torch.Tensor] = None
|
||||
bias: Optional[torch.Tensor] = None
|
||||
|
||||
org_module: torch.nn.Module
|
||||
|
||||
def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
|
||||
self.lora_name = lora_name
|
||||
self.name = name
|
||||
self.scale = alpha / rank if (alpha and rank) else 1.0
|
||||
|
||||
def forward(self, lora, input_h, output):
|
||||
|
||||
if type(self.org_module) == torch.nn.Conv2d:
|
||||
op = torch.nn.functional.conv2d
|
||||
extra_args = dict(
|
||||
stride=self.org_module.stride,
|
||||
padding=self.org_module.padding,
|
||||
dilation=self.org_module.dilation,
|
||||
groups=self.org_module.groups,
|
||||
)
|
||||
|
||||
else:
|
||||
op = torch.nn.functional.linear
|
||||
extra_args = {}
|
||||
|
||||
if self.t1 is None:
|
||||
weight = ((self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b))
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
return output + op(
|
||||
*input_h,
|
||||
(weight + bias).view(self.org_module.weight.shape),
|
||||
None,
|
||||
**extra_args,
|
||||
) * lora.multiplier * self.scale
|
||||
|
||||
|
||||
class LoRAModuleWrapper:
|
||||
unet: UNet2DConditionModel
|
||||
@@ -44,8 +111,8 @@ class LoRAModuleWrapper:
|
||||
self.applied_loras = {}
|
||||
self.loaded_loras = {}
|
||||
|
||||
self.UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
||||
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
self.UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D", "SpatialTransformer"]
|
||||
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = ["ResidualAttentionBlock", "CLIPAttention", "CLIPMLP"]
|
||||
self.LORA_PREFIX_UNET = "lora_unet"
|
||||
self.LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
@@ -60,7 +127,7 @@ class LoRAModuleWrapper:
|
||||
layer_type = child_module.__class__.__name__
|
||||
if layer_type == "Linear" or (
|
||||
layer_type == "Conv2d"
|
||||
and child_module.kernel_size == (1, 1)
|
||||
and child_module.kernel_size in [(1, 1), (3, 3)]
|
||||
):
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
@@ -92,10 +159,7 @@ class LoRAModuleWrapper:
|
||||
layer = lora.layers.get(name, None)
|
||||
if layer is None:
|
||||
continue
|
||||
output = (
|
||||
output
|
||||
+ layer.up(layer.down(*input_h)) * lora.multiplier * layer.scale
|
||||
)
|
||||
output = layer.forward(lora, input_h, output)
|
||||
return output
|
||||
|
||||
return lora_forward
|
||||
@@ -131,78 +195,126 @@ class LoRA:
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.wrapper = wrapper
|
||||
self.rank = None
|
||||
self.alpha = None
|
||||
|
||||
def load_from_dict(self, state_dict):
|
||||
state_dict_groupped = dict()
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = dict()
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
if leaf.endswith("alpha"):
|
||||
if self.alpha is None:
|
||||
self.alpha = value.item()
|
||||
continue
|
||||
|
||||
for stem, values in state_dict_groupped.items():
|
||||
if stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER):
|
||||
wrapped = self.wrapper.text_modules.get(stem, None)
|
||||
if wrapped is None:
|
||||
print(f">> Missing layer: {stem}")
|
||||
continue
|
||||
|
||||
if (
|
||||
self.rank is None
|
||||
and leaf == "lora_down.weight"
|
||||
and len(value.size()) == 2
|
||||
):
|
||||
self.rank = value.shape[0]
|
||||
self.load_lora_layer(stem, leaf, value, wrapped)
|
||||
continue
|
||||
elif stem.startswith(self.wrapper.LORA_PREFIX_UNET):
|
||||
wrapped = self.wrapper.unet_modules.get(stem, None)
|
||||
if wrapped is None:
|
||||
print(f">> Missing layer: {stem}")
|
||||
continue
|
||||
|
||||
if (
|
||||
self.rank is None
|
||||
and leaf == "lora_down.weight"
|
||||
and len(value.size()) == 2
|
||||
):
|
||||
self.rank = value.shape[0]
|
||||
self.load_lora_layer(stem, leaf, value, wrapped)
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
def load_lora_layer(self, stem: str, leaf: str, value, wrapped: torch.nn.Module):
|
||||
layer = self.layers.get(stem, None)
|
||||
if layer is None:
|
||||
layer = LoRALayer(self.name, stem, self.rank, self.alpha)
|
||||
if wrapped is None:
|
||||
print(f">> Missing layer: {stem}")
|
||||
continue
|
||||
|
||||
# TODO: diff key
|
||||
|
||||
bias = None
|
||||
alpha = None
|
||||
|
||||
if "alpha" in values:
|
||||
alpha = values["alpha"].item()
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
bias = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
value_down = values["lora_down.weight"]
|
||||
value_mid = values.get("lora_mid.weight", None)
|
||||
value_up = values["lora_up.weight"]
|
||||
|
||||
if type(wrapped) == torch.nn.Conv2d:
|
||||
if value_mid is not None:
|
||||
layer_down = torch.nn.Conv2d(value_down.shape[1], value_down.shape[0], (1, 1), bias=False)
|
||||
layer_mid = torch.nn.Conv2d(value_mid.shape[1], value_mid.shape[0], wrapped.kernel_size, wrapped.stride, wrapped.padding, bias=False)
|
||||
else:
|
||||
layer_down = torch.nn.Conv2d(value_down.shape[1], value_down.shape[0], wrapped.kernel_size, wrapped.stride, wrapped.padding, bias=False)
|
||||
layer_mid = None
|
||||
|
||||
layer_up = torch.nn.Conv2d(value_up.shape[1], value_up.shape[0], (1, 1), bias=False)
|
||||
|
||||
elif type(wrapped) == torch.nn.Linear:
|
||||
layer_down = torch.nn.Linear(value_down.shape[1], value_down.shape[0], bias=False)
|
||||
layer_mid = None
|
||||
layer_up = torch.nn.Linear(value_up.shape[1], value_up.shape[0], bias=False)
|
||||
|
||||
else:
|
||||
print(
|
||||
f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
layer_down.weight.copy_(value_down)
|
||||
if layer_mid is not None:
|
||||
layer_mid.weight.copy_(value_mid)
|
||||
layer_up.weight.copy_(value_up)
|
||||
|
||||
|
||||
layer_down.to(device=self.device, dtype=self.dtype)
|
||||
if layer_mid is not None:
|
||||
layer_mid.to(device=self.device, dtype=self.dtype)
|
||||
layer_up.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
rank = value_down.shape[0]
|
||||
|
||||
layer = LoRALayer(self.name, stem, rank, alpha)
|
||||
#layer.bias = bias # TODO: find and debug lora/locon with bias
|
||||
layer.down = layer_down
|
||||
layer.mid = layer_mid
|
||||
layer.up = layer_up
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
|
||||
rank = values["hada_w1_b"].shape[0]
|
||||
|
||||
layer = LoHALayer(self.name, stem, rank, alpha)
|
||||
layer.org_module = wrapped
|
||||
layer.bias = bias
|
||||
|
||||
layer.w1_a = values["hada_w1_a"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w1_b = values["hada_w1_b"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w2_a = values["hada_w2_a"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w2_b = values["hada_w2_b"].to(device=self.device, dtype=self.dtype)
|
||||
|
||||
if "hada_t1" in values:
|
||||
layer.t1 = values["hada_t1"].to(device=self.device, dtype=self.dtype)
|
||||
else:
|
||||
layer.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
layer.t2 = values["hada_t2"].to(device=self.device, dtype=self.dtype)
|
||||
else:
|
||||
layer.t2 = None
|
||||
|
||||
else:
|
||||
print(
|
||||
f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}"
|
||||
)
|
||||
return
|
||||
|
||||
self.layers[stem] = layer
|
||||
|
||||
if type(wrapped) == torch.nn.Linear:
|
||||
module = torch.nn.Linear(value.shape[1], value.shape[0], bias=False)
|
||||
elif type(wrapped) == torch.nn.Conv2d:
|
||||
module = torch.nn.Conv2d(value.shape[1], value.shape[0], (1, 1), bias=False)
|
||||
else:
|
||||
print(
|
||||
f">> Encountered unknown lora layer module in {self.name}: {type(value).__name__}"
|
||||
)
|
||||
return
|
||||
|
||||
with torch.no_grad():
|
||||
module.weight.copy_(value)
|
||||
|
||||
module.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
if leaf == "lora_up.weight":
|
||||
layer.up = module
|
||||
elif leaf == "lora_down.weight":
|
||||
layer.down = module
|
||||
else:
|
||||
print(f">> Encountered unknown layer in lora {self.name}: {leaf}")
|
||||
return
|
||||
|
||||
|
||||
class KohyaLoraManager:
|
||||
def __init__(self, pipe, lora_path):
|
||||
|
||||
Reference in New Issue
Block a user