From a4aea1540b906faee6687a819dc9a2d007ca9bfa Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 25 Jan 2023 14:51:08 +0100 Subject: [PATCH] more wip sliced attention (.swap doesn't work yet) --- ldm/invoke/generator/diffusers_pipeline.py | 60 +++++++++---------- .../diffusion/cross_attention_control.py | 25 ++++---- .../diffusion/shared_invokeai_diffusion.py | 39 +++++++++--- 3 files changed, 75 insertions(+), 49 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 6f3cd14550..e5ce403cb7 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -306,6 +306,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if is_xformers_available() and not Globals.disable_xformers: self.enable_xformers_memory_efficient_attention() + else: + slice_size = 2 + self.enable_attention_slicing(slice_size=slice_size) def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, conditioning_data: ConditioningData, @@ -370,43 +373,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if additional_guidance is None: additional_guidance = [] extra_conditioning_info = conditioning_data.extra - if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, - step_count=len(self.scheduler.timesteps)) - else: - self.invokeai_diffuser.remove_cross_attention_control() + with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info, + step_count=len(self.scheduler.timesteps), + do_attention_map_saving=False): - yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, - latents=latents) + yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, + latents=latents) - batch_size = latents.shape[0] - batched_t = torch.full((batch_size,), timesteps[0], - dtype=timesteps.dtype, device=self.unet.device) - latents = self.scheduler.add_noise(latents, noise, batched_t) + batch_size = latents.shape[0] + batched_t = torch.full((batch_size,), timesteps[0], + dtype=timesteps.dtype, device=self.unet.device) + latents = self.scheduler.add_noise(latents, noise, batched_t) - attention_map_saver: Optional[AttentionMapSaver] = None - self.invokeai_diffuser.remove_attention_map_saving() + attention_map_saver: Optional[AttentionMapSaver] = None - for i, t in enumerate(self.progress_bar(timesteps)): - batched_t.fill_(t) - step_output = self.step(batched_t, latents, conditioning_data, - step_index=i, - total_step_count=len(timesteps), - additional_guidance=additional_guidance) - latents = step_output.prev_sample - predicted_original = getattr(step_output, 'pred_original_sample', None) + for i, t in enumerate(self.progress_bar(timesteps)): + batched_t.fill_(t) + step_output = self.step(batched_t, latents, conditioning_data, + step_index=i, + total_step_count=len(timesteps), + additional_guidance=additional_guidance) + latents = step_output.prev_sample + predicted_original = getattr(step_output, 'pred_original_sample', None) - if i == len(timesteps)-1 and extra_conditioning_info is not None: - eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 - attention_map_token_ids = range(1, eos_token_index) - attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) - self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) + # TODO resuscitate attention map saving + #if i == len(timesteps)-1 and extra_conditioning_info is not None: + # eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 + # attention_map_token_ids = range(1, eos_token_index) + # attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) + # self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) - yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, - predicted_original=predicted_original, attention_map_saver=attention_map_saver) + yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, + predicted_original=predicted_original, attention_map_saver=attention_map_saver) - self.invokeai_diffuser.remove_attention_map_saving() - return latents, attention_map_saver + return latents, attention_map_saver @torch.inference_mode() def step(self, t: torch.Tensor, latents: torch.Tensor, diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index b1c1cd63d9..08c62060c9 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -7,6 +7,7 @@ import torch import diffusers from torch import nn from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.models.cross_attention import AttnProcessor from ldm.invoke.devices import torch_dtype @@ -305,11 +306,10 @@ class InvokeAICrossAttentionMixin: -def remove_cross_attention_control(model, is_running_diffusers: bool): +def remove_cross_attention_control(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None): if is_running_diffusers: unet = model - print("** need to know what cross attn processor to use by default, None in the following line is wrong") - unet.set_attn_processor(CrossAttnProcessor()) + unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor()) else: remove_attention_function(model) @@ -343,10 +343,16 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers context.cross_attention_index_map = indices.to(device) if is_running_diffusers: unet = model - unet.set_attn_processor(SwapCrossAttnProcessor()) + old_attn_processors = unet.attn_processors + # try to re-use an existing slice size + default_slice_size = 4 + slice_size = next((p for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) + unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) + return old_attn_processors else: context.register_cross_attention_modules(model) inject_attention_function(model, context) + return None @@ -509,13 +515,11 @@ class CrossAttnProcessor: return hidden_states """ -import enum from dataclasses import field, dataclass import torch -from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor -from ldm.models.diffusion.cross_attention_control import CrossAttentionType +from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor, AttnProcessor @dataclass @@ -523,7 +527,7 @@ class SwapCrossAttnContext: modified_text_embeddings: torch.Tensor index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt mask: torch.Tensor # in the target space of the index_map - cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=[]) + cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list) def __int__(self, cac_types_to_do: [CrossAttentionType], @@ -629,9 +633,6 @@ class SwapCrossAttnProcessor(CrossAttnProcessor): class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): - def __init__(self, slice_size = 1e6): - self.slice_count = slice_size - # TODO: dynamically pick slice size based on memory conditions def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, @@ -660,7 +661,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): original_text_key = attn.head_to_batch_dim(original_text_key) modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings modified_text_key = attn.to_k(modified_text_embeddings) - modified_text_key = attn.head_to_batch_dim(original_text_key) + modified_text_key = attn.head_to_batch_dim(modified_text_key) # for the "value" just use the modified text embeddings. value = attn.to_v(modified_text_embeddings) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index e4932f6ad8..0c91df9528 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,11 +1,13 @@ import math +from contextlib import contextmanager from dataclasses import dataclass from math import ceil -from typing import Callable, Optional, Union, Any +from typing import Callable, Optional, Union, Any, Dict import numpy as np import torch +from diffusers.models.cross_attention import AttnProcessor from ldm.models.diffusion.cross_attention_control import Arguments, \ remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \ CrossAttentionType, SwapCrossAttnContext @@ -55,20 +57,43 @@ class InvokeAIDiffuserComponent: self.model_forward_callback = model_forward_callback self.cross_attention_control_context = None - def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): + @contextmanager + def custom_attention_context(self, + extra_conditioning_info: Optional[ExtraConditioningInfo], + step_count: int, + do_attention_map_saving: bool): + do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control + old_attn_processor = None + if do_swap: + old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info, + step_count=step_count) + try: + yield None + finally: + self.remove_cross_attention_control(old_attn_processor) + # TODO resuscitate attention map saving + #self.remove_attention_map_saving() + + def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]: + """ + setup cross attention .swap control. for diffusers this replaces the attention processor, so + the previous attention processor is returned so that the caller can restore it later. + """ self.conditioning = conditioning self.cross_attention_control_context = Context( arguments=self.conditioning.cross_attention_control_args, step_count=step_count ) - setup_cross_attention_control(self.model, - self.cross_attention_control_context, - is_running_diffusers=self.is_running_diffusers) + return setup_cross_attention_control(self.model, + self.cross_attention_control_context, + is_running_diffusers=self.is_running_diffusers) - def remove_cross_attention_control(self): + def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None): self.conditioning = None self.cross_attention_control_context = None - remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers) + remove_cross_attention_control(self.model, + is_running_diffusers=self.is_running_diffusers, + restore_attention_processor=restore_attention_processor) def setup_attention_map_saving(self, saver: AttentionMapSaver): def callback(slice, dim, offset, slice_size, key):