rename override/restore methods to better reflect what they actually do

This commit is contained in:
Damian Stewart
2023-01-30 16:23:44 +01:00
parent 17d73d09c0
commit d044d4c577
5 changed files with 20 additions and 20 deletions

View File

@@ -9,7 +9,7 @@ import torch
from diffusers.models.cross_attention import AttnProcessor
from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
@@ -64,17 +64,17 @@ class InvokeAIDiffuserComponent:
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
old_attn_processor = None
if do_swap:
old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info,
step_count=step_count)
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
step_count=step_count)
try:
yield None
finally:
if old_attn_processor is not None:
self.remove_cross_attention_control(old_attn_processor)
self.restore_default_cross_attention(old_attn_processor)
# TODO resuscitate attention map saving
#self.remove_attention_map_saving()
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
@@ -84,16 +84,16 @@ class InvokeAIDiffuserComponent:
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
)
return setup_cross_attention_control(self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers)
return override_cross_attention(self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers)
def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None):
def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None):
self.conditioning = None
self.cross_attention_control_context = None
remove_cross_attention_control(self.model,
is_running_diffusers=self.is_running_diffusers,
restore_attention_processor=restore_attention_processor)
restore_default_cross_attention(self.model,
is_running_diffusers=self.is_running_diffusers,
restore_attention_processor=restore_attention_processor)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):