diff --git a/ldm/modules/lora_manager.py b/ldm/modules/lora_manager.py index d24316ddc6..c1719a6db6 100644 --- a/ldm/modules/lora_manager.py +++ b/ldm/modules/lora_manager.py @@ -120,6 +120,7 @@ def load_lora( class LoraManager: loras: dict[str, LoRA] applied_loras: dict[str, LoRA] + loras_to_load: dict[str, dict] hooks: list[RemovableHandle] def __init__(self, pipe): @@ -132,7 +133,7 @@ class LoraManager: self.loras = {} self.applied_loras = {} self.hooks = [] - self.loras_to_load = [] + self.loras_to_load = {} def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]: mapping = {} @@ -153,13 +154,13 @@ class LoraManager: LORA_PREFIX_UNET, self.unet, UNET_TARGET_REPLACE_MODULE) def _make_hook(self, layer: str): - def hook(module, input, output): + def hook(module, input_h, output): for lora in self.applied_loras.values(): lora_layer = lora.layers.get(layer, None) if lora_layer is None: continue output = output + \ - lora_layer.up(lora_layer.down(*input)) * \ + lora_layer.up(lora_layer.down(*input_h)) * \ lora.multiplier * lora_layer.scale return output return hook @@ -195,19 +196,20 @@ class LoraManager: self.applied_loras[name] = lora def load_lora(self): - for lora_to_load in self.loras_to_load: - self.apply_lora_model(lora_to_load["name"], lora_to_load["mult"]) + for name, data in self.loras_to_load.items(): + self.apply_lora_model(name, data["mult"]) + + # unload any lora's not defined by loras_to_load + for name in list(self.loras.keys()): + if name not in self.loras_to_load: + self.unload_lora(name) def unload_lora(self, lora_name: str): if lora_name in self.loras: del self.loras[lora_name] def set_lora(self, name, mult: float = 1.0): - if name in self.loras_to_load: - index = self.loras_to_load.index(name) - self.loras_to_load[index]["mult"] = mult - else: - self.loras_to_load.append({"name": name, "mult": mult}) + self.loras_to_load[name] = {"mult": mult} def set_lora_from_prompt(self, match): match = match.split(':') @@ -221,6 +223,7 @@ class LoraManager: def configure_prompt(self, prompt: str) -> str: self.applied_loras = {} + self.loras_to_load = {} for match in re.findall(self.lora_match, prompt): self.set_lora_from_prompt(match)