mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
implement caching scheme for vector length
This commit is contained in:
@@ -1,24 +1,28 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from filelock import FileLock, Timeout
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from ..invoke.globals import global_lora_models_dir
|
||||
from ..invoke.devices import choose_torch_device
|
||||
|
||||
|
||||
"""
|
||||
This module supports loading LoRA weights trained with https://github.com/kohya-ss/sd-scripts
|
||||
To be removed once support for diffusers LoRA weights is well supported
|
||||
"""
|
||||
|
||||
|
||||
class IncompatibleModelException(Exception):
|
||||
"Raised when there is an attempt to load a LoRA into a model that is incompatible with it"
|
||||
pass
|
||||
|
||||
|
||||
class LoRALayer:
|
||||
lora_name: str
|
||||
name: str
|
||||
@@ -36,8 +40,7 @@ class LoRALayer:
|
||||
def forward(self, lora, input_h, output):
|
||||
if self.mid is None:
|
||||
output = (
|
||||
output
|
||||
+ self.up(self.down(*input_h)) * lora.multiplier * self.scale
|
||||
output + self.up(self.down(*input_h)) * lora.multiplier * self.scale
|
||||
)
|
||||
else:
|
||||
output = (
|
||||
@@ -46,6 +49,7 @@ class LoRALayer:
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class LoHALayer:
|
||||
lora_name: str
|
||||
name: str
|
||||
@@ -67,7 +71,6 @@ class LoHALayer:
|
||||
self.scale = alpha / rank if (alpha and rank) else 1.0
|
||||
|
||||
def forward(self, lora, input_h, output):
|
||||
|
||||
if type(self.org_module) == torch.nn.Conv2d:
|
||||
op = torch.nn.functional.conv2d
|
||||
extra_args = dict(
|
||||
@@ -82,20 +85,29 @@ class LoHALayer:
|
||||
extra_args = {}
|
||||
|
||||
if self.t1 is None:
|
||||
weight = ((self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b))
|
||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a)
|
||||
rebuild1 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a
|
||||
)
|
||||
rebuild2 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a
|
||||
)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
return output + op(
|
||||
*input_h,
|
||||
(weight + bias).view(self.org_module.weight.shape),
|
||||
None,
|
||||
**extra_args,
|
||||
) * lora.multiplier * self.scale
|
||||
return (
|
||||
output
|
||||
+ op(
|
||||
*input_h,
|
||||
(weight + bias).view(self.org_module.weight.shape),
|
||||
None,
|
||||
**extra_args,
|
||||
)
|
||||
* lora.multiplier
|
||||
* self.scale
|
||||
)
|
||||
|
||||
|
||||
class LoRAModuleWrapper:
|
||||
@@ -113,12 +125,22 @@ class LoRAModuleWrapper:
|
||||
self.applied_loras = {}
|
||||
self.loaded_loras = {}
|
||||
|
||||
self.UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D", "SpatialTransformer"]
|
||||
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = ["ResidualAttentionBlock", "CLIPAttention", "CLIPMLP"]
|
||||
self.UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
"Attention",
|
||||
"ResnetBlock2D",
|
||||
"Downsample2D",
|
||||
"Upsample2D",
|
||||
"SpatialTransformer",
|
||||
]
|
||||
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = [
|
||||
"ResidualAttentionBlock",
|
||||
"CLIPAttention",
|
||||
"CLIPMLP",
|
||||
]
|
||||
self.LORA_PREFIX_UNET = "lora_unet"
|
||||
self.LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
def find_modules(
|
||||
prefix, root_module: torch.nn.Module, target_replace_modules
|
||||
) -> dict[str, torch.nn.Module]:
|
||||
@@ -149,7 +171,6 @@ class LoRAModuleWrapper:
|
||||
self.LORA_PREFIX_UNET, unet, self.UNET_TARGET_REPLACE_MODULE
|
||||
)
|
||||
|
||||
|
||||
def lora_forward_hook(self, name):
|
||||
wrapper = self
|
||||
|
||||
@@ -182,6 +203,7 @@ class LoRAModuleWrapper:
|
||||
def clear_loaded_loras(self):
|
||||
self.loaded_loras.clear()
|
||||
|
||||
|
||||
class LoRA:
|
||||
name: str
|
||||
layers: dict[str, LoRALayer]
|
||||
@@ -207,7 +229,6 @@ class LoRA:
|
||||
state_dict_groupped[stem] = dict()
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
|
||||
for stem, values in state_dict_groupped.items():
|
||||
if stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER):
|
||||
wrapped = self.wrapper.text_modules.get(stem, None)
|
||||
@@ -228,34 +249,59 @@ class LoRA:
|
||||
if "alpha" in values:
|
||||
alpha = values["alpha"].item()
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
if (
|
||||
"bias_indices" in values
|
||||
and "bias_values" in values
|
||||
and "bias_size" in values
|
||||
):
|
||||
bias = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
value_down = values["lora_down.weight"]
|
||||
value_mid = values.get("lora_mid.weight", None)
|
||||
value_up = values["lora_up.weight"]
|
||||
value_mid = values.get("lora_mid.weight", None)
|
||||
value_up = values["lora_up.weight"]
|
||||
|
||||
if type(wrapped) == torch.nn.Conv2d:
|
||||
if value_mid is not None:
|
||||
layer_down = torch.nn.Conv2d(value_down.shape[1], value_down.shape[0], (1, 1), bias=False)
|
||||
layer_mid = torch.nn.Conv2d(value_mid.shape[1], value_mid.shape[0], wrapped.kernel_size, wrapped.stride, wrapped.padding, bias=False)
|
||||
layer_down = torch.nn.Conv2d(
|
||||
value_down.shape[1], value_down.shape[0], (1, 1), bias=False
|
||||
)
|
||||
layer_mid = torch.nn.Conv2d(
|
||||
value_mid.shape[1],
|
||||
value_mid.shape[0],
|
||||
wrapped.kernel_size,
|
||||
wrapped.stride,
|
||||
wrapped.padding,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
layer_down = torch.nn.Conv2d(value_down.shape[1], value_down.shape[0], wrapped.kernel_size, wrapped.stride, wrapped.padding, bias=False)
|
||||
layer_mid = None
|
||||
layer_down = torch.nn.Conv2d(
|
||||
value_down.shape[1],
|
||||
value_down.shape[0],
|
||||
wrapped.kernel_size,
|
||||
wrapped.stride,
|
||||
wrapped.padding,
|
||||
bias=False,
|
||||
)
|
||||
layer_mid = None
|
||||
|
||||
layer_up = torch.nn.Conv2d(value_up.shape[1], value_up.shape[0], (1, 1), bias=False)
|
||||
layer_up = torch.nn.Conv2d(
|
||||
value_up.shape[1], value_up.shape[0], (1, 1), bias=False
|
||||
)
|
||||
|
||||
elif type(wrapped) == torch.nn.Linear:
|
||||
layer_down = torch.nn.Linear(value_down.shape[1], value_down.shape[0], bias=False)
|
||||
layer_mid = None
|
||||
layer_up = torch.nn.Linear(value_up.shape[1], value_up.shape[0], bias=False)
|
||||
layer_down = torch.nn.Linear(
|
||||
value_down.shape[1], value_down.shape[0], bias=False
|
||||
)
|
||||
layer_mid = None
|
||||
layer_up = torch.nn.Linear(
|
||||
value_up.shape[1], value_up.shape[0], bias=False
|
||||
)
|
||||
|
||||
else:
|
||||
print(
|
||||
@@ -263,49 +309,57 @@ class LoRA:
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
layer_down.weight.copy_(value_down)
|
||||
if layer_mid is not None:
|
||||
layer_mid.weight.copy_(value_mid)
|
||||
layer_up.weight.copy_(value_up)
|
||||
|
||||
|
||||
layer_down.to(device=self.device, dtype=self.dtype)
|
||||
if layer_mid is not None:
|
||||
layer_mid.to(device=self.device, dtype=self.dtype)
|
||||
layer_up.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
rank = value_down.shape[0]
|
||||
|
||||
layer = LoRALayer(self.name, stem, rank, alpha)
|
||||
#layer.bias = bias # TODO: find and debug lora/locon with bias
|
||||
# layer.bias = bias # TODO: find and debug lora/locon with bias
|
||||
layer.down = layer_down
|
||||
layer.mid = layer_mid
|
||||
layer.up = layer_up
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
|
||||
rank = values["hada_w1_b"].shape[0]
|
||||
|
||||
layer = LoHALayer(self.name, stem, rank, alpha)
|
||||
layer.org_module = wrapped
|
||||
layer.bias = bias
|
||||
|
||||
layer.w1_a = values["hada_w1_a"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w1_b = values["hada_w1_b"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w2_a = values["hada_w2_a"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w2_b = values["hada_w2_b"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w1_a = values["hada_w1_a"].to(
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
layer.w1_b = values["hada_w1_b"].to(
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
layer.w2_a = values["hada_w2_a"].to(
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
layer.w2_b = values["hada_w2_b"].to(
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
if "hada_t1" in values:
|
||||
layer.t1 = values["hada_t1"].to(device=self.device, dtype=self.dtype)
|
||||
layer.t1 = values["hada_t1"].to(
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
else:
|
||||
layer.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
layer.t2 = values["hada_t2"].to(device=self.device, dtype=self.dtype)
|
||||
layer.t2 = values["hada_t2"].to(
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
else:
|
||||
layer.t2 = None
|
||||
|
||||
@@ -319,9 +373,11 @@ class LoRA:
|
||||
|
||||
|
||||
class KohyaLoraManager:
|
||||
def __init__(self, pipe, lora_path):
|
||||
lora_path = Path(global_lora_models_dir())
|
||||
vector_length_cache_path = lora_path / '.vectorlength.cache'
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.unet = pipe.unet
|
||||
self.lora_path = lora_path
|
||||
self.wrapper = LoRAModuleWrapper(pipe.unet, pipe.text_encoder)
|
||||
self.text_encoder = pipe.text_encoder
|
||||
self.device = torch.device(choose_torch_device())
|
||||
@@ -333,10 +389,10 @@ class KohyaLoraManager:
|
||||
checkpoint = load_file(path_file.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(path_file, map_location="cpu")
|
||||
|
||||
|
||||
if not self.check_model_compatibility(checkpoint):
|
||||
raise IncompatibleModelException
|
||||
|
||||
|
||||
lora = LoRA(name, self.device, self.dtype, self.wrapper, multiplier)
|
||||
lora.load_from_dict(checkpoint)
|
||||
self.wrapper.loaded_loras[name] = lora
|
||||
@@ -362,17 +418,17 @@ class KohyaLoraManager:
|
||||
lora.multiplier = mult
|
||||
self.wrapper.applied_loras[name] = lora
|
||||
|
||||
def unload_applied_lora(self, lora_name: str)->bool:
|
||||
'''If the indicated LoRA has previously been applied then
|
||||
def unload_applied_lora(self, lora_name: str) -> bool:
|
||||
"""If the indicated LoRA has previously been applied then
|
||||
unload it and return True. Return False if the LoRA was
|
||||
not previously applied (for status reporting)
|
||||
'''
|
||||
"""
|
||||
if lora_name in self.wrapper.applied_loras:
|
||||
del self.wrapper.applied_loras[lora_name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def unload_lora(self, lora_name: str)->bool:
|
||||
def unload_lora(self, lora_name: str) -> bool:
|
||||
if lora_name in self.wrapper.loaded_loras:
|
||||
del self.wrapper.loaded_loras[lora_name]
|
||||
return True
|
||||
@@ -381,34 +437,70 @@ class KohyaLoraManager:
|
||||
def clear_loras(self):
|
||||
self.wrapper.clear_applied_loras()
|
||||
|
||||
def check_model_compatibility(self, checkpoint)->bool:
|
||||
'''Checks whether the LoRA checkpoint is compatible with the token vector
|
||||
def check_model_compatibility(self, checkpoint) -> bool:
|
||||
"""Checks whether the LoRA checkpoint is compatible with the token vector
|
||||
length of the model that this manager is associated with.
|
||||
'''
|
||||
model_token_vector_length = self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
||||
"""
|
||||
model_token_vector_length = (
|
||||
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
||||
)
|
||||
lora_token_vector_length = self.vector_length_from_checkpoint(checkpoint)
|
||||
return model_token_vector_length == lora_token_vector_length
|
||||
|
||||
|
||||
@staticmethod
|
||||
def vector_length_from_checkpoint(checkpoint:dict)->int:
|
||||
'''Return the vector token length for the passed LoRA checkpoint object.
|
||||
def vector_length_from_checkpoint(checkpoint: dict) -> int:
|
||||
"""Return the vector token length for the passed LoRA checkpoint object.
|
||||
This is used to determine which SD model version the LoRA was based on.
|
||||
768 -> SDv1
|
||||
1024-> SDv2
|
||||
'''
|
||||
key1 = 'lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight'
|
||||
key2 = 'lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a'
|
||||
lora_token_vector_length = checkpoint[key1].shape[1] if key1 in checkpoint \
|
||||
else checkpoint[key2].shape[0] if key2 in checkpoint \
|
||||
else 768
|
||||
"""
|
||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||
lora_token_vector_length = (
|
||||
checkpoint[key1].shape[1]
|
||||
if key1 in checkpoint
|
||||
else checkpoint[key2].shape[0]
|
||||
if key2 in checkpoint
|
||||
else 768
|
||||
)
|
||||
return lora_token_vector_length
|
||||
|
||||
@staticmethod
|
||||
def vector_length_from_checkpoint_file(checkpoint_path: Path)->int:
|
||||
if checkpoint_path.suffix == ".safetensors":
|
||||
checkpoint = load_file(checkpoint_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
return KohyaLoraManager.vector_length_from_checkpoint(checkpoint)
|
||||
@classmethod
|
||||
def vector_length_from_checkpoint_file(self, checkpoint_path: Path) -> int:
|
||||
with LoraVectorLengthCache(self.vector_length_cache_path) as cache:
|
||||
if str(checkpoint_path) not in cache:
|
||||
if checkpoint_path.suffix == ".safetensors":
|
||||
checkpoint = load_file(
|
||||
checkpoint_path.absolute().as_posix(), device="cpu"
|
||||
)
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
cache[str(checkpoint_path)] = KohyaLoraManager.vector_length_from_checkpoint(
|
||||
checkpoint
|
||||
)
|
||||
return cache[str(checkpoint_path)]
|
||||
|
||||
|
||||
class LoraVectorLengthCache(object):
|
||||
def __init__(self, cache_path: Path):
|
||||
self.cache_path = cache_path
|
||||
self.lock = FileLock(Path(cache_path.parent, ".cachelock"))
|
||||
self.cache = {}
|
||||
|
||||
def __enter__(self):
|
||||
self.lock.acquire(timeout=10)
|
||||
try:
|
||||
if self.cache_path.exists():
|
||||
with open(self.cache_path, "r") as json_file:
|
||||
self.cache = json.load(json_file)
|
||||
except Timeout:
|
||||
print(
|
||||
"** Can't acquire lock on lora vector length cache. Operations will be slower"
|
||||
)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
self.cache_path.unlink()
|
||||
return self.cache
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
with open(self.cache_path, "w") as json_file:
|
||||
json.dump(self.cache, json_file)
|
||||
self.lock.release()
|
||||
|
||||
@@ -43,7 +43,7 @@ class LoraCondition:
|
||||
class LoraManager:
|
||||
def __init__(self, pipe):
|
||||
# Kohya class handles lora not generated through diffusers
|
||||
self.kohya = KohyaLoraManager(pipe, global_lora_models_dir())
|
||||
self.kohya = KohyaLoraManager(pipe)
|
||||
|
||||
def set_loras_conditions(self, lora_weights: list):
|
||||
conditions = []
|
||||
|
||||
Reference in New Issue
Block a user