diff --git a/ldm/generate.py b/ldm/generate.py index ba6dec9d5b..9a7badb42e 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -488,10 +488,11 @@ class Generate: self.sampler_name = sampler_name self._set_sampler() + # To try and load lora not trained through diffusers if self.model.lora_manager: - prompt = self.model.lora_manager.configure_prompt(prompt) + prompt = self.model.lora_manager.configure_prompt_legacy(prompt) # lora MUST process prompt before conditioning - self.model.lora_manager.load_lora() + self.model.lora_manager.load_lora_legacy() # apply the concepts library to the prompt prompt = self.huggingface_concepts_library.replace_concepts_with_triggers( diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index db50ce4076..94e90cab55 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -59,6 +59,8 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l positive_prompt = legacy_blend else: positive_prompt = Compel.parse_prompt_string(positive_prompt_string) + if model.lora_manager: + model.lora_manager.load_lora_compel(positive_prompt.lora_weights) negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string) if log_tokens or getattr(Globals, "log_tokenization", False): diff --git a/ldm/modules/lora_manager.py b/ldm/modules/lora_manager.py index 902a2cb578..5c4d52d01f 100644 --- a/ldm/modules/lora_manager.py +++ b/ldm/modules/lora_manager.py @@ -273,7 +273,7 @@ class LoRA: return -class LegacyLora: +class LegacyLoraManager: def __init__(self, pipe, lora_path): self.unet = pipe.unet self.lora_path = lora_path @@ -281,6 +281,7 @@ class LegacyLora: self.text_encoder = pipe.text_encoder self.device = torch.device(choose_torch_device()) self.dtype = pipe.unet.dtype + self.loras_to_load = {} def load_lora_module(self, name, path_file, multiplier: float = 1.0): # can be used instead to load through diffusers, once enough support is added @@ -298,6 +299,26 @@ class LegacyLora: return lora + def configure_prompt(self, prompt: str) -> str: + self.clear_loras() + + # lora_match = re.compile(r"]+)>") + lora_match = re.compile(r"withLoraLegacy\(([a-zA-Z\,\d]+)\)") + + for match in re.findall(lora_match, prompt): + # match = match.split(':') + match = match.split(',') + name = match[0].strip() + + mult = 1.0 + if len(match) == 2: + mult = float(match[1].strip()) + + self.set_lora(name, mult) + + # remove lora and return prompt to avoid the lora prompt causing issues in inference + return re.sub(lora_match, "", prompt) + def apply_lora_model(self, name, mult: float = 1.0): path_file = Path(self.lora_path, f'{name}.ckpt') if Path(self.lora_path, f'{name}.safetensors').exists(): @@ -314,6 +335,10 @@ class LegacyLora: lora.multiplier = mult self.wrapper.applied_loras[name] = lora + def load_lora(self): + for name, multiplier in self.loras_to_load.items(): + self.apply_lora_model(name, multiplier) + def unload_applied_loras(self, loras_to_load): # unload any lora's not defined by loras_to_load for name in list(self.wrapper.applied_loras.keys()): @@ -329,71 +354,45 @@ class LegacyLora: del self.wrapper.loaded_loras[lora_name] def set_lora(self, name, multiplier: float = 1.0): + self.loras_to_load[name] = multiplier + # update the multiplier if the lora was already loaded if name in self.wrapper.loaded_loras: self.wrapper.loaded_loras[name].multiplier = multiplier def clear_loras(self): + self.loras_to_load = {} self.wrapper.clear_applied_loras() + def __del__(self): + del self.loras_to_load + class LoraManager: - loras_to_load: dict[str, float] - def __init__(self, pipe): self.lora_path = Path(global_models_dir(), 'lora') self.unet = pipe.unet - self.loras_to_load = {} # Legacy class handles lora not generated through diffusers - self.legacy = LegacyLora(pipe, self.lora_path) + self.legacy = LegacyLoraManager(pipe, self.lora_path) - def apply_lora_model(self, name, mult: float = 1.0): + def apply_lora_model(self, name): path = Path(self.lora_path, name) file = Path(path, "pytorch_lora_weights.bin") if path.is_dir() and file.is_file(): - print(f"loading lora: {path}") + print(f">> Loading LoRA: {path}") self.unet.load_attn_procs(path.absolute().as_posix()) else: - self.legacy.apply_lora_model(name, mult) + print(f">> Unable to find valid LoRA at: {path}") - def load_lora(self): - for name, multiplier in self.loras_to_load.items(): - self.apply_lora_model(name, multiplier) + def load_lora_compel(self, lora_weights: list): + if len(lora_weights) > 0: + for lora in lora_weights: + self.apply_lora_model(lora.model) - self.legacy.unload_applied_loras(self.loras_to_load) + # Legacy functions, to pipe to LoraLegacyManager + def configure_prompt_legacy(self, prompt: str) -> str: + return self.legacy.configure_prompt(prompt) - # Define a lora to be loaded - # Can be used to define a lora to be loaded outside of prompts - def set_lora(self, name, multiplier: float = 1.0): - self.loras_to_load[name] = multiplier - self.legacy.set_lora(name, multiplier) - - # Load the lora from a prompt, syntax is - # Multiplier should be a value between 0.0 and 1.0 - def configure_prompt(self, prompt: str) -> str: - self.clear_loras() - - # lora_match = re.compile(r"]+)>") - lora_match = re.compile(r"withLora\(([a-zA-Z\,\d]+)\)") - - for match in re.findall(lora_match, prompt): - # match = match.split(':') - match = match.split(',') - name = match[0].strip() - - mult = 1.0 - if len(match) == 2: - mult = float(match[1].strip()) - - self.set_lora(name, mult) - - # remove lora and return prompt to avoid the lora prompt causing issues in inference - return re.sub(lora_match, "", prompt) - - def clear_loras(self): - self.loras_to_load = {} - self.legacy.clear_loras() - - def __del__(self): - del self.loras_to_load + def load_lora_legacy(self): + self.legacy.load_lora()