mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-19 09:54:24 -05:00
switch all none diffusers stuff to legacy, and load through compel prompts
This commit is contained in:
@@ -488,10 +488,11 @@ class Generate:
|
||||
self.sampler_name = sampler_name
|
||||
self._set_sampler()
|
||||
|
||||
# To try and load lora not trained through diffusers
|
||||
if self.model.lora_manager:
|
||||
prompt = self.model.lora_manager.configure_prompt(prompt)
|
||||
prompt = self.model.lora_manager.configure_prompt_legacy(prompt)
|
||||
# lora MUST process prompt before conditioning
|
||||
self.model.lora_manager.load_lora()
|
||||
self.model.lora_manager.load_lora_legacy()
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
||||
|
||||
@@ -59,6 +59,8 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
|
||||
positive_prompt = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
if model.lora_manager:
|
||||
model.lora_manager.load_lora_compel(positive_prompt.lora_weights)
|
||||
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
|
||||
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
|
||||
@@ -273,7 +273,7 @@ class LoRA:
|
||||
return
|
||||
|
||||
|
||||
class LegacyLora:
|
||||
class LegacyLoraManager:
|
||||
def __init__(self, pipe, lora_path):
|
||||
self.unet = pipe.unet
|
||||
self.lora_path = lora_path
|
||||
@@ -281,6 +281,7 @@ class LegacyLora:
|
||||
self.text_encoder = pipe.text_encoder
|
||||
self.device = torch.device(choose_torch_device())
|
||||
self.dtype = pipe.unet.dtype
|
||||
self.loras_to_load = {}
|
||||
|
||||
def load_lora_module(self, name, path_file, multiplier: float = 1.0):
|
||||
# can be used instead to load through diffusers, once enough support is added
|
||||
@@ -298,6 +299,26 @@ class LegacyLora:
|
||||
|
||||
return lora
|
||||
|
||||
def configure_prompt(self, prompt: str) -> str:
|
||||
self.clear_loras()
|
||||
|
||||
# lora_match = re.compile(r"<lora:([^>]+)>")
|
||||
lora_match = re.compile(r"withLoraLegacy\(([a-zA-Z\,\d]+)\)")
|
||||
|
||||
for match in re.findall(lora_match, prompt):
|
||||
# match = match.split(':')
|
||||
match = match.split(',')
|
||||
name = match[0].strip()
|
||||
|
||||
mult = 1.0
|
||||
if len(match) == 2:
|
||||
mult = float(match[1].strip())
|
||||
|
||||
self.set_lora(name, mult)
|
||||
|
||||
# remove lora and return prompt to avoid the lora prompt causing issues in inference
|
||||
return re.sub(lora_match, "", prompt)
|
||||
|
||||
def apply_lora_model(self, name, mult: float = 1.0):
|
||||
path_file = Path(self.lora_path, f'{name}.ckpt')
|
||||
if Path(self.lora_path, f'{name}.safetensors').exists():
|
||||
@@ -314,6 +335,10 @@ class LegacyLora:
|
||||
lora.multiplier = mult
|
||||
self.wrapper.applied_loras[name] = lora
|
||||
|
||||
def load_lora(self):
|
||||
for name, multiplier in self.loras_to_load.items():
|
||||
self.apply_lora_model(name, multiplier)
|
||||
|
||||
def unload_applied_loras(self, loras_to_load):
|
||||
# unload any lora's not defined by loras_to_load
|
||||
for name in list(self.wrapper.applied_loras.keys()):
|
||||
@@ -329,71 +354,45 @@ class LegacyLora:
|
||||
del self.wrapper.loaded_loras[lora_name]
|
||||
|
||||
def set_lora(self, name, multiplier: float = 1.0):
|
||||
self.loras_to_load[name] = multiplier
|
||||
|
||||
# update the multiplier if the lora was already loaded
|
||||
if name in self.wrapper.loaded_loras:
|
||||
self.wrapper.loaded_loras[name].multiplier = multiplier
|
||||
|
||||
def clear_loras(self):
|
||||
self.loras_to_load = {}
|
||||
self.wrapper.clear_applied_loras()
|
||||
|
||||
def __del__(self):
|
||||
del self.loras_to_load
|
||||
|
||||
|
||||
class LoraManager:
|
||||
loras_to_load: dict[str, float]
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.lora_path = Path(global_models_dir(), 'lora')
|
||||
self.unet = pipe.unet
|
||||
self.loras_to_load = {}
|
||||
# Legacy class handles lora not generated through diffusers
|
||||
self.legacy = LegacyLora(pipe, self.lora_path)
|
||||
self.legacy = LegacyLoraManager(pipe, self.lora_path)
|
||||
|
||||
def apply_lora_model(self, name, mult: float = 1.0):
|
||||
def apply_lora_model(self, name):
|
||||
path = Path(self.lora_path, name)
|
||||
file = Path(path, "pytorch_lora_weights.bin")
|
||||
|
||||
if path.is_dir() and file.is_file():
|
||||
print(f"loading lora: {path}")
|
||||
print(f">> Loading LoRA: {path}")
|
||||
self.unet.load_attn_procs(path.absolute().as_posix())
|
||||
else:
|
||||
self.legacy.apply_lora_model(name, mult)
|
||||
print(f">> Unable to find valid LoRA at: {path}")
|
||||
|
||||
def load_lora(self):
|
||||
for name, multiplier in self.loras_to_load.items():
|
||||
self.apply_lora_model(name, multiplier)
|
||||
def load_lora_compel(self, lora_weights: list):
|
||||
if len(lora_weights) > 0:
|
||||
for lora in lora_weights:
|
||||
self.apply_lora_model(lora.model)
|
||||
|
||||
self.legacy.unload_applied_loras(self.loras_to_load)
|
||||
# Legacy functions, to pipe to LoraLegacyManager
|
||||
def configure_prompt_legacy(self, prompt: str) -> str:
|
||||
return self.legacy.configure_prompt(prompt)
|
||||
|
||||
# Define a lora to be loaded
|
||||
# Can be used to define a lora to be loaded outside of prompts
|
||||
def set_lora(self, name, multiplier: float = 1.0):
|
||||
self.loras_to_load[name] = multiplier
|
||||
self.legacy.set_lora(name, multiplier)
|
||||
|
||||
# Load the lora from a prompt, syntax is <lora:lora_name:multiplier>
|
||||
# Multiplier should be a value between 0.0 and 1.0
|
||||
def configure_prompt(self, prompt: str) -> str:
|
||||
self.clear_loras()
|
||||
|
||||
# lora_match = re.compile(r"<lora:([^>]+)>")
|
||||
lora_match = re.compile(r"withLora\(([a-zA-Z\,\d]+)\)")
|
||||
|
||||
for match in re.findall(lora_match, prompt):
|
||||
# match = match.split(':')
|
||||
match = match.split(',')
|
||||
name = match[0].strip()
|
||||
|
||||
mult = 1.0
|
||||
if len(match) == 2:
|
||||
mult = float(match[1].strip())
|
||||
|
||||
self.set_lora(name, mult)
|
||||
|
||||
# remove lora and return prompt to avoid the lora prompt causing issues in inference
|
||||
return re.sub(lora_match, "", prompt)
|
||||
|
||||
def clear_loras(self):
|
||||
self.loras_to_load = {}
|
||||
self.legacy.clear_loras()
|
||||
|
||||
def __del__(self):
|
||||
del self.loras_to_load
|
||||
def load_lora_legacy(self):
|
||||
self.legacy.load_lora()
|
||||
|
||||
Reference in New Issue
Block a user