from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass import math from typing import Any, Callable, Dict, Optional, Union, List import numpy as np import torch from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import AttentionProcessor from typing_extensions import TypeAlias import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from .cross_attention_control import ( Arguments, Context, CrossAttentionType, SwapCrossAttnContext, get_cross_attention_modules, setup_cross_attention_control_attention_processors, ) from .cross_attention_map_saving import AttentionMapSaver ModelForwardCallback: TypeAlias = Union[ # x, t, conditioning, Optional[cross-attention kwargs] Callable[ [torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor, ], Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], ] @dataclass class BasicConditioningInfo: embeds: torch.Tensor extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] # weight: float # mode: ConditioningAlgo def to(self, device, dtype=None): self.embeds = self.embeds.to(device=device, dtype=dtype) return self @dataclass class SDXLConditioningInfo(BasicConditioningInfo): pooled_embeds: torch.Tensor add_time_ids: torch.Tensor def to(self, device, dtype=None): self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype) self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype) return super().to(device=device, dtype=dtype) @dataclass(frozen=True) class PostprocessingSettings: threshold: float warmup: float h_symmetry_time_pct: Optional[float] v_symmetry_time_pct: Optional[float] class InvokeAIDiffuserComponent: """ The aim of this component is to provide a single place for code that can be applied identically to all InvokeAI diffusion procedures. At the moment it includes the following features: * Cross attention control ("prompt2prompt") * Hybrid conditioning (used for inpainting) """ debug_thresholding = False sequential_guidance = False @dataclass class ExtraConditioningInfo: tokens_count_including_eos_bos: int cross_attention_control_args: Optional[Arguments] = None @property def wants_cross_attention_control(self): return self.cross_attention_control_args is not None def __init__( self, model, model_forward_callback: ModelForwardCallback, ): """ :param model: the unet model to pass through to cross attention control :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) """ config = InvokeAIAppConfig.get_config() self.conditioning = None self.model = model self.model_forward_callback = model_forward_callback self.cross_attention_control_context = None self.sequential_guidance = config.sequential_guidance @contextmanager def custom_attention_context( self, unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int, ): old_attn_processors = None if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control): old_attn_processors = unet.attn_processors # Load lora conditions into the model if extra_conditioning_info.wants_cross_attention_control: self.cross_attention_control_context = Context( arguments=extra_conditioning_info.cross_attention_control_args, step_count=step_count, ) setup_cross_attention_control_attention_processors( unet, self.cross_attention_control_context, ) try: yield None finally: self.cross_attention_control_context = None if old_attn_processors is not None: unet.set_attn_processor(old_attn_processors) # TODO resuscitate attention map saving # self.remove_attention_map_saving() def setup_attention_map_saving(self, saver: AttentionMapSaver): def callback(slice, dim, offset, slice_size, key): if dim is not None: # sliced tokens attention map saving is not implemented return saver.add_attention_maps(slice, key) tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) for identifier, module in tokens_cross_attention_modules: key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid" module.set_attention_slice_calculated_callback( lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key) ) def remove_attention_map_saving(self): tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) for _, module in tokens_cross_attention_modules: module.set_attention_slice_calculated_callback(None) def do_controlnet_step( self, control_data, sample: torch.Tensor, timestep: torch.Tensor, step_index: int, total_step_count: int, conditioning_data, ): down_block_res_samples, mid_block_res_sample = None, None # control_data should be type List[ControlNetData] # this loop covers both ControlNet (one ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list) for i, control_datum in enumerate(control_data): control_mode = control_datum.control_mode # soft_injection and cfg_injection are the two ControlNet control_mode booleans # that are combined at higher level to make control_mode enum # soft_injection determines whether to do per-layer re-weighting adjustment (if True) # or default weighting (if False) soft_injection = control_mode == "more_prompt" or control_mode == "more_control" # cfg_injection = determines whether to apply ControlNet to only the conditional (if True) # or the default both conditional and unconditional (if False) cfg_injection = control_mode == "more_control" or control_mode == "unbalanced" first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) # only apply controlnet if current step is within the controlnet's begin/end step range if step_index >= first_control_step and step_index <= last_control_step: if cfg_injection: sample_model_input = sample else: # expand the latents input to control model if doing classifier free guidance # (which I think for now is always true, there is conditional elsewhere that stops execution if # classifier_free_guidance is <= 1.0 ?) sample_model_input = torch.cat([sample] * 2) added_cond_kwargs = None if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: added_cond_kwargs = { "text_embeds": conditioning_data.text_embeddings.pooled_embeds, "time_ids": conditioning_data.text_embeddings.add_time_ids, } encoder_hidden_states = conditioning_data.text_embeddings.embeds encoder_attention_mask = None else: if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: added_cond_kwargs = { "text_embeds": torch.cat([ # TODO: how to pad? just by zeros? or even truncate? conditioning_data.unconditioned_embeddings.pooled_embeds, conditioning_data.text_embeddings.pooled_embeds, ], dim=0), "time_ids": torch.cat([ conditioning_data.unconditioned_embeddings.add_time_ids, conditioning_data.text_embeddings.add_time_ids, ], dim=0), } ( encoder_hidden_states, encoder_attention_mask, ) = self._concat_conditionings_for_batch( conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds, ) if isinstance(control_datum.weight, list): # if controlnet has multiple weights, use the weight for the current step controlnet_weight = control_datum.weight[step_index] else: # if controlnet has a single weight, use it for all steps controlnet_weight = control_datum.weight # controlnet(s) inference down_samples, mid_sample = control_datum.model( sample=sample_model_input, timestep=timestep, encoder_hidden_states=encoder_hidden_states, controlnet_cond=control_datum.image_tensor, conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale encoder_attention_mask=encoder_attention_mask, guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel return_dict=False, ) if cfg_injection: # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # prepend zeros for unconditional batch down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) if down_block_res_samples is None and mid_block_res_sample is None: down_block_res_samples, mid_block_res_sample = down_samples, mid_sample else: # add controlnet outputs together if have multiple controlnets down_block_res_samples = [ samples_prev + samples_curr for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) ] mid_block_res_sample += mid_sample return down_block_res_samples, mid_block_res_sample def do_unet_step( self, sample: torch.Tensor, timestep: torch.Tensor, conditioning_data, # TODO: type step_index: int, total_step_count: int, **kwargs, ): cross_attention_control_types_to_do = [] context: Context = self.cross_attention_control_context if self.cross_attention_control_context is not None: percent_through = step_index / total_step_count cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step( percent_through ) wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 if wants_cross_attention_control: ( unconditioned_next_x, conditioned_next_x, ) = self._apply_cross_attention_controlled_conditioning( sample, timestep, conditioning_data, cross_attention_control_types_to_do, **kwargs, ) elif self.sequential_guidance: ( unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning_sequentially( sample, timestep, conditioning_data, **kwargs, ) else: ( unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning( sample, timestep, conditioning_data, **kwargs, ) return unconditioned_next_x, conditioned_next_x def do_latent_postprocessing( self, postprocessing_settings: PostprocessingSettings, latents: torch.Tensor, sigma, step_index, total_step_count, ) -> torch.Tensor: if postprocessing_settings is not None: percent_through = step_index / total_step_count latents = self.apply_symmetry(postprocessing_settings, latents, percent_through) return latents def _concat_conditionings_for_batch(self, unconditioning, conditioning): def _pad_conditioning(cond, target_len, encoder_attention_mask): conditioning_attention_mask = torch.ones( (cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype ) if cond.shape[1] < max_len: conditioning_attention_mask = torch.cat( [ conditioning_attention_mask, torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), ], dim=1, ) cond = torch.cat( [ cond, torch.zeros( (cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype, ), ], dim=1, ) if encoder_attention_mask is None: encoder_attention_mask = conditioning_attention_mask else: encoder_attention_mask = torch.cat( [ encoder_attention_mask, conditioning_attention_mask, ] ) return cond, encoder_attention_mask encoder_attention_mask = None if unconditioning.shape[1] != conditioning.shape[1]: max_len = max(unconditioning.shape[1], conditioning.shape[1]) unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) return torch.cat([unconditioning, conditioning]), encoder_attention_mask # methods below are called from do_diffusion_step and should be considered private to this class. def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs): # fast batched path x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) added_cond_kwargs = None if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: added_cond_kwargs = { "text_embeds": torch.cat([ # TODO: how to pad? just by zeros? or even truncate? conditioning_data.unconditioned_embeddings.pooled_embeds, conditioning_data.text_embeddings.pooled_embeds, ], dim=0), "time_ids": torch.cat([ conditioning_data.unconditioned_embeddings.add_time_ids, conditioning_data.text_embeddings.add_time_ids, ], dim=0), } both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds ) both_results = self.model_forward_callback( x_twice, sigma_twice, both_conditionings, encoder_attention_mask=encoder_attention_mask, added_cond_kwargs=added_cond_kwargs, **kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) return unconditioned_next_x, conditioned_next_x def _apply_standard_conditioning_sequentially( self, x: torch.Tensor, sigma, conditioning_data, **kwargs, ): # low-memory sequential path uncond_down_block, cond_down_block = None, None down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None) if down_block_additional_residuals is not None: uncond_down_block, cond_down_block = [], [] for down_block in down_block_additional_residuals: _uncond_down, _cond_down = down_block.chunk(2) uncond_down_block.append(_uncond_down) cond_down_block.append(_cond_down) uncond_mid_block, cond_mid_block = None, None mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) if mid_block_additional_residual is not None: uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) added_cond_kwargs = None is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo if is_sdxl: added_cond_kwargs = { "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, "time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, } unconditioned_next_x = self.model_forward_callback( x, sigma, conditioning_data.unconditioned_embeddings.embeds, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block, added_cond_kwargs=added_cond_kwargs, **kwargs, ) if is_sdxl: added_cond_kwargs = { "text_embeds": conditioning_data.text_embeddings.pooled_embeds, "time_ids": conditioning_data.text_embeddings.add_time_ids, } conditioned_next_x = self.model_forward_callback( x, sigma, conditioning_data.text_embeddings.embeds, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block, added_cond_kwargs=added_cond_kwargs, **kwargs, ) return unconditioned_next_x, conditioned_next_x def _apply_cross_attention_controlled_conditioning( self, x: torch.Tensor, sigma, conditioning_data, cross_attention_control_types_to_do, **kwargs, ): context: Context = self.cross_attention_control_context uncond_down_block, cond_down_block = None, None down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None) if down_block_additional_residuals is not None: uncond_down_block, cond_down_block = [], [] for down_block in down_block_additional_residuals: _uncond_down, _cond_down = down_block.chunk(2) uncond_down_block.append(_uncond_down) cond_down_block.append(_cond_down) uncond_mid_block, cond_mid_block = None, None mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None) if mid_block_additional_residual is not None: uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) cross_attn_processor_context = SwapCrossAttnContext( modified_text_embeddings=context.arguments.edited_conditioning, index_map=context.cross_attention_index_map, mask=context.cross_attention_mask, cross_attention_types_to_do=[], ) added_cond_kwargs = None is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo if is_sdxl: added_cond_kwargs = { "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, "time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, } # no cross attention for unconditioning (negative prompt) unconditioned_next_x = self.model_forward_callback( x, sigma, conditioning_data.unconditioned_embeddings.embeds, {"swap_cross_attn_context": cross_attn_processor_context}, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block, added_cond_kwargs=added_cond_kwargs, **kwargs, ) if is_sdxl: added_cond_kwargs = { "text_embeds": conditioning_data.text_embeddings.pooled_embeds, "time_ids": conditioning_data.text_embeddings.add_time_ids, } # do requested cross attention types for conditioning (positive prompt) cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do conditioned_next_x = self.model_forward_callback( x, sigma, conditioning_data.text_embeddings.embeds, {"swap_cross_attn_context": cross_attn_processor_context}, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block, added_cond_kwargs=added_cond_kwargs, **kwargs, ) return unconditioned_next_x, conditioned_next_x def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale): # to scale how much effect conditioning has, calculate the changes it does and then scale that scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale combined_next_x = unconditioned_next_x + scaled_delta return combined_next_x def apply_symmetry( self, postprocessing_settings: PostprocessingSettings, latents: torch.Tensor, percent_through: float, ) -> torch.Tensor: # Reset our last percent through if this is our first step. if percent_through == 0.0: self.last_percent_through = 0.0 if postprocessing_settings is None: return latents # Check for out of bounds h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0): h_symmetry_time_pct = None v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0): v_symmetry_time_pct = None dev = latents.device.type latents.to(device="cpu") if ( h_symmetry_time_pct != None and self.last_percent_through < h_symmetry_time_pct and percent_through >= h_symmetry_time_pct ): # Horizontal symmetry occurs on the 3rd dimension of the latent width = latents.shape[3] x_flipped = torch.flip(latents, dims=[3]) latents = torch.cat( [ latents[:, :, :, 0 : int(width / 2)], x_flipped[:, :, :, int(width / 2) : int(width)], ], dim=3, ) if ( v_symmetry_time_pct != None and self.last_percent_through < v_symmetry_time_pct and percent_through >= v_symmetry_time_pct ): # Vertical symmetry occurs on the 2nd dimension of the latent height = latents.shape[2] y_flipped = torch.flip(latents, dims=[2]) latents = torch.cat( [ latents[:, :, 0 : int(height / 2)], y_flipped[:, :, int(height / 2) : int(height)], ], dim=2, ) self.last_percent_through = percent_through return latents.to(device=dev) # todo: make this work @classmethod def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) # aka sigmas deltas = None uncond_latents = None weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)] # below is fugly omg conditionings = [uc] + [c for c, weight in weighted_cond_list] weights = [1] + [weight for c, weight in weighted_cond_list] chunk_count = math.ceil(len(conditionings) / 2) deltas = None for chunk_index in range(chunk_count): offset = chunk_index * 2 chunk_size = min(2, len(conditionings) - offset) if chunk_size == 1: c_in = conditionings[offset] latents_a = forward_func(x_in[:-1], t_in[:-1], c_in) latents_b = None else: c_in = torch.cat(conditionings[offset : offset + 2]) latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2) # first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining if chunk_index == 0: uncond_latents = latents_a deltas = latents_b - uncond_latents else: deltas = torch.cat((deltas, latents_a - uncond_latents)) if latents_b is not None: deltas = torch.cat((deltas, latents_b - uncond_latents)) # merge the weighted deltas together into a single merged delta per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device) normalize = False if normalize: per_delta_weights /= torch.sum(per_delta_weights) reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1)) deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True) # old_return_value = super().forward(x, sigma, uncond, cond, cond_scale) # assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale)))) return uncond_latents + deltas_merged * global_guidance_scale