mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-19 09:54:24 -05:00
initial setup of lora support
This commit is contained in:
@@ -457,6 +457,9 @@ class Generate:
|
||||
self.sampler_name = sampler_name
|
||||
self._set_sampler()
|
||||
|
||||
if self.model.lora_manager:
|
||||
prompt = self.model.lora_manager.configure_prompt(prompt)
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
||||
prompt,
|
||||
@@ -515,6 +518,9 @@ class Generate:
|
||||
'extractor':self.safety_feature_extractor
|
||||
} if self.safety_checker else None
|
||||
|
||||
if self.model.lora_manager:
|
||||
self.model.lora_manager.load_lora()
|
||||
|
||||
results = generator.generate(
|
||||
prompt,
|
||||
iterations=iterations,
|
||||
@@ -927,6 +933,7 @@ class Generate:
|
||||
|
||||
self.model_name = model_name
|
||||
self._set_sampler() # requires self.model_name to be set first
|
||||
|
||||
return self.model
|
||||
|
||||
def load_huggingface_concepts(self, concepts:list[str]):
|
||||
|
||||
@@ -28,6 +28,7 @@ from typing_extensions import ParamSpec
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
from ldm.modules.lora_manager import LoraManager
|
||||
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||
@@ -292,6 +293,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
full_precision=use_full_precision)
|
||||
self.lora_manager = LoraManager(self.unet)
|
||||
|
||||
# InvokeAI's interface for text embeddings and whatnot
|
||||
self.prompt_fragments_to_embeddings_converter = WeightedPromptFragmentsToEmbeddingsConverter(
|
||||
tokenizer=self.tokenizer,
|
||||
|
||||
43
ldm/modules/lora_manager.py
Normal file
43
ldm/modules/lora_manager.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from ldm.invoke.globals import global_models_dir
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
class LoraManager:
|
||||
|
||||
def __init__(self, model: UNet2DConditionModel):
|
||||
self.weights = {}
|
||||
self.model = model
|
||||
self.lora_path = Path(global_models_dir(), 'lora')
|
||||
self.lora_match = re.compile(r"<lora:([^>]+)>")
|
||||
self.prompt = None
|
||||
|
||||
def apply_lora_model(self, args):
|
||||
args = args.split(':')
|
||||
name = args[0]
|
||||
path = Path(self.lora_path, name)
|
||||
|
||||
if path.is_dir():
|
||||
print(f"loading lora: {path}")
|
||||
self.model.load_attn_procs(path.absolute().as_posix())
|
||||
|
||||
if len(args) == 2:
|
||||
self.weights[name] = args[1]
|
||||
|
||||
def load_lora_from_prompt(self, prompt: str):
|
||||
|
||||
for m in re.findall(self.lora_match, prompt):
|
||||
self.apply_lora_model(m)
|
||||
|
||||
def load_lora(self):
|
||||
self.load_lora_from_prompt(self.prompt)
|
||||
|
||||
def configure_prompt(self, prompt: str) -> str:
|
||||
self.prompt = prompt
|
||||
|
||||
def found(m):
|
||||
return ""
|
||||
|
||||
return re.sub(self.lora_match, found, prompt)
|
||||
|
||||
Reference in New Issue
Block a user