switch all none diffusers stuff to legacy, and load through compel prompts

This commit is contained in:
Jordan
2023-02-23 16:48:33 -07:00
parent 8e1fd92e7f
commit 6a1129ab64
3 changed files with 50 additions and 48 deletions

View File

@@ -488,10 +488,11 @@ class Generate:
self.sampler_name = sampler_name
self._set_sampler()
# To try and load lora not trained through diffusers
if self.model.lora_manager:
prompt = self.model.lora_manager.configure_prompt(prompt)
prompt = self.model.lora_manager.configure_prompt_legacy(prompt)
# lora MUST process prompt before conditioning
self.model.lora_manager.load_lora()
self.model.lora_manager.load_lora_legacy()
# apply the concepts library to the prompt
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(

View File

@@ -59,6 +59,8 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
positive_prompt = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
if model.lora_manager:
model.lora_manager.load_lora_compel(positive_prompt.lora_weights)
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
if log_tokens or getattr(Globals, "log_tokenization", False):

View File

@@ -273,7 +273,7 @@ class LoRA:
return
class LegacyLora:
class LegacyLoraManager:
def __init__(self, pipe, lora_path):
self.unet = pipe.unet
self.lora_path = lora_path
@@ -281,6 +281,7 @@ class LegacyLora:
self.text_encoder = pipe.text_encoder
self.device = torch.device(choose_torch_device())
self.dtype = pipe.unet.dtype
self.loras_to_load = {}
def load_lora_module(self, name, path_file, multiplier: float = 1.0):
# can be used instead to load through diffusers, once enough support is added
@@ -298,6 +299,26 @@ class LegacyLora:
return lora
def configure_prompt(self, prompt: str) -> str:
self.clear_loras()
# lora_match = re.compile(r"<lora:([^>]+)>")
lora_match = re.compile(r"withLoraLegacy\(([a-zA-Z\,\d]+)\)")
for match in re.findall(lora_match, prompt):
# match = match.split(':')
match = match.split(',')
name = match[0].strip()
mult = 1.0
if len(match) == 2:
mult = float(match[1].strip())
self.set_lora(name, mult)
# remove lora and return prompt to avoid the lora prompt causing issues in inference
return re.sub(lora_match, "", prompt)
def apply_lora_model(self, name, mult: float = 1.0):
path_file = Path(self.lora_path, f'{name}.ckpt')
if Path(self.lora_path, f'{name}.safetensors').exists():
@@ -314,6 +335,10 @@ class LegacyLora:
lora.multiplier = mult
self.wrapper.applied_loras[name] = lora
def load_lora(self):
for name, multiplier in self.loras_to_load.items():
self.apply_lora_model(name, multiplier)
def unload_applied_loras(self, loras_to_load):
# unload any lora's not defined by loras_to_load
for name in list(self.wrapper.applied_loras.keys()):
@@ -329,71 +354,45 @@ class LegacyLora:
del self.wrapper.loaded_loras[lora_name]
def set_lora(self, name, multiplier: float = 1.0):
self.loras_to_load[name] = multiplier
# update the multiplier if the lora was already loaded
if name in self.wrapper.loaded_loras:
self.wrapper.loaded_loras[name].multiplier = multiplier
def clear_loras(self):
self.loras_to_load = {}
self.wrapper.clear_applied_loras()
def __del__(self):
del self.loras_to_load
class LoraManager:
loras_to_load: dict[str, float]
def __init__(self, pipe):
self.lora_path = Path(global_models_dir(), 'lora')
self.unet = pipe.unet
self.loras_to_load = {}
# Legacy class handles lora not generated through diffusers
self.legacy = LegacyLora(pipe, self.lora_path)
self.legacy = LegacyLoraManager(pipe, self.lora_path)
def apply_lora_model(self, name, mult: float = 1.0):
def apply_lora_model(self, name):
path = Path(self.lora_path, name)
file = Path(path, "pytorch_lora_weights.bin")
if path.is_dir() and file.is_file():
print(f"loading lora: {path}")
print(f">> Loading LoRA: {path}")
self.unet.load_attn_procs(path.absolute().as_posix())
else:
self.legacy.apply_lora_model(name, mult)
print(f">> Unable to find valid LoRA at: {path}")
def load_lora(self):
for name, multiplier in self.loras_to_load.items():
self.apply_lora_model(name, multiplier)
def load_lora_compel(self, lora_weights: list):
if len(lora_weights) > 0:
for lora in lora_weights:
self.apply_lora_model(lora.model)
self.legacy.unload_applied_loras(self.loras_to_load)
# Legacy functions, to pipe to LoraLegacyManager
def configure_prompt_legacy(self, prompt: str) -> str:
return self.legacy.configure_prompt(prompt)
# Define a lora to be loaded
# Can be used to define a lora to be loaded outside of prompts
def set_lora(self, name, multiplier: float = 1.0):
self.loras_to_load[name] = multiplier
self.legacy.set_lora(name, multiplier)
# Load the lora from a prompt, syntax is <lora:lora_name:multiplier>
# Multiplier should be a value between 0.0 and 1.0
def configure_prompt(self, prompt: str) -> str:
self.clear_loras()
# lora_match = re.compile(r"<lora:([^>]+)>")
lora_match = re.compile(r"withLora\(([a-zA-Z\,\d]+)\)")
for match in re.findall(lora_match, prompt):
# match = match.split(':')
match = match.split(',')
name = match[0].strip()
mult = 1.0
if len(match) == 2:
mult = float(match[1].strip())
self.set_lora(name, mult)
# remove lora and return prompt to avoid the lora prompt causing issues in inference
return re.sub(lora_match, "", prompt)
def clear_loras(self):
self.loras_to_load = {}
self.legacy.clear_loras()
def __del__(self):
del self.loras_to_load
def load_lora_legacy(self):
self.legacy.load_lora()