diff --git a/ldm/modules/lora_manager.py b/ldm/modules/lora_manager.py index c1719a6db6..bb63f501e7 100644 --- a/ldm/modules/lora_manager.py +++ b/ldm/modules/lora_manager.py @@ -98,7 +98,7 @@ def load_lora( value.shape[1], value.shape[0], (1, 1), bias=False) else: print( - f">> Encoundered unknown lora layer module in {name}: {type(value).__name__}") + f">> Encountered unknown lora layer module in {name}: {type(value).__name__}") continue with torch.no_grad(): @@ -125,7 +125,6 @@ class LoraManager: def __init__(self, pipe): self.lora_path = Path(global_models_dir(), 'lora') - self.lora_match = re.compile(r"]+)>") self.unet = pipe.unet self.text_encoder = pipe.text_encoder self.device = torch.device(choose_torch_device()) @@ -208,30 +207,31 @@ class LoraManager: if lora_name in self.loras: del self.loras[lora_name] + # Define a lora to be loaded + # Can be used to define a lora to be loaded outside of prompts def set_lora(self, name, mult: float = 1.0): self.loras_to_load[name] = {"mult": mult} - def set_lora_from_prompt(self, match): - match = match.split(':') - name = match[0] - - mult = 1.0 - if len(match) == 2: - mult = float(match[1]) - - self.set_lora(name, mult) - + # Load the lora from a prompt, syntax is + # Multiplier should be a value between 0.0 and 1.0 def configure_prompt(self, prompt: str) -> str: self.applied_loras = {} self.loras_to_load = {} - for match in re.findall(self.lora_match, prompt): - self.set_lora_from_prompt(match) + lora_match = re.compile(r"]+)>") - def found(m): - return "" + for match in re.findall(lora_match, prompt): + match = match.split(':') + name = match[0] - return re.sub(self.lora_match, found, prompt) + mult = 1.0 + if len(match) == 2: + mult = float(match[1]) + + self.set_lora(name, mult) + + # remove lora and return prompt to avoid the lora prompt causing issues in inference + return re.sub(lora_match, "", prompt) def __del__(self): del self.loras