mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
more wip sliced attention (.swap doesn't work yet)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Callable, Optional, Union, Any
|
||||
from typing import Callable, Optional, Union, Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
|
||||
CrossAttentionType, SwapCrossAttnContext
|
||||
@@ -55,20 +57,43 @@ class InvokeAIDiffuserComponent:
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
|
||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
|
||||
@contextmanager
|
||||
def custom_attention_context(self,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int,
|
||||
do_attention_map_saving: bool):
|
||||
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||
old_attn_processor = None
|
||||
if do_swap:
|
||||
old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info,
|
||||
step_count=step_count)
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
self.remove_cross_attention_control(old_attn_processor)
|
||||
# TODO resuscitate attention map saving
|
||||
#self.remove_attention_map_saving()
|
||||
|
||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
|
||||
"""
|
||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
the previous attention processor is returned so that the caller can restore it later.
|
||||
"""
|
||||
self.conditioning = conditioning
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
)
|
||||
setup_cross_attention_control(self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers)
|
||||
return setup_cross_attention_control(self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers)
|
||||
|
||||
def remove_cross_attention_control(self):
|
||||
def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers)
|
||||
remove_cross_attention_control(self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
restore_attention_processor=restore_attention_processor)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
|
||||
Reference in New Issue
Block a user