more wip sliced attention (.swap doesn't work yet)

This commit is contained in:
Damian Stewart
2023-01-25 14:51:08 +01:00
parent 63c6019f92
commit a4aea1540b
3 changed files with 75 additions and 49 deletions

View File

@@ -1,11 +1,13 @@
import math
from contextlib import contextmanager
from dataclasses import dataclass
from math import ceil
from typing import Callable, Optional, Union, Any
from typing import Callable, Optional, Union, Any, Dict
import numpy as np
import torch
from diffusers.models.cross_attention import AttnProcessor
from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext
@@ -55,20 +57,43 @@ class InvokeAIDiffuserComponent:
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
@contextmanager
def custom_attention_context(self,
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
do_attention_map_saving: bool):
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
old_attn_processor = None
if do_swap:
old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info,
step_count=step_count)
try:
yield None
finally:
self.remove_cross_attention_control(old_attn_processor)
# TODO resuscitate attention map saving
#self.remove_attention_map_saving()
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
"""
self.conditioning = conditioning
self.cross_attention_control_context = Context(
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
)
setup_cross_attention_control(self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers)
return setup_cross_attention_control(self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers)
def remove_cross_attention_control(self):
def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None):
self.conditioning = None
self.cross_attention_control_context = None
remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers)
remove_cross_attention_control(self.model,
is_running_diffusers=self.is_running_diffusers,
restore_attention_processor=restore_attention_processor)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):