update prompt setup so lora's can be loaded in other ways

This commit is contained in:
Jordan
2023-02-20 16:06:30 -07:00
parent 3c6c18b34c
commit ac972ebbe3

View File

@@ -132,7 +132,7 @@ class LoraManager:
self.loras = {}
self.applied_loras = {}
self.hooks = []
self.prompt = ""
self.loras_to_load = []
def find_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> dict[str, torch.nn.Module]:
mapping = {}
@@ -170,9 +170,7 @@ class LoraManager:
self.loras[name] = lora
return lora
def apply_lora_model(self, args):
args = args.split(':')
name = args[0]
def apply_lora_model(self, name, mult: float = 1.0):
path = Path(self.lora_path, name)
file = Path(path, "pytorch_lora_weights.bin")
@@ -189,10 +187,6 @@ class LoraManager:
print(f">> Unable to find lora: {name}")
return
mult = 1.0
if len(args) == 2:
mult = float(args[1])
lora = self.loras.get(name, None)
if lora is None:
lora = self._load_lora(name, path_file, mult)
@@ -200,20 +194,36 @@ class LoraManager:
lora.multiplier = mult
self.applied_loras[name] = lora
def load_lora_from_prompt(self, prompt: str):
for m in re.findall(self.lora_match, prompt):
self.apply_lora_model(m)
def load_lora(self):
self.load_lora_from_prompt(self.prompt)
for lora_to_load in self.loras_to_load:
self.apply_lora_model(lora_to_load["name"], lora_to_load["mult"])
def unload_lora(self, lora_name: str):
if lora_name in self.loras:
del self.loras[lora_name]
def set_lora(self, name, mult: float = 1.0):
if name in self.loras_to_load:
index = self.loras_to_load.index(name)
self.loras_to_load[index]["mult"] = mult
else:
self.loras_to_load.append({"name": name, "mult": mult})
def set_lora_from_prompt(self, match):
match = match.split(':')
name = match[0]
mult = 1.0
if len(match) == 2:
mult = float(match[1])
self.set_lora(name, mult)
def configure_prompt(self, prompt: str) -> str:
self.applied_loras = {}
self.prompt = prompt
for match in re.findall(self.lora_match, prompt):
self.set_lora_from_prompt(match)
def found(m):
return ""
@@ -228,3 +238,4 @@ class LoraManager:
for cb in self.hooks:
cb.remove()
del self.hooks
del self.loras_to_load