mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-19 09:54:24 -05:00
62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
import re
|
|
from pathlib import Path
|
|
from ldm.invoke.globals import global_models_dir
|
|
from lora_diffusion import tune_lora_scale, patch_pipe
|
|
|
|
|
|
class LoraManager:
|
|
|
|
def __init__(self, pipe):
|
|
self.weights = {}
|
|
self.pipe = pipe
|
|
self.lora_path = Path(global_models_dir(), 'lora')
|
|
self.lora_match = re.compile(r"<lora:([^>]+)>")
|
|
self.prompt = None
|
|
|
|
def apply_lora_model(self, args):
|
|
args = args.split(':')
|
|
name = args[0]
|
|
path = Path(self.lora_path, name)
|
|
file = Path(path, "pytorch_lora_weights.bin")
|
|
|
|
if path.is_dir() and file.is_file():
|
|
print(f"loading lora: {path}")
|
|
self.pipe.unet.load_attn_procs(path.absolute().as_posix())
|
|
if len(args) == 2:
|
|
self.weights[name] = args[1]
|
|
else:
|
|
# converting and saving in diffusers format
|
|
path_file = Path(self.lora_path, f'{name}.ckpt')
|
|
if Path(self.lora_path, f'{name}.safetensors').exists():
|
|
path_file = Path(self.lora_path, f'{name}.safetensors')
|
|
|
|
if path_file.is_file():
|
|
print(f"loading lora: {path}")
|
|
patch_pipe(
|
|
self.pipe,
|
|
path_file.absolute().as_posix(),
|
|
patch_text=True,
|
|
patch_ti=True,
|
|
patch_unet=True,
|
|
)
|
|
if len(args) == 2:
|
|
tune_lora_scale(self.pipe.unet, args[1])
|
|
tune_lora_scale(self.pipe.text_encoder, args[1])
|
|
|
|
def load_lora_from_prompt(self, prompt: str):
|
|
|
|
for m in re.findall(self.lora_match, prompt):
|
|
self.apply_lora_model(m)
|
|
|
|
def load_lora(self):
|
|
self.load_lora_from_prompt(self.prompt)
|
|
|
|
def configure_prompt(self, prompt: str) -> str:
|
|
self.prompt = prompt
|
|
|
|
def found(m):
|
|
return ""
|
|
|
|
return re.sub(self.lora_match, found, prompt)
|
|
|