From 9ec6fbfee0e7ea8dc173e8cdedc69ab2e79a5dd3 Mon Sep 17 00:00:00 2001 From: JPPhoto Date: Fri, 3 Mar 2023 20:28:03 -0600 Subject: [PATCH] Updated from main. --- .../diffusion/shared_invokeai_diffusion.py | 444 ++++++++++++------ 1 file changed, 297 insertions(+), 147 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index e950bd2969..c0dbcbcf18 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -2,7 +2,7 @@ from enum import Enum from contextlib import contextmanager from dataclasses import dataclass from math import ceil -from typing import Callable, Optional, Union, Any, Dict +from typing import Any, Callable, Dict, Optional, Union import numpy as np import torch @@ -10,19 +10,30 @@ from diffusers.models.cross_attention import AttnProcessor from einops import einops from typing_extensions import TypeAlias -from ldm.invoke.globals import Globals -from ldm.models.diffusion.cross_attention_control import Arguments, \ - restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \ - CrossAttentionType, SwapCrossAttnContext -from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver +from invokeai.backend.globals import Globals + +from .cross_attention_control import ( + Arguments, + Context, + CrossAttentionType, + SwapCrossAttnContext, + get_cross_attention_modules, + override_cross_attention, + restore_default_cross_attention, +) +from .cross_attention_map_saving import AttentionMapSaver ModelForwardCallback: TypeAlias = Union[ # x, t, conditioning, Optional[cross-attention kwargs] - Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor], - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] + Callable[ + [torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], + torch.Tensor, + ], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], ] + SymmetryType = Enum('SymmetryType', ['MIRROR', 'FADE']) @@ -36,20 +47,20 @@ class PostprocessingSettings: class InvokeAIDiffuserComponent: - ''' + """ The aim of this component is to provide a single place for code that can be applied identically to all InvokeAI diffusion procedures. At the moment it includes the following features: * Cross attention control ("prompt2prompt") * Hybrid conditioning (used for inpainting) - ''' + """ + debug_thresholding = False sequential_guidance = False @dataclass class ExtraConditioningInfo: - tokens_count_including_eos_bos: int cross_attention_control_args: Optional[Arguments] = None @@ -57,10 +68,12 @@ class InvokeAIDiffuserComponent: def wants_cross_attention_control(self): return self.cross_attention_control_args is not None - - def __init__(self, model, model_forward_callback: ModelForwardCallback, - is_running_diffusers: bool=False, - ): + def __init__( + self, + model, + model_forward_callback: ModelForwardCallback, + is_running_diffusers: bool = False, + ): """ :param model: the unet model to pass through to cross attention control :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) @@ -73,23 +86,29 @@ class InvokeAIDiffuserComponent: self.sequential_guidance = Globals.sequential_guidance @contextmanager - def custom_attention_context(self, - extra_conditioning_info: Optional[ExtraConditioningInfo], - step_count: int): - do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control + def custom_attention_context( + self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int + ): + do_swap = ( + extra_conditioning_info is not None + and extra_conditioning_info.wants_cross_attention_control + ) old_attn_processor = None if do_swap: - old_attn_processor = self.override_cross_attention(extra_conditioning_info, - step_count=step_count) + old_attn_processor = self.override_cross_attention( + extra_conditioning_info, step_count=step_count + ) try: yield None finally: if old_attn_processor is not None: self.restore_default_cross_attention(old_attn_processor) # TODO resuscitate attention map saving - #self.remove_attention_map_saving() + # self.remove_attention_map_saving() - def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]: + def override_cross_attention( + self, conditioning: ExtraConditioningInfo, step_count: int + ) -> Dict[str, AttnProcessor]: """ setup cross attention .swap control. for diffusers this replaces the attention processor, so the previous attention processor is returned so that the caller can restore it later. @@ -97,18 +116,24 @@ class InvokeAIDiffuserComponent: self.conditioning = conditioning self.cross_attention_control_context = Context( arguments=self.conditioning.cross_attention_control_args, - step_count=step_count + step_count=step_count, + ) + return override_cross_attention( + self.model, + self.cross_attention_control_context, + is_running_diffusers=self.is_running_diffusers, ) - return override_cross_attention(self.model, - self.cross_attention_control_context, - is_running_diffusers=self.is_running_diffusers) - def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None): + def restore_default_cross_attention( + self, restore_attention_processor: Optional["AttnProcessor"] = None + ): self.conditioning = None self.cross_attention_control_context = None - restore_default_cross_attention(self.model, - is_running_diffusers=self.is_running_diffusers, - restore_attention_processor=restore_attention_processor) + restore_default_cross_attention( + self.model, + is_running_diffusers=self.is_running_diffusers, + restore_attention_processor=restore_attention_processor, + ) def setup_attention_map_saving(self, saver: AttentionMapSaver): def callback(slice, dim, offset, slice_size, key): @@ -117,26 +142,40 @@ class InvokeAIDiffuserComponent: return saver.add_attention_maps(slice, key) - tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) + tokens_cross_attention_modules = get_cross_attention_modules( + self.model, CrossAttentionType.TOKENS + ) for identifier, module in tokens_cross_attention_modules: - key = ('down' if identifier.startswith('down') else - 'up' if identifier.startswith('up') else - 'mid') + key = ( + "down" + if identifier.startswith("down") + else "up" + if identifier.startswith("up") + else "mid" + ) module.set_attention_slice_calculated_callback( - lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)) + lambda slice, dim, offset, slice_size, key=key: callback( + slice, dim, offset, slice_size, key + ) + ) def remove_attention_map_saving(self): - tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) + tokens_cross_attention_modules = get_cross_attention_modules( + self.model, CrossAttentionType.TOKENS + ) for _, module in tokens_cross_attention_modules: module.set_attention_slice_calculated_callback(None) - def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, - unconditioning: Union[torch.Tensor,dict], - conditioning: Union[torch.Tensor,dict], - unconditional_guidance_scale: float, - step_index: Optional[int]=None, - total_step_count: Optional[int]=None, - ): + def do_diffusion_step( + self, + x: torch.Tensor, + sigma: torch.Tensor, + unconditioning: Union[torch.Tensor, dict], + conditioning: Union[torch.Tensor, dict], + unconditional_guidance_scale: float, + step_index: Optional[int] = None, + total_step_count: Optional[int] = None, + ): """ :param x: current latents :param sigma: aka t, passed to the internal model to control how much denoising will occur @@ -147,33 +186,55 @@ class InvokeAIDiffuserComponent: :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. """ - cross_attention_control_types_to_do = [] context: Context = self.cross_attention_control_context if self.cross_attention_control_context is not None: - percent_through = self.calculate_percent_through(sigma, step_index, total_step_count) - cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) + percent_through = self.calculate_percent_through( + sigma, step_index, total_step_count + ) + cross_attention_control_types_to_do = ( + context.get_active_cross_attention_control_types_for_step( + percent_through + ) + ) - wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) + wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 wants_hybrid_conditioning = isinstance(conditioning, dict) if wants_hybrid_conditioning: - unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning, - conditioning) + unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning( + x, sigma, unconditioning, conditioning + ) elif wants_cross_attention_control: - unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma, - unconditioning, - conditioning, - cross_attention_control_types_to_do) + ( + unconditioned_next_x, + conditioned_next_x, + ) = self._apply_cross_attention_controlled_conditioning( + x, + sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do, + ) elif self.sequential_guidance: - unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially( - x, sigma, unconditioning, conditioning) + ( + unconditioned_next_x, + conditioned_next_x, + ) = self._apply_standard_conditioning_sequentially( + x, sigma, unconditioning, conditioning + ) else: - unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning( - x, sigma, unconditioning, conditioning) + ( + unconditioned_next_x, + conditioned_next_x, + ) = self._apply_standard_conditioning( + x, sigma, unconditioning, conditioning + ) - combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale) + combined_next_x = self._combine( + unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale + ) return combined_next_x @@ -183,24 +244,33 @@ class InvokeAIDiffuserComponent: latents: torch.Tensor, sigma, step_index, - total_step_count + total_step_count, ) -> torch.Tensor: if postprocessing_settings is not None: - percent_through = self.calculate_percent_through(sigma, step_index, total_step_count) - latents = self.apply_threshold(postprocessing_settings, latents, percent_through) - latents = self.apply_symmetry(postprocessing_settings, latents, percent_through) + percent_through = self.calculate_percent_through( + sigma, step_index, total_step_count + ) + latents = self.apply_threshold( + postprocessing_settings, latents, percent_through + ) + latents = self.apply_symmetry( + postprocessing_settings, latents, percent_through + ) return latents def calculate_percent_through(self, sigma, step_index, total_step_count): if step_index is not None and total_step_count is not None: # 🧨diffusers codepath - percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate + percent_through = ( + step_index / total_step_count + ) # will never reach 1.0 - this is deliberate else: # legacy compvis codepath # TODO remove when compvis codepath support is dropped if step_index is None and sigma is None: raise ValueError( - f"Either step_index or sigma is required when doing cross attention control, but both are None.") + f"Either step_index or sigma is required when doing cross attention control, but both are None." + ) percent_through = self.estimate_percent_through(step_index, sigma) return percent_through @@ -211,24 +281,30 @@ class InvokeAIDiffuserComponent: x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) both_conditionings = torch.cat([unconditioning, conditioning]) - both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings) + both_results = self.model_forward_callback( + x_twice, sigma_twice, both_conditionings + ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) - if conditioned_next_x.device.type == 'mps': + if conditioned_next_x.device.type == "mps": # prevent a result filled with zeros. seems to be a torch bug. conditioned_next_x = conditioned_next_x.clone() return unconditioned_next_x, conditioned_next_x - - def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor): + def _apply_standard_conditioning_sequentially( + self, + x: torch.Tensor, + sigma, + unconditioning: torch.Tensor, + conditioning: torch.Tensor, + ): # low-memory sequential path unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) conditioned_next_x = self.model_forward_callback(x, sigma, conditioning) - if conditioned_next_x.device.type == 'mps': + if conditioned_next_x.device.type == "mps": # prevent a result filled with zeros. seems to be a torch bug. conditioned_next_x = conditioned_next_x.clone() return unconditioned_next_x, conditioned_next_x - def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): assert isinstance(conditioning, dict) assert isinstance(unconditioning, dict) @@ -243,48 +319,80 @@ class InvokeAIDiffuserComponent: ] else: both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) - unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) + unconditioned_next_x, conditioned_next_x = self.model_forward_callback( + x_twice, sigma_twice, both_conditionings + ).chunk(2) return unconditioned_next_x, conditioned_next_x - - def _apply_cross_attention_controlled_conditioning(self, - x: torch.Tensor, - sigma, - unconditioning, - conditioning, - cross_attention_control_types_to_do): + def _apply_cross_attention_controlled_conditioning( + self, + x: torch.Tensor, + sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do, + ): if self.is_running_diffusers: - return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, - conditioning, - cross_attention_control_types_to_do) + return self._apply_cross_attention_controlled_conditioning__diffusers( + x, + sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do, + ) else: - return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, - cross_attention_control_types_to_do) + return self._apply_cross_attention_controlled_conditioning__compvis( + x, + sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do, + ) - def _apply_cross_attention_controlled_conditioning__diffusers(self, - x: torch.Tensor, - sigma, - unconditioning, - conditioning, - cross_attention_control_types_to_do): + def _apply_cross_attention_controlled_conditioning__diffusers( + self, + x: torch.Tensor, + sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do, + ): context: Context = self.cross_attention_control_context - cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning, - index_map=context.cross_attention_index_map, - mask=context.cross_attention_mask, - cross_attention_types_to_do=[]) + cross_attn_processor_context = SwapCrossAttnContext( + modified_text_embeddings=context.arguments.edited_conditioning, + index_map=context.cross_attention_index_map, + mask=context.cross_attention_mask, + cross_attention_types_to_do=[], + ) # no cross attention for unconditioning (negative prompt) - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, - {"swap_cross_attn_context": cross_attn_processor_context}) + unconditioned_next_x = self.model_forward_callback( + x, + sigma, + unconditioning, + {"swap_cross_attn_context": cross_attn_processor_context}, + ) # do requested cross attention types for conditioning (positive prompt) - cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do - conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, - {"swap_cross_attn_context": cross_attn_processor_context}) + cross_attn_processor_context.cross_attention_types_to_do = ( + cross_attention_control_types_to_do + ) + conditioned_next_x = self.model_forward_callback( + x, + sigma, + conditioning, + {"swap_cross_attn_context": cross_attn_processor_context}, + ) return unconditioned_next_x, conditioned_next_x - - def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): + def _apply_cross_attention_controlled_conditioning__compvis( + self, + x: torch.Tensor, + sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do, + ): # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of @@ -294,24 +402,28 @@ class InvokeAIDiffuserComponent: # representing batched uncond + cond, but then when it comes to applying the saved attention, the # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. - context:Context = self.cross_attention_control_context + context: Context = self.cross_attention_control_context try: unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) # process x using the original prompt, saving the attention maps - #print("saving attention maps for", cross_attention_control_types_to_do) + # print("saving attention maps for", cross_attention_control_types_to_do) for ca_type in cross_attention_control_types_to_do: context.request_save_attention_maps(ca_type) _ = self.model_forward_callback(x, sigma, conditioning) context.clear_requests(cleanup=False) # process x again, using the saved attention maps to control where self.edited_conditioning will be applied - #print("applying saved attention maps for", cross_attention_control_types_to_do) + # print("applying saved attention maps for", cross_attention_control_types_to_do) for ca_type in cross_attention_control_types_to_do: context.request_apply_saved_attention_maps(ca_type) - edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning - conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) + edited_conditioning = ( + self.conditioning.cross_attention_control_args.edited_conditioning + ) + conditioned_next_x = self.model_forward_callback( + x, sigma, edited_conditioning + ) context.clear_requests(cleanup=True) except: @@ -330,17 +442,21 @@ class InvokeAIDiffuserComponent: self, postprocessing_settings: PostprocessingSettings, latents: torch.Tensor, - percent_through: float + percent_through: float, ) -> torch.Tensor: - - if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0: + if ( + postprocessing_settings.threshold is None + or postprocessing_settings.threshold == 0.0 + ): return latents threshold = postprocessing_settings.threshold warmup = postprocessing_settings.warmup if percent_through < warmup: - current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup)) + current_threshold = threshold + threshold * 5 * ( + 1 - (percent_through / warmup) + ) else: current_threshold = threshold @@ -354,10 +470,14 @@ class InvokeAIDiffuserComponent: if self.debug_thresholding: std, mean = [i.item() for i in torch.std_mean(latents)] - outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold)) - print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n" - f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n" - f" | {outside / latents.numel() * 100:.2f}% values outside threshold") + outside = torch.count_nonzero( + (latents < -current_threshold) | (latents > current_threshold) + ) + print( + f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n" + f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n" + f" | {outside / latents.numel() * 100:.2f}% values outside threshold" + ) if maxval < current_threshold and minval > -current_threshold: return latents @@ -370,17 +490,23 @@ class InvokeAIDiffuserComponent: latents = torch.clone(latents) maxval = np.clip(maxval * scale, 1, current_threshold) num_altered += torch.count_nonzero(latents > maxval) - latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval + latents[latents > maxval] = ( + torch.rand_like(latents[latents > maxval]) * maxval + ) if minval < -current_threshold: latents = torch.clone(latents) minval = np.clip(minval * scale, -current_threshold, -1) num_altered += torch.count_nonzero(latents < minval) - latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval + latents[latents < minval] = ( + torch.rand_like(latents[latents < minval]) * minval + ) if self.debug_thresholding: - print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n" - f" | {num_altered / latents.numel() * 100:.2f}% values altered") + print( + f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n" + f" | {num_altered / latents.numel() * 100:.2f}% values altered" + ) return latents @@ -388,9 +514,8 @@ class InvokeAIDiffuserComponent: self, postprocessing_settings: PostprocessingSettings, latents: torch.Tensor, - percent_through: float + percent_through: float, ) -> torch.Tensor: - # Reset our last percent through if this is our first step. if percent_through == 0.0: self.last_percent_through = 0.0 @@ -400,11 +525,15 @@ class InvokeAIDiffuserComponent: # Check for out of bounds h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct - if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)): + if h_symmetry_time_pct is not None and ( + h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0 + ): h_symmetry_time_pct = None v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct - if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)): + if v_symmetry_time_pct is not None and ( + v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0 + ): v_symmetry_time_pct = None width = latents.shape[3] @@ -413,24 +542,29 @@ class InvokeAIDiffuserComponent: dtype = latents.dtype symmetry_type = postprocessing_settings.symmetry_type or SymmetryType.FADE - latents.to(device='cpu') + latents.to(device="cpu") - def make_ramp(ease_in:int, total:int) -> torch.Tensor: + def make_ramp(ease_in: int, total: int) -> torch.Tensor: ramp1 = torch.linspace(start=1.0, end=0.5, steps=ease_in, device=dev) ramp2 = torch.linspace(start=0.5, end=1.0, steps=total - ease_in, device=dev) ramp = torch.cat((ramp1, ramp2)) return ramp if ( - h_symmetry_time_pct != None and - self.last_percent_through < h_symmetry_time_pct and - percent_through >= h_symmetry_time_pct + h_symmetry_time_pct != None + and self.last_percent_through < h_symmetry_time_pct + and percent_through >= h_symmetry_time_pct ): # Horizontal symmetry occurs on the 3rd dimension of the latent x_flipped = torch.flip(latents, dims=[3]) if symmetry_type is SymmetryType.MIRROR: - # Use the first half of latents and then the flipped one on this dimension - latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3) + latents = torch.cat( + [ + latents[:, :, :, 0 : int(width / 2)], + x_flipped[:, :, :, int(width / 2) : int(width)], + ], + dim=3, + ) elif symmetry_type is SymmetryType.FADE: apply_width = width // 2 # Create a linear ramp so the middle gets perfect symmetry but the edges retain their original latents @@ -442,15 +576,20 @@ class InvokeAIDiffuserComponent: latents = ((latents * fade1) + (x_flipped * fade0)) * multiplier if ( - v_symmetry_time_pct != None and - self.last_percent_through < v_symmetry_time_pct and - percent_through >= v_symmetry_time_pct + v_symmetry_time_pct != None + and self.last_percent_through < v_symmetry_time_pct + and percent_through >= v_symmetry_time_pct ): # Vertical symmetry occurs on the 3rd dimension of the latent y_flipped = torch.flip(latents, dims=[2]) if symmetry_type is SymmetryType.MIRROR: - # Vertical symmetry occurs on the 2nd dimension of the latent - latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2) + latents = torch.cat( + [ + latents[:, :, 0 : int(height / 2)], + y_flipped[:, :, int(height / 2) : int(height)], + ], + dim=2, + ) elif symmetry_type is SymmetryType.FADE: apply_height = height // 2 # Create a linear ramp so the middle gets perfect symmetry but the edges retain their original latents @@ -467,7 +606,9 @@ class InvokeAIDiffuserComponent: def estimate_percent_through(self, step_index, sigma): if step_index is not None and self.cross_attention_control_context is not None: # percent_through will never reach 1.0 (but this is intended) - return float(step_index) / float(self.cross_attention_control_context.step_count) + return float(step_index) / float( + self.cross_attention_control_context.step_count + ) # find the best possible index of the current sigma in the sigma sequence smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma) sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0 @@ -476,33 +617,38 @@ class InvokeAIDiffuserComponent: return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0]) # print('estimated percent_through', percent_through, 'from sigma', sigma.item()) - # todo: make this work @classmethod - def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): + def apply_conjunction( + cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale + ): x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) # aka sigmas + t_in = torch.cat([t] * 2) # aka sigmas deltas = None uncond_latents = None - weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)] + weighted_cond_list = ( + c_or_weighted_c_list + if type(c_or_weighted_c_list) is list + else [(c_or_weighted_c_list, 1)] + ) # below is fugly omg num_actual_conditionings = len(c_or_weighted_c_list) - conditionings = [uc] + [c for c,weight in weighted_cond_list] - weights = [1] + [weight for c,weight in weighted_cond_list] - chunk_count = ceil(len(conditionings)/2) + conditionings = [uc] + [c for c, weight in weighted_cond_list] + weights = [1] + [weight for c, weight in weighted_cond_list] + chunk_count = ceil(len(conditionings) / 2) deltas = None for chunk_index in range(chunk_count): - offset = chunk_index*2 - chunk_size = min(2, len(conditionings)-offset) + offset = chunk_index * 2 + chunk_size = min(2, len(conditionings) - offset) if chunk_size == 1: c_in = conditionings[offset] latents_a = forward_func(x_in[:-1], t_in[:-1], c_in) latents_b = None else: - c_in = torch.cat(conditionings[offset:offset+2]) + c_in = torch.cat(conditionings[offset : offset + 2]) latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2) # first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining @@ -515,11 +661,15 @@ class InvokeAIDiffuserComponent: deltas = torch.cat((deltas, latents_b - uncond_latents)) # merge the weighted deltas together into a single merged delta - per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device) + per_delta_weights = torch.tensor( + weights[1:], dtype=deltas.dtype, device=deltas.device + ) normalize = False if normalize: per_delta_weights /= torch.sum(per_delta_weights) - reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1)) + reshaped_weights = per_delta_weights.reshape( + per_delta_weights.shape + (1, 1, 1) + ) deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True) # old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)