mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
update prompt setup so lora's can be loaded in other ways
This commit is contained in:
@@ -132,7 +132,7 @@ class LoraManager:
|
||||
self.loras = {}
|
||||
self.applied_loras = {}
|
||||
self.hooks = []
|
||||
self.prompt = ""
|
||||
self.loras_to_load = []
|
||||
|
||||
def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]:
|
||||
mapping = {}
|
||||
@@ -170,9 +170,7 @@ class LoraManager:
|
||||
self.loras[name] = lora
|
||||
return lora
|
||||
|
||||
def apply_lora_model(self, args):
|
||||
args = args.split(':')
|
||||
name = args[0]
|
||||
def apply_lora_model(self, name, mult: float = 1.0):
|
||||
path = Path(self.lora_path, name)
|
||||
file = Path(path, "pytorch_lora_weights.bin")
|
||||
|
||||
@@ -189,10 +187,6 @@ class LoraManager:
|
||||
print(f">> Unable to find lora: {name}")
|
||||
return
|
||||
|
||||
mult = 1.0
|
||||
if len(args) == 2:
|
||||
mult = float(args[1])
|
||||
|
||||
lora = self.loras.get(name, None)
|
||||
if lora is None:
|
||||
lora = self._load_lora(name, path_file, mult)
|
||||
@@ -200,20 +194,36 @@ class LoraManager:
|
||||
lora.multiplier = mult
|
||||
self.applied_loras[name] = lora
|
||||
|
||||
def load_lora_from_prompt(self, prompt: str):
|
||||
for m in re.findall(self.lora_match, prompt):
|
||||
self.apply_lora_model(m)
|
||||
|
||||
def load_lora(self):
|
||||
self.load_lora_from_prompt(self.prompt)
|
||||
for lora_to_load in self.loras_to_load:
|
||||
self.apply_lora_model(lora_to_load["name"], lora_to_load["mult"])
|
||||
|
||||
def unload_lora(self, lora_name: str):
|
||||
if lora_name in self.loras:
|
||||
del self.loras[lora_name]
|
||||
|
||||
def set_lora(self, name, mult: float = 1.0):
|
||||
if name in self.loras_to_load:
|
||||
index = self.loras_to_load.index(name)
|
||||
self.loras_to_load[index]["mult"] = mult
|
||||
else:
|
||||
self.loras_to_load.append({"name": name, "mult": mult})
|
||||
|
||||
def set_lora_from_prompt(self, match):
|
||||
match = match.split(':')
|
||||
name = match[0]
|
||||
|
||||
mult = 1.0
|
||||
if len(match) == 2:
|
||||
mult = float(match[1])
|
||||
|
||||
self.set_lora(name, mult)
|
||||
|
||||
def configure_prompt(self, prompt: str) -> str:
|
||||
self.applied_loras = {}
|
||||
self.prompt = prompt
|
||||
|
||||
for match in re.findall(self.lora_match, prompt):
|
||||
self.set_lora_from_prompt(match)
|
||||
|
||||
def found(m):
|
||||
return ""
|
||||
@@ -228,3 +238,4 @@ class LoraManager:
|
||||
for cb in self.hooks:
|
||||
cb.remove()
|
||||
del self.hooks
|
||||
del self.loras_to_load
|
||||
|
||||
Reference in New Issue
Block a user