mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
switch all none diffusers stuff to legacy, and load through compel prompts
This commit is contained in:
@@ -488,10 +488,11 @@ class Generate:
|
|||||||
self.sampler_name = sampler_name
|
self.sampler_name = sampler_name
|
||||||
self._set_sampler()
|
self._set_sampler()
|
||||||
|
|
||||||
|
# To try and load lora not trained through diffusers
|
||||||
if self.model.lora_manager:
|
if self.model.lora_manager:
|
||||||
prompt = self.model.lora_manager.configure_prompt(prompt)
|
prompt = self.model.lora_manager.configure_prompt_legacy(prompt)
|
||||||
# lora MUST process prompt before conditioning
|
# lora MUST process prompt before conditioning
|
||||||
self.model.lora_manager.load_lora()
|
self.model.lora_manager.load_lora_legacy()
|
||||||
|
|
||||||
# apply the concepts library to the prompt
|
# apply the concepts library to the prompt
|
||||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
|
|||||||
positive_prompt = legacy_blend
|
positive_prompt = legacy_blend
|
||||||
else:
|
else:
|
||||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||||
|
if model.lora_manager:
|
||||||
|
model.lora_manager.load_lora_compel(positive_prompt.lora_weights)
|
||||||
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
|
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
|
||||||
|
|
||||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||||
|
|||||||
@@ -273,7 +273,7 @@ class LoRA:
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
class LegacyLora:
|
class LegacyLoraManager:
|
||||||
def __init__(self, pipe, lora_path):
|
def __init__(self, pipe, lora_path):
|
||||||
self.unet = pipe.unet
|
self.unet = pipe.unet
|
||||||
self.lora_path = lora_path
|
self.lora_path = lora_path
|
||||||
@@ -281,6 +281,7 @@ class LegacyLora:
|
|||||||
self.text_encoder = pipe.text_encoder
|
self.text_encoder = pipe.text_encoder
|
||||||
self.device = torch.device(choose_torch_device())
|
self.device = torch.device(choose_torch_device())
|
||||||
self.dtype = pipe.unet.dtype
|
self.dtype = pipe.unet.dtype
|
||||||
|
self.loras_to_load = {}
|
||||||
|
|
||||||
def load_lora_module(self, name, path_file, multiplier: float = 1.0):
|
def load_lora_module(self, name, path_file, multiplier: float = 1.0):
|
||||||
# can be used instead to load through diffusers, once enough support is added
|
# can be used instead to load through diffusers, once enough support is added
|
||||||
@@ -298,6 +299,26 @@ class LegacyLora:
|
|||||||
|
|
||||||
return lora
|
return lora
|
||||||
|
|
||||||
|
def configure_prompt(self, prompt: str) -> str:
|
||||||
|
self.clear_loras()
|
||||||
|
|
||||||
|
# lora_match = re.compile(r"<lora:([^>]+)>")
|
||||||
|
lora_match = re.compile(r"withLoraLegacy\(([a-zA-Z\,\d]+)\)")
|
||||||
|
|
||||||
|
for match in re.findall(lora_match, prompt):
|
||||||
|
# match = match.split(':')
|
||||||
|
match = match.split(',')
|
||||||
|
name = match[0].strip()
|
||||||
|
|
||||||
|
mult = 1.0
|
||||||
|
if len(match) == 2:
|
||||||
|
mult = float(match[1].strip())
|
||||||
|
|
||||||
|
self.set_lora(name, mult)
|
||||||
|
|
||||||
|
# remove lora and return prompt to avoid the lora prompt causing issues in inference
|
||||||
|
return re.sub(lora_match, "", prompt)
|
||||||
|
|
||||||
def apply_lora_model(self, name, mult: float = 1.0):
|
def apply_lora_model(self, name, mult: float = 1.0):
|
||||||
path_file = Path(self.lora_path, f'{name}.ckpt')
|
path_file = Path(self.lora_path, f'{name}.ckpt')
|
||||||
if Path(self.lora_path, f'{name}.safetensors').exists():
|
if Path(self.lora_path, f'{name}.safetensors').exists():
|
||||||
@@ -314,6 +335,10 @@ class LegacyLora:
|
|||||||
lora.multiplier = mult
|
lora.multiplier = mult
|
||||||
self.wrapper.applied_loras[name] = lora
|
self.wrapper.applied_loras[name] = lora
|
||||||
|
|
||||||
|
def load_lora(self):
|
||||||
|
for name, multiplier in self.loras_to_load.items():
|
||||||
|
self.apply_lora_model(name, multiplier)
|
||||||
|
|
||||||
def unload_applied_loras(self, loras_to_load):
|
def unload_applied_loras(self, loras_to_load):
|
||||||
# unload any lora's not defined by loras_to_load
|
# unload any lora's not defined by loras_to_load
|
||||||
for name in list(self.wrapper.applied_loras.keys()):
|
for name in list(self.wrapper.applied_loras.keys()):
|
||||||
@@ -329,71 +354,45 @@ class LegacyLora:
|
|||||||
del self.wrapper.loaded_loras[lora_name]
|
del self.wrapper.loaded_loras[lora_name]
|
||||||
|
|
||||||
def set_lora(self, name, multiplier: float = 1.0):
|
def set_lora(self, name, multiplier: float = 1.0):
|
||||||
|
self.loras_to_load[name] = multiplier
|
||||||
|
|
||||||
# update the multiplier if the lora was already loaded
|
# update the multiplier if the lora was already loaded
|
||||||
if name in self.wrapper.loaded_loras:
|
if name in self.wrapper.loaded_loras:
|
||||||
self.wrapper.loaded_loras[name].multiplier = multiplier
|
self.wrapper.loaded_loras[name].multiplier = multiplier
|
||||||
|
|
||||||
def clear_loras(self):
|
def clear_loras(self):
|
||||||
|
self.loras_to_load = {}
|
||||||
self.wrapper.clear_applied_loras()
|
self.wrapper.clear_applied_loras()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
del self.loras_to_load
|
||||||
|
|
||||||
|
|
||||||
class LoraManager:
|
class LoraManager:
|
||||||
loras_to_load: dict[str, float]
|
|
||||||
|
|
||||||
def __init__(self, pipe):
|
def __init__(self, pipe):
|
||||||
self.lora_path = Path(global_models_dir(), 'lora')
|
self.lora_path = Path(global_models_dir(), 'lora')
|
||||||
self.unet = pipe.unet
|
self.unet = pipe.unet
|
||||||
self.loras_to_load = {}
|
|
||||||
# Legacy class handles lora not generated through diffusers
|
# Legacy class handles lora not generated through diffusers
|
||||||
self.legacy = LegacyLora(pipe, self.lora_path)
|
self.legacy = LegacyLoraManager(pipe, self.lora_path)
|
||||||
|
|
||||||
def apply_lora_model(self, name, mult: float = 1.0):
|
def apply_lora_model(self, name):
|
||||||
path = Path(self.lora_path, name)
|
path = Path(self.lora_path, name)
|
||||||
file = Path(path, "pytorch_lora_weights.bin")
|
file = Path(path, "pytorch_lora_weights.bin")
|
||||||
|
|
||||||
if path.is_dir() and file.is_file():
|
if path.is_dir() and file.is_file():
|
||||||
print(f"loading lora: {path}")
|
print(f">> Loading LoRA: {path}")
|
||||||
self.unet.load_attn_procs(path.absolute().as_posix())
|
self.unet.load_attn_procs(path.absolute().as_posix())
|
||||||
else:
|
else:
|
||||||
self.legacy.apply_lora_model(name, mult)
|
print(f">> Unable to find valid LoRA at: {path}")
|
||||||
|
|
||||||
def load_lora(self):
|
def load_lora_compel(self, lora_weights: list):
|
||||||
for name, multiplier in self.loras_to_load.items():
|
if len(lora_weights) > 0:
|
||||||
self.apply_lora_model(name, multiplier)
|
for lora in lora_weights:
|
||||||
|
self.apply_lora_model(lora.model)
|
||||||
|
|
||||||
self.legacy.unload_applied_loras(self.loras_to_load)
|
# Legacy functions, to pipe to LoraLegacyManager
|
||||||
|
def configure_prompt_legacy(self, prompt: str) -> str:
|
||||||
|
return self.legacy.configure_prompt(prompt)
|
||||||
|
|
||||||
# Define a lora to be loaded
|
def load_lora_legacy(self):
|
||||||
# Can be used to define a lora to be loaded outside of prompts
|
self.legacy.load_lora()
|
||||||
def set_lora(self, name, multiplier: float = 1.0):
|
|
||||||
self.loras_to_load[name] = multiplier
|
|
||||||
self.legacy.set_lora(name, multiplier)
|
|
||||||
|
|
||||||
# Load the lora from a prompt, syntax is <lora:lora_name:multiplier>
|
|
||||||
# Multiplier should be a value between 0.0 and 1.0
|
|
||||||
def configure_prompt(self, prompt: str) -> str:
|
|
||||||
self.clear_loras()
|
|
||||||
|
|
||||||
# lora_match = re.compile(r"<lora:([^>]+)>")
|
|
||||||
lora_match = re.compile(r"withLora\(([a-zA-Z\,\d]+)\)")
|
|
||||||
|
|
||||||
for match in re.findall(lora_match, prompt):
|
|
||||||
# match = match.split(':')
|
|
||||||
match = match.split(',')
|
|
||||||
name = match[0].strip()
|
|
||||||
|
|
||||||
mult = 1.0
|
|
||||||
if len(match) == 2:
|
|
||||||
mult = float(match[1].strip())
|
|
||||||
|
|
||||||
self.set_lora(name, mult)
|
|
||||||
|
|
||||||
# remove lora and return prompt to avoid the lora prompt causing issues in inference
|
|
||||||
return re.sub(lora_match, "", prompt)
|
|
||||||
|
|
||||||
def clear_loras(self):
|
|
||||||
self.loras_to_load = {}
|
|
||||||
self.legacy.clear_loras()
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
del self.loras_to_load
|
|
||||||
|
|||||||
Reference in New Issue
Block a user