setup cross conditioning for lora

This commit is contained in:
Jordan
2023-02-23 19:27:45 -07:00
parent 68a3132d81
commit 4ce8b1ba21
5 changed files with 68 additions and 33 deletions

View File

@@ -55,12 +55,13 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_prompt: FlattenedPrompt|Blend
lora_conditions = None
if legacy_blend is not None:
positive_prompt = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
if model.lora_manager:
model.lora_manager.set_loras_compel(positive_prompt.lora_weights)
lora_conditions = model.lora_manager.set_loras_conditions(positive_prompt.lora_weights)
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
if log_tokens or getattr(Globals, "log_tokenization", False):
@@ -73,7 +74,8 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get(
'cross_attention_control', None))
'cross_attention_control', None),
lora_conditions=lora_conditions)
return uc, c, ec

View File

@@ -291,10 +291,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor=feature_extractor,
)
self.lora_manager = LoraManager(self)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet,
self._unet_forward,
self.lora_manager,
is_running_diffusers=True)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True)
use_full_precision = (precision == 'float32' or precision == 'autocast')
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
text_encoder=self.text_encoder,

View File

@@ -23,6 +23,7 @@ Globals = Namespace()
Globals.initfile = 'invokeai.init'
Globals.models_file = 'models.yaml'
Globals.models_dir = 'models'
Globals.lora_models_dir = 'lora'
Globals.config_dir = 'configs'
Globals.autoscan_dir = 'weights'
Globals.converted_ckpts_dir = 'converted_ckpts'
@@ -75,6 +76,9 @@ def global_config_dir()->Path:
def global_models_dir()->Path:
return Path(Globals.root, Globals.models_dir)
def global_lora_models_dir()->Path:
return Path(global_models_dir(), Globals.lora_models_dir)
def global_autoscan_dir()->Path:
return Path(Globals.root, Globals.autoscan_dir)

View File

@@ -13,7 +13,7 @@ from ldm.models.diffusion.cross_attention_control import Arguments, \
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ldm.modules.lora_manager import LoraManager
from ldm.modules.lora_manager import LoraCondition
ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs]
@@ -46,13 +46,21 @@ class InvokeAIDiffuserComponent:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
lora_conditions: Optional[list[LoraCondition]] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
@property
def has_lora_conditions(self):
return self.lora_conditions is not None
def __init__(self, model, model_forward_callback: ModelForwardCallback, lora_manager: LoraManager,
@property
def should_do_swap(self):
return self.wants_cross_attention_control or self.has_lora_conditions
def __init__(self, model, model_forward_callback: ModelForwardCallback,
is_running_diffusers: bool=False,
):
"""
@@ -65,16 +73,13 @@ class InvokeAIDiffuserComponent:
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
self.sequential_guidance = Globals.sequential_guidance
self.lora_manager = lora_manager
@contextmanager
def custom_attention_context(self,
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int):
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
do_swap = extra_conditioning_info is not None and extra_conditioning_info.should_do_swap
old_attn_processor = None
if self.lora_manager:
self.lora_manager.load_loras()
if do_swap:
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
step_count=step_count)
@@ -93,6 +98,21 @@ class InvokeAIDiffuserComponent:
the previous attention processor is returned so that the caller can restore it later.
"""
self.conditioning = conditioning
# If other modules do not want cross_attention_control then we should bypass setting up Context
old_attn_processors = None
if not self.conditioning.wants_cross_attention_control:
old_attn_processors = self.model.attn_processors
# Load lora conditions into the model
if self.conditioning.has_lora_conditions:
for condition in self.conditioning.lora_conditions:
condition(self.model)
# return old_attn_processors if there is nothing further to do here
if not self.conditioning.wants_cross_attention_control:
return old_attn_processors
self.cross_attention_control_context = Context(
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count

View File

@@ -1,40 +1,52 @@
from pathlib import Path
from ldm.invoke.globals import global_models_dir
from ldm.invoke.globals import global_lora_models_dir
from .legacy_lora_manager import LegacyLoraManager
class LoraManager:
models: list[str]
class LoraCondition:
name: str
weight: float
def __init__(self, pipe):
self.lora_path = Path(global_models_dir(), 'lora')
self.unet = pipe.unet
self.text_encoder = pipe.text_encoder
# Legacy class handles lora not generated through diffusers
self.legacy = LegacyLoraManager(pipe, self.lora_path)
self.models = []
def __init__(self, name, weight: float = 1.0):
self.name = name
self.weight = weight
def apply_lora_model(self, name):
path = Path(self.lora_path, name)
def __call__(self, model):
path = Path(global_lora_models_dir(), self.name)
file = Path(path, "pytorch_lora_weights.bin")
if path.is_dir() and file.is_file():
print(f">> Loading LoRA: {path}")
self.unet.load_attn_procs(path.absolute().as_posix())
if model.load_attn_procs:
print(f">> Loading LoRA: {path}")
model.load_attn_procs(path.absolute().as_posix())
else:
print(f">> Invalid Model to load LoRA")
else:
print(f">> Unable to find valid LoRA at: {path}")
def set_lora_model(self, name):
self.models.append(name)
def set_loras_compel(self, lora_weights: list):
class LoraManager:
conditions: list[LoraCondition]
def __init__(self, pipe):
self.unet = pipe.unet
self.text_encoder = pipe.text_encoder
# Legacy class handles lora not generated through diffusers
self.legacy = LegacyLoraManager(pipe, global_lora_models_dir())
self.conditions = []
def set_lora_model(self, name, weight: float = 1.0):
self.conditions.append(LoraCondition(name, weight))
def set_loras_conditions(self, lora_weights: list):
if len(lora_weights) > 0:
for lora in lora_weights:
self.set_lora_model(lora.model)
self.set_lora_model(lora.model, lora.weight)
def load_loras(self):
for name in self.models:
self.apply_lora_model(name)
if len(self.conditions) > 0:
return self.conditions
return None
# Legacy functions, to pipe to LoraLegacyManager
# To be removed once support for diffusers LoRA weights is high enough