add notes and adjust functions

This commit is contained in:
Jordan
2023-02-20 16:49:59 -07:00
parent 884a5543c7
commit c3edede73f

View File

@@ -98,7 +98,7 @@ def load_lora(
value.shape[1], value.shape[0], (1, 1), bias=False)
else:
print(
f">> Encoundered unknown lora layer module in {name}: {type(value).__name__}")
f">> Encountered unknown lora layer module in {name}: {type(value).__name__}")
continue
with torch.no_grad():
@@ -125,7 +125,6 @@ class LoraManager:
def __init__(self, pipe):
self.lora_path = Path(global_models_dir(), 'lora')
self.lora_match = re.compile(r"<lora:([^>]+)>")
self.unet = pipe.unet
self.text_encoder = pipe.text_encoder
self.device = torch.device(choose_torch_device())
@@ -208,30 +207,31 @@ class LoraManager:
if lora_name in self.loras:
del self.loras[lora_name]
# Define a lora to be loaded
# Can be used to define a lora to be loaded outside of prompts
def set_lora(self, name, mult: float = 1.0):
self.loras_to_load[name] = {"mult": mult}
def set_lora_from_prompt(self, match):
match = match.split(':')
name = match[0]
mult = 1.0
if len(match) == 2:
mult = float(match[1])
self.set_lora(name, mult)
# Load the lora from a prompt, syntax is <lora:lora_name:multiplier>
# Multiplier should be a value between 0.0 and 1.0
def configure_prompt(self, prompt: str) -> str:
self.applied_loras = {}
self.loras_to_load = {}
for match in re.findall(self.lora_match, prompt):
self.set_lora_from_prompt(match)
lora_match = re.compile(r"<lora:([^>]+)>")
def found(m):
return ""
for match in re.findall(lora_match, prompt):
match = match.split(':')
name = match[0]
return re.sub(self.lora_match, found, prompt)
mult = 1.0
if len(match) == 2:
mult = float(match[1])
self.set_lora(name, mult)
# remove lora and return prompt to avoid the lora prompt causing issues in inference
return re.sub(lora_match, "", prompt)
def __del__(self):
del self.loras