mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
setup cross conditioning for lora
This commit is contained in:
@@ -55,12 +55,13 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_prompt: FlattenedPrompt|Blend
|
||||
lora_conditions = None
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
if model.lora_manager:
|
||||
model.lora_manager.set_loras_compel(positive_prompt.lora_weights)
|
||||
lora_conditions = model.lora_manager.set_loras_conditions(positive_prompt.lora_weights)
|
||||
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
|
||||
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
@@ -73,7 +74,8 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
'cross_attention_control', None))
|
||||
'cross_attention_control', None),
|
||||
lora_conditions=lora_conditions)
|
||||
return uc, c, ec
|
||||
|
||||
|
||||
|
||||
@@ -291,10 +291,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.lora_manager = LoraManager(self)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet,
|
||||
self._unet_forward,
|
||||
self.lora_manager,
|
||||
is_running_diffusers=True)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True)
|
||||
use_full_precision = (precision == 'float32' or precision == 'autocast')
|
||||
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
|
||||
@@ -23,6 +23,7 @@ Globals = Namespace()
|
||||
Globals.initfile = 'invokeai.init'
|
||||
Globals.models_file = 'models.yaml'
|
||||
Globals.models_dir = 'models'
|
||||
Globals.lora_models_dir = 'lora'
|
||||
Globals.config_dir = 'configs'
|
||||
Globals.autoscan_dir = 'weights'
|
||||
Globals.converted_ckpts_dir = 'converted_ckpts'
|
||||
@@ -75,6 +76,9 @@ def global_config_dir()->Path:
|
||||
def global_models_dir()->Path:
|
||||
return Path(Globals.root, Globals.models_dir)
|
||||
|
||||
def global_lora_models_dir()->Path:
|
||||
return Path(global_models_dir(), Globals.lora_models_dir)
|
||||
|
||||
def global_autoscan_dir()->Path:
|
||||
return Path(Globals.root, Globals.autoscan_dir)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
||||
CrossAttentionType, SwapCrossAttnContext
|
||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ldm.modules.lora_manager import LoraManager
|
||||
from ldm.modules.lora_manager import LoraCondition
|
||||
|
||||
ModelForwardCallback: TypeAlias = Union[
|
||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||
@@ -46,13 +46,21 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
lora_conditions: Optional[list[LoraCondition]] = None
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
@property
|
||||
def has_lora_conditions(self):
|
||||
return self.lora_conditions is not None
|
||||
|
||||
def __init__(self, model, model_forward_callback: ModelForwardCallback, lora_manager: LoraManager,
|
||||
@property
|
||||
def should_do_swap(self):
|
||||
return self.wants_cross_attention_control or self.has_lora_conditions
|
||||
|
||||
def __init__(self, model, model_forward_callback: ModelForwardCallback,
|
||||
is_running_diffusers: bool=False,
|
||||
):
|
||||
"""
|
||||
@@ -65,16 +73,13 @@ class InvokeAIDiffuserComponent:
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = Globals.sequential_guidance
|
||||
self.lora_manager = lora_manager
|
||||
|
||||
@contextmanager
|
||||
def custom_attention_context(self,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int):
|
||||
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||
do_swap = extra_conditioning_info is not None and extra_conditioning_info.should_do_swap
|
||||
old_attn_processor = None
|
||||
if self.lora_manager:
|
||||
self.lora_manager.load_loras()
|
||||
if do_swap:
|
||||
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
|
||||
step_count=step_count)
|
||||
@@ -93,6 +98,21 @@ class InvokeAIDiffuserComponent:
|
||||
the previous attention processor is returned so that the caller can restore it later.
|
||||
"""
|
||||
self.conditioning = conditioning
|
||||
|
||||
# If other modules do not want cross_attention_control then we should bypass setting up Context
|
||||
old_attn_processors = None
|
||||
if not self.conditioning.wants_cross_attention_control:
|
||||
old_attn_processors = self.model.attn_processors
|
||||
|
||||
# Load lora conditions into the model
|
||||
if self.conditioning.has_lora_conditions:
|
||||
for condition in self.conditioning.lora_conditions:
|
||||
condition(self.model)
|
||||
|
||||
# return old_attn_processors if there is nothing further to do here
|
||||
if not self.conditioning.wants_cross_attention_control:
|
||||
return old_attn_processors
|
||||
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
|
||||
@@ -1,40 +1,52 @@
|
||||
from pathlib import Path
|
||||
from ldm.invoke.globals import global_models_dir
|
||||
from ldm.invoke.globals import global_lora_models_dir
|
||||
from .legacy_lora_manager import LegacyLoraManager
|
||||
|
||||
|
||||
class LoraManager:
|
||||
models: list[str]
|
||||
class LoraCondition:
|
||||
name: str
|
||||
weight: float
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.lora_path = Path(global_models_dir(), 'lora')
|
||||
self.unet = pipe.unet
|
||||
self.text_encoder = pipe.text_encoder
|
||||
# Legacy class handles lora not generated through diffusers
|
||||
self.legacy = LegacyLoraManager(pipe, self.lora_path)
|
||||
self.models = []
|
||||
def __init__(self, name, weight: float = 1.0):
|
||||
self.name = name
|
||||
self.weight = weight
|
||||
|
||||
def apply_lora_model(self, name):
|
||||
path = Path(self.lora_path, name)
|
||||
def __call__(self, model):
|
||||
path = Path(global_lora_models_dir(), self.name)
|
||||
file = Path(path, "pytorch_lora_weights.bin")
|
||||
|
||||
if path.is_dir() and file.is_file():
|
||||
print(f">> Loading LoRA: {path}")
|
||||
self.unet.load_attn_procs(path.absolute().as_posix())
|
||||
if model.load_attn_procs:
|
||||
print(f">> Loading LoRA: {path}")
|
||||
model.load_attn_procs(path.absolute().as_posix())
|
||||
else:
|
||||
print(f">> Invalid Model to load LoRA")
|
||||
else:
|
||||
print(f">> Unable to find valid LoRA at: {path}")
|
||||
|
||||
def set_lora_model(self, name):
|
||||
self.models.append(name)
|
||||
|
||||
def set_loras_compel(self, lora_weights: list):
|
||||
class LoraManager:
|
||||
conditions: list[LoraCondition]
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.unet = pipe.unet
|
||||
self.text_encoder = pipe.text_encoder
|
||||
# Legacy class handles lora not generated through diffusers
|
||||
self.legacy = LegacyLoraManager(pipe, global_lora_models_dir())
|
||||
self.conditions = []
|
||||
|
||||
def set_lora_model(self, name, weight: float = 1.0):
|
||||
self.conditions.append(LoraCondition(name, weight))
|
||||
|
||||
def set_loras_conditions(self, lora_weights: list):
|
||||
if len(lora_weights) > 0:
|
||||
for lora in lora_weights:
|
||||
self.set_lora_model(lora.model)
|
||||
self.set_lora_model(lora.model, lora.weight)
|
||||
|
||||
def load_loras(self):
|
||||
for name in self.models:
|
||||
self.apply_lora_model(name)
|
||||
if len(self.conditions) > 0:
|
||||
return self.conditions
|
||||
|
||||
return None
|
||||
|
||||
# Legacy functions, to pipe to LoraLegacyManager
|
||||
# To be removed once support for diffusers LoRA weights is high enough
|
||||
|
||||
Reference in New Issue
Block a user