mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-19 09:54:24 -05:00
add notes and adjust functions
This commit is contained in:
@@ -98,7 +98,7 @@ def load_lora(
|
||||
value.shape[1], value.shape[0], (1, 1), bias=False)
|
||||
else:
|
||||
print(
|
||||
f">> Encoundered unknown lora layer module in {name}: {type(value).__name__}")
|
||||
f">> Encountered unknown lora layer module in {name}: {type(value).__name__}")
|
||||
continue
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -125,7 +125,6 @@ class LoraManager:
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.lora_path = Path(global_models_dir(), 'lora')
|
||||
self.lora_match = re.compile(r"<lora:([^>]+)>")
|
||||
self.unet = pipe.unet
|
||||
self.text_encoder = pipe.text_encoder
|
||||
self.device = torch.device(choose_torch_device())
|
||||
@@ -208,30 +207,31 @@ class LoraManager:
|
||||
if lora_name in self.loras:
|
||||
del self.loras[lora_name]
|
||||
|
||||
# Define a lora to be loaded
|
||||
# Can be used to define a lora to be loaded outside of prompts
|
||||
def set_lora(self, name, mult: float = 1.0):
|
||||
self.loras_to_load[name] = {"mult": mult}
|
||||
|
||||
def set_lora_from_prompt(self, match):
|
||||
match = match.split(':')
|
||||
name = match[0]
|
||||
|
||||
mult = 1.0
|
||||
if len(match) == 2:
|
||||
mult = float(match[1])
|
||||
|
||||
self.set_lora(name, mult)
|
||||
|
||||
# Load the lora from a prompt, syntax is <lora:lora_name:multiplier>
|
||||
# Multiplier should be a value between 0.0 and 1.0
|
||||
def configure_prompt(self, prompt: str) -> str:
|
||||
self.applied_loras = {}
|
||||
self.loras_to_load = {}
|
||||
|
||||
for match in re.findall(self.lora_match, prompt):
|
||||
self.set_lora_from_prompt(match)
|
||||
lora_match = re.compile(r"<lora:([^>]+)>")
|
||||
|
||||
def found(m):
|
||||
return ""
|
||||
for match in re.findall(lora_match, prompt):
|
||||
match = match.split(':')
|
||||
name = match[0]
|
||||
|
||||
return re.sub(self.lora_match, found, prompt)
|
||||
mult = 1.0
|
||||
if len(match) == 2:
|
||||
mult = float(match[1])
|
||||
|
||||
self.set_lora(name, mult)
|
||||
|
||||
# remove lora and return prompt to avoid the lora prompt causing issues in inference
|
||||
return re.sub(lora_match, "", prompt)
|
||||
|
||||
def __del__(self):
|
||||
del self.loras
|
||||
|
||||
Reference in New Issue
Block a user