initial setup of cross attention

This commit is contained in:
Jordan
2023-02-23 17:30:34 -07:00
parent 6a1129ab64
commit b69f9d4af1
4 changed files with 25 additions and 6 deletions

View File

@@ -60,7 +60,7 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
if model.lora_manager:
model.lora_manager.load_lora_compel(positive_prompt.lora_weights)
model.lora_manager.set_loras_compel(positive_prompt.lora_weights)
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
if log_tokens or getattr(Globals, "log_tokenization", False):

View File

@@ -290,12 +290,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True)
self.lora_manager = LoraManager(self)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet,
self._unet_forward,
self.lora_manager,
is_running_diffusers=True)
use_full_precision = (precision == 'float32' or precision == 'autocast')
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
full_precision=use_full_precision)
self.lora_manager = LoraManager(self)
# InvokeAI's interface for text embeddings and whatnot
self.embeddings_provider = EmbeddingsProvider(

View File

@@ -13,6 +13,7 @@ from ldm.models.diffusion.cross_attention_control import Arguments, \
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ldm.modules.lora_manager import LoraManager
ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs]
@@ -51,7 +52,7 @@ class InvokeAIDiffuserComponent:
return self.cross_attention_control_args is not None
def __init__(self, model, model_forward_callback: ModelForwardCallback,
def __init__(self, model, model_forward_callback: ModelForwardCallback, lora_manager: LoraManager,
is_running_diffusers: bool=False,
):
"""
@@ -64,6 +65,7 @@ class InvokeAIDiffuserComponent:
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
self.sequential_guidance = Globals.sequential_guidance
self.lora_manager = lora_manager
@contextmanager
def custom_attention_context(self,
@@ -71,6 +73,8 @@ class InvokeAIDiffuserComponent:
step_count: int):
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
old_attn_processor = None
if self.lora_manager:
self.lora_manager.load_loras()
if do_swap:
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
step_count=step_count)
@@ -82,6 +86,7 @@ class InvokeAIDiffuserComponent:
# TODO resuscitate attention map saving
#self.remove_attention_map_saving()
def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so

View File

@@ -369,11 +369,15 @@ class LegacyLoraManager:
class LoraManager:
models: list[str]
def __init__(self, pipe):
self.lora_path = Path(global_models_dir(), 'lora')
self.unet = pipe.unet
self.text_encoder = pipe.text_encoder
# Legacy class handles lora not generated through diffusers
self.legacy = LegacyLoraManager(pipe, self.lora_path)
self.models = []
def apply_lora_model(self, name):
path = Path(self.lora_path, name)
@@ -385,10 +389,17 @@ class LoraManager:
else:
print(f">> Unable to find valid LoRA at: {path}")
def load_lora_compel(self, lora_weights: list):
def set_lora_model(self, name):
self.models.append(name)
def set_loras_compel(self, lora_weights: list):
if len(lora_weights) > 0:
for lora in lora_weights:
self.apply_lora_model(lora.model)
self.set_lora_model(lora.model)
def load_loras(self):
for name in self.models:
self.apply_lora_model(name)
# Legacy functions, to pipe to LoraLegacyManager
def configure_prompt_legacy(self, prompt: str) -> str: