From 141be95c2cb4bd6611f36766e087777a4425a3eb Mon Sep 17 00:00:00 2001 From: Jordan Date: Sat, 18 Feb 2023 05:29:04 -0700 Subject: [PATCH] initial setup of lora support --- binary_installer/requirements.in | 2 +- ldm/generate.py | 7 ++++ ldm/invoke/generator/diffusers_pipeline.py | 3 ++ ldm/modules/lora_manager.py | 43 ++++++++++++++++++++++ pyproject.toml | 2 +- 5 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 ldm/modules/lora_manager.py diff --git a/binary_installer/requirements.in b/binary_installer/requirements.in index 66e0618f78..5922b6f04e 100644 --- a/binary_installer/requirements.in +++ b/binary_installer/requirements.in @@ -4,7 +4,7 @@ --trusted-host https://download.pytorch.org accelerate~=0.15 albumentations -diffusers[torch]~=0.11 +diffusers[torch]~=0.13 einops eventlet flask_cors diff --git a/ldm/generate.py b/ldm/generate.py index 1b07122628..8e21f39f33 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -457,6 +457,9 @@ class Generate: self.sampler_name = sampler_name self._set_sampler() + if self.model.lora_manager: + prompt = self.model.lora_manager.configure_prompt(prompt) + # apply the concepts library to the prompt prompt = self.huggingface_concepts_library.replace_concepts_with_triggers( prompt, @@ -515,6 +518,9 @@ class Generate: 'extractor':self.safety_feature_extractor } if self.safety_checker else None + if self.model.lora_manager: + self.model.lora_manager.load_lora() + results = generator.generate( prompt, iterations=iterations, @@ -927,6 +933,7 @@ class Generate: self.model_name = model_name self._set_sampler() # requires self.model_name to be set first + return self.model def load_huggingface_concepts(self, concepts:list[str]): diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 5990eb42a1..82e7f9afc9 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -28,6 +28,7 @@ from typing_extensions import ParamSpec from ldm.invoke.globals import Globals from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings from ldm.modules.textual_inversion_manager import TextualInversionManager +from ldm.modules.lora_manager import LoraManager from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter @@ -292,6 +293,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer, text_encoder=self.text_encoder, full_precision=use_full_precision) + self.lora_manager = LoraManager(self.unet) + # InvokeAI's interface for text embeddings and whatnot self.prompt_fragments_to_embeddings_converter = WeightedPromptFragmentsToEmbeddingsConverter( tokenizer=self.tokenizer, diff --git a/ldm/modules/lora_manager.py b/ldm/modules/lora_manager.py new file mode 100644 index 0000000000..05e3e3d8f0 --- /dev/null +++ b/ldm/modules/lora_manager.py @@ -0,0 +1,43 @@ +import re +from pathlib import Path + +from ldm.invoke.globals import global_models_dir +from diffusers.models import UNet2DConditionModel + +class LoraManager: + + def __init__(self, model: UNet2DConditionModel): + self.weights = {} + self.model = model + self.lora_path = Path(global_models_dir(), 'lora') + self.lora_match = re.compile(r"]+)>") + self.prompt = None + + def apply_lora_model(self, args): + args = args.split(':') + name = args[0] + path = Path(self.lora_path, name) + + if path.is_dir(): + print(f"loading lora: {path}") + self.model.load_attn_procs(path.absolute().as_posix()) + + if len(args) == 2: + self.weights[name] = args[1] + + def load_lora_from_prompt(self, prompt: str): + + for m in re.findall(self.lora_match, prompt): + self.apply_lora_model(m) + + def load_lora(self): + self.load_lora_from_prompt(self.prompt) + + def configure_prompt(self, prompt: str) -> str: + self.prompt = prompt + + def found(m): + return "" + + return re.sub(self.lora_match, found, prompt) + diff --git a/pyproject.toml b/pyproject.toml index f3dfa69b91..f6e428924d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "click", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "datasets", - "diffusers[torch]~=0.11", + "diffusers[torch]~=0.13", "dnspython==2.2.1", "einops", "eventlet",