mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
tweaks and small refactors
This commit is contained in:
@@ -289,10 +289,10 @@ class InvokeAICrossAttentionMixin:
|
||||
|
||||
|
||||
|
||||
def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None):
|
||||
def restore_default_cross_attention(model, is_running_diffusers: bool, processors_to_restore: Optional[AttnProcessor]=None):
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
|
||||
unet.set_attn_processor(processors_to_restore or CrossAttnProcessor())
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
@@ -334,11 +334,9 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
|
||||
default_slice_size = 4
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
return old_attn_processors
|
||||
else:
|
||||
context.register_cross_attention_modules(model)
|
||||
inject_attention_function(model, context)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class DDIMSampler(Sampler):
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ class CFGDenoiser(nn.Module):
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = t_enc)
|
||||
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = t_enc)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class PLMSSampler(Sampler):
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
@@ -56,9 +56,6 @@ class InvokeAIDiffuserComponent:
|
||||
def has_lora_conditions(self):
|
||||
return self.lora_conditions is not None
|
||||
|
||||
@property
|
||||
def should_do_swap(self):
|
||||
return self.wants_cross_attention_control or self.has_lora_conditions
|
||||
|
||||
def __init__(self, model, model_forward_callback: ModelForwardCallback,
|
||||
is_running_diffusers: bool=False,
|
||||
@@ -78,11 +75,11 @@ class InvokeAIDiffuserComponent:
|
||||
def custom_attention_context(self,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int):
|
||||
do_swap = extra_conditioning_info is not None and extra_conditioning_info.should_do_swap
|
||||
old_attn_processor = None
|
||||
if do_swap:
|
||||
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
|
||||
step_count=step_count)
|
||||
if extra_conditioning_info.wants_cross_attention_control | extra_conditioning_info.has_lora_conditions:
|
||||
old_attn_processor = self.override_attention_processors(extra_conditioning_info,
|
||||
step_count=step_count)
|
||||
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
@@ -92,41 +89,34 @@ class InvokeAIDiffuserComponent:
|
||||
#self.remove_attention_map_saving()
|
||||
|
||||
|
||||
def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
|
||||
def override_attention_processors(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
|
||||
"""
|
||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
the previous attention processor is returned so that the caller can restore it later.
|
||||
"""
|
||||
self.conditioning = conditioning
|
||||
|
||||
# If other modules do not want cross_attention_control then we should bypass setting up Context
|
||||
old_attn_processors = None
|
||||
if not self.conditioning.wants_cross_attention_control:
|
||||
old_attn_processors = self.model.attn_processors
|
||||
old_attn_processors = self.model.attn_processors
|
||||
|
||||
# Load lora conditions into the model
|
||||
if self.conditioning.has_lora_conditions:
|
||||
for condition in self.conditioning.lora_conditions:
|
||||
if conditioning.has_lora_conditions:
|
||||
for condition in conditioning.lora_conditions:
|
||||
condition(self.model)
|
||||
|
||||
# return old_attn_processors if there is nothing further to do here
|
||||
if not self.conditioning.wants_cross_attention_control:
|
||||
return old_attn_processors
|
||||
if conditioning.wants_cross_attention_control:
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
)
|
||||
override_cross_attention(self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers)
|
||||
return old_attn_processors
|
||||
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
)
|
||||
return override_cross_attention(self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers)
|
||||
|
||||
def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None):
|
||||
self.conditioning = None
|
||||
def restore_default_cross_attention(self, processors_to_restore: Optional[dict[str, 'AttnProcessor']]=None):
|
||||
self.cross_attention_control_context = None
|
||||
restore_default_cross_attention(self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
restore_attention_processor=restore_attention_processor)
|
||||
processors_to_restore=processors_to_restore)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
@@ -328,7 +318,7 @@ class InvokeAIDiffuserComponent:
|
||||
#print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_apply_saved_attention_maps(ca_type)
|
||||
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
edited_conditioning = context.arguments.edited_conditioning
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||
context.clear_requests(cleanup=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user