mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-19 09:54:24 -05:00
adjust loader to use a settings dict
This commit is contained in:
@@ -120,6 +120,7 @@ def load_lora(
|
||||
class LoraManager:
|
||||
loras: dict[str, LoRA]
|
||||
applied_loras: dict[str, LoRA]
|
||||
loras_to_load: dict[str, dict]
|
||||
hooks: list[RemovableHandle]
|
||||
|
||||
def __init__(self, pipe):
|
||||
@@ -132,7 +133,7 @@ class LoraManager:
|
||||
self.loras = {}
|
||||
self.applied_loras = {}
|
||||
self.hooks = []
|
||||
self.loras_to_load = []
|
||||
self.loras_to_load = {}
|
||||
|
||||
def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]:
|
||||
mapping = {}
|
||||
@@ -153,13 +154,13 @@ class LoraManager:
|
||||
LORA_PREFIX_UNET, self.unet, UNET_TARGET_REPLACE_MODULE)
|
||||
|
||||
def _make_hook(self, layer: str):
|
||||
def hook(module, input, output):
|
||||
def hook(module, input_h, output):
|
||||
for lora in self.applied_loras.values():
|
||||
lora_layer = lora.layers.get(layer, None)
|
||||
if lora_layer is None:
|
||||
continue
|
||||
output = output + \
|
||||
lora_layer.up(lora_layer.down(*input)) * \
|
||||
lora_layer.up(lora_layer.down(*input_h)) * \
|
||||
lora.multiplier * lora_layer.scale
|
||||
return output
|
||||
return hook
|
||||
@@ -195,19 +196,20 @@ class LoraManager:
|
||||
self.applied_loras[name] = lora
|
||||
|
||||
def load_lora(self):
|
||||
for lora_to_load in self.loras_to_load:
|
||||
self.apply_lora_model(lora_to_load["name"], lora_to_load["mult"])
|
||||
for name, data in self.loras_to_load.items():
|
||||
self.apply_lora_model(name, data["mult"])
|
||||
|
||||
# unload any lora's not defined by loras_to_load
|
||||
for name in list(self.loras.keys()):
|
||||
if name not in self.loras_to_load:
|
||||
self.unload_lora(name)
|
||||
|
||||
def unload_lora(self, lora_name: str):
|
||||
if lora_name in self.loras:
|
||||
del self.loras[lora_name]
|
||||
|
||||
def set_lora(self, name, mult: float = 1.0):
|
||||
if name in self.loras_to_load:
|
||||
index = self.loras_to_load.index(name)
|
||||
self.loras_to_load[index]["mult"] = mult
|
||||
else:
|
||||
self.loras_to_load.append({"name": name, "mult": mult})
|
||||
self.loras_to_load[name] = {"mult": mult}
|
||||
|
||||
def set_lora_from_prompt(self, match):
|
||||
match = match.split(':')
|
||||
@@ -221,6 +223,7 @@ class LoraManager:
|
||||
|
||||
def configure_prompt(self, prompt: str) -> str:
|
||||
self.applied_loras = {}
|
||||
self.loras_to_load = {}
|
||||
|
||||
for match in re.findall(self.lora_match, prompt):
|
||||
self.set_lora_from_prompt(match)
|
||||
|
||||
Reference in New Issue
Block a user