mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-01 03:01:13 -04:00
initial setup of cross attention
This commit is contained in:
@@ -60,7 +60,7 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
if model.lora_manager:
|
||||
model.lora_manager.load_lora_compel(positive_prompt.lora_weights)
|
||||
model.lora_manager.set_loras_compel(positive_prompt.lora_weights)
|
||||
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
|
||||
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
|
||||
@@ -290,12 +290,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True)
|
||||
self.lora_manager = LoraManager(self)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet,
|
||||
self._unet_forward,
|
||||
self.lora_manager,
|
||||
is_running_diffusers=True)
|
||||
use_full_precision = (precision == 'float32' or precision == 'autocast')
|
||||
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
full_precision=use_full_precision)
|
||||
self.lora_manager = LoraManager(self)
|
||||
|
||||
# InvokeAI's interface for text embeddings and whatnot
|
||||
self.embeddings_provider = EmbeddingsProvider(
|
||||
|
||||
@@ -13,6 +13,7 @@ from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
||||
CrossAttentionType, SwapCrossAttnContext
|
||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ldm.modules.lora_manager import LoraManager
|
||||
|
||||
ModelForwardCallback: TypeAlias = Union[
|
||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||
@@ -51,7 +52,7 @@ class InvokeAIDiffuserComponent:
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
|
||||
def __init__(self, model, model_forward_callback: ModelForwardCallback,
|
||||
def __init__(self, model, model_forward_callback: ModelForwardCallback, lora_manager: LoraManager,
|
||||
is_running_diffusers: bool=False,
|
||||
):
|
||||
"""
|
||||
@@ -64,6 +65,7 @@ class InvokeAIDiffuserComponent:
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = Globals.sequential_guidance
|
||||
self.lora_manager = lora_manager
|
||||
|
||||
@contextmanager
|
||||
def custom_attention_context(self,
|
||||
@@ -71,6 +73,8 @@ class InvokeAIDiffuserComponent:
|
||||
step_count: int):
|
||||
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||
old_attn_processor = None
|
||||
if self.lora_manager:
|
||||
self.lora_manager.load_loras()
|
||||
if do_swap:
|
||||
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
|
||||
step_count=step_count)
|
||||
@@ -82,6 +86,7 @@ class InvokeAIDiffuserComponent:
|
||||
# TODO resuscitate attention map saving
|
||||
#self.remove_attention_map_saving()
|
||||
|
||||
|
||||
def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
|
||||
"""
|
||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
|
||||
@@ -369,11 +369,15 @@ class LegacyLoraManager:
|
||||
|
||||
|
||||
class LoraManager:
|
||||
models: list[str]
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.lora_path = Path(global_models_dir(), 'lora')
|
||||
self.unet = pipe.unet
|
||||
self.text_encoder = pipe.text_encoder
|
||||
# Legacy class handles lora not generated through diffusers
|
||||
self.legacy = LegacyLoraManager(pipe, self.lora_path)
|
||||
self.models = []
|
||||
|
||||
def apply_lora_model(self, name):
|
||||
path = Path(self.lora_path, name)
|
||||
@@ -385,10 +389,17 @@ class LoraManager:
|
||||
else:
|
||||
print(f">> Unable to find valid LoRA at: {path}")
|
||||
|
||||
def load_lora_compel(self, lora_weights: list):
|
||||
def set_lora_model(self, name):
|
||||
self.models.append(name)
|
||||
|
||||
def set_loras_compel(self, lora_weights: list):
|
||||
if len(lora_weights) > 0:
|
||||
for lora in lora_weights:
|
||||
self.apply_lora_model(lora.model)
|
||||
self.set_lora_model(lora.model)
|
||||
|
||||
def load_loras(self):
|
||||
for name in self.models:
|
||||
self.apply_lora_model(name)
|
||||
|
||||
# Legacy functions, to pipe to LoraLegacyManager
|
||||
def configure_prompt_legacy(self, prompt: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user