mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-01 03:01:13 -04:00
add pending support for safetensors with cloneofsimo/lora
This commit is contained in:
@@ -293,7 +293,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
full_precision=use_full_precision)
|
||||
self.lora_manager = LoraManager(self.unet)
|
||||
self.lora_manager = LoraManager(self)
|
||||
|
||||
# InvokeAI's interface for text embeddings and whatnot
|
||||
self.prompt_fragments_to_embeddings_converter = WeightedPromptFragmentsToEmbeddingsConverter(
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from ldm.invoke.globals import global_models_dir
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from lora_diffusion import tune_lora_scale, patch_pipe
|
||||
|
||||
|
||||
class LoraManager:
|
||||
|
||||
def __init__(self, model: UNet2DConditionModel):
|
||||
def __init__(self, pipe):
|
||||
self.weights = {}
|
||||
self.model = model
|
||||
self.pipe = pipe
|
||||
self.lora_path = Path(global_models_dir(), 'lora')
|
||||
self.lora_match = re.compile(r"<lora:([^>]+)>")
|
||||
self.prompt = None
|
||||
@@ -17,13 +17,31 @@ class LoraManager:
|
||||
args = args.split(':')
|
||||
name = args[0]
|
||||
path = Path(self.lora_path, name)
|
||||
file = Path(path, "pytorch_lora_weights.bin")
|
||||
|
||||
if path.is_dir():
|
||||
if path.is_dir() and file.is_file():
|
||||
print(f"loading lora: {path}")
|
||||
self.model.load_attn_procs(path.absolute().as_posix())
|
||||
|
||||
self.pipe.unet.load_attn_procs(path.absolute().as_posix())
|
||||
if len(args) == 2:
|
||||
self.weights[name] = args[1]
|
||||
else:
|
||||
# converting and saving in diffusers format
|
||||
path_file = Path(self.lora_path, f'{name}.ckpt')
|
||||
if Path(self.lora_path, f'{name}.safetensors').exists():
|
||||
path_file = Path(self.lora_path, f'{name}.safetensors')
|
||||
|
||||
if path_file.is_file():
|
||||
print(f"loading lora: {path}")
|
||||
patch_pipe(
|
||||
self.pipe,
|
||||
path_file.absolute().as_posix(),
|
||||
patch_text=True,
|
||||
patch_ti=True,
|
||||
patch_unet=True,
|
||||
)
|
||||
if len(args) == 2:
|
||||
tune_lora_scale(self.pipe.unet, args[1])
|
||||
tune_lora_scale(self.pipe.text_encoder, args[1])
|
||||
|
||||
def load_lora_from_prompt(self, prompt: str):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user