adjust loader to use a settings dict

This commit is contained in:
Jordan
2023-02-20 16:33:53 -07:00
parent ac972ebbe3
commit 884a5543c7

View File

@@ -120,6 +120,7 @@ def load_lora(
class LoraManager:
loras: dict[str, LoRA]
applied_loras: dict[str, LoRA]
loras_to_load: dict[str, dict]
hooks: list[RemovableHandle]
def __init__(self, pipe):
@@ -132,7 +133,7 @@ class LoraManager:
self.loras = {}
self.applied_loras = {}
self.hooks = []
self.loras_to_load = []
self.loras_to_load = {}
def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]:
mapping = {}
@@ -153,13 +154,13 @@ class LoraManager:
LORA_PREFIX_UNET, self.unet, UNET_TARGET_REPLACE_MODULE)
def _make_hook(self, layer: str):
def hook(module, input, output):
def hook(module, input_h, output):
for lora in self.applied_loras.values():
lora_layer = lora.layers.get(layer, None)
if lora_layer is None:
continue
output = output + \
lora_layer.up(lora_layer.down(*input)) * \
lora_layer.up(lora_layer.down(*input_h)) * \
lora.multiplier * lora_layer.scale
return output
return hook
@@ -195,19 +196,20 @@ class LoraManager:
self.applied_loras[name] = lora
def load_lora(self):
for lora_to_load in self.loras_to_load:
self.apply_lora_model(lora_to_load["name"], lora_to_load["mult"])
for name, data in self.loras_to_load.items():
self.apply_lora_model(name, data["mult"])
# unload any lora's not defined by loras_to_load
for name in list(self.loras.keys()):
if name not in self.loras_to_load:
self.unload_lora(name)
def unload_lora(self, lora_name: str):
if lora_name in self.loras:
del self.loras[lora_name]
def set_lora(self, name, mult: float = 1.0):
if name in self.loras_to_load:
index = self.loras_to_load.index(name)
self.loras_to_load[index]["mult"] = mult
else:
self.loras_to_load.append({"name": name, "mult": mult})
self.loras_to_load[name] = {"mult": mult}
def set_lora_from_prompt(self, match):
match = match.split(':')
@@ -221,6 +223,7 @@ class LoraManager:
def configure_prompt(self, prompt: str) -> str:
self.applied_loras = {}
self.loras_to_load = {}
for match in re.findall(self.lora_match, prompt):
self.set_lora_from_prompt(match)