Updated from main.

This commit is contained in:
JPPhoto
2023-03-03 20:28:03 -06:00
parent 4765e707bf
commit 9ec6fbfee0

View File

@@ -2,7 +2,7 @@ from enum import Enum
from contextlib import contextmanager
from dataclasses import dataclass
from math import ceil
from typing import Callable, Optional, Union, Any, Dict
from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import torch
@@ -10,19 +10,30 @@ from diffusers.models.cross_attention import AttnProcessor
from einops import einops
from typing_extensions import TypeAlias
from ldm.invoke.globals import Globals
from ldm.models.diffusion.cross_attention_control import Arguments, \
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
from invokeai.backend.globals import Globals
from .cross_attention_control import (
Arguments,
Context,
CrossAttentionType,
SwapCrossAttnContext,
get_cross_attention_modules,
override_cross_attention,
restore_default_cross_attention,
)
from .cross_attention_map_saving import AttentionMapSaver
ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs]
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]],
torch.Tensor,
],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
SymmetryType = Enum('SymmetryType', ['MIRROR', 'FADE'])
@@ -36,20 +47,20 @@ class PostprocessingSettings:
class InvokeAIDiffuserComponent:
'''
"""
The aim of this component is to provide a single place for code that can be applied identically to
all InvokeAI diffusion procedures.
At the moment it includes the following features:
* Cross attention control ("prompt2prompt")
* Hybrid conditioning (used for inpainting)
'''
"""
debug_thresholding = False
sequential_guidance = False
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@@ -57,10 +68,12 @@ class InvokeAIDiffuserComponent:
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
def __init__(self, model, model_forward_callback: ModelForwardCallback,
is_running_diffusers: bool=False,
):
def __init__(
self,
model,
model_forward_callback: ModelForwardCallback,
is_running_diffusers: bool = False,
):
"""
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
@@ -73,23 +86,29 @@ class InvokeAIDiffuserComponent:
self.sequential_guidance = Globals.sequential_guidance
@contextmanager
def custom_attention_context(self,
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int):
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
def custom_attention_context(
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
):
do_swap = (
extra_conditioning_info is not None
and extra_conditioning_info.wants_cross_attention_control
)
old_attn_processor = None
if do_swap:
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
step_count=step_count)
old_attn_processor = self.override_cross_attention(
extra_conditioning_info, step_count=step_count
)
try:
yield None
finally:
if old_attn_processor is not None:
self.restore_default_cross_attention(old_attn_processor)
# TODO resuscitate attention map saving
#self.remove_attention_map_saving()
# self.remove_attention_map_saving()
def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
def override_cross_attention(
self, conditioning: ExtraConditioningInfo, step_count: int
) -> Dict[str, AttnProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
@@ -97,18 +116,24 @@ class InvokeAIDiffuserComponent:
self.conditioning = conditioning
self.cross_attention_control_context = Context(
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
step_count=step_count,
)
return override_cross_attention(
self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers,
)
return override_cross_attention(self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers)
def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None):
def restore_default_cross_attention(
self, restore_attention_processor: Optional["AttnProcessor"] = None
):
self.conditioning = None
self.cross_attention_control_context = None
restore_default_cross_attention(self.model,
is_running_diffusers=self.is_running_diffusers,
restore_attention_processor=restore_attention_processor)
restore_default_cross_attention(
self.model,
is_running_diffusers=self.is_running_diffusers,
restore_attention_processor=restore_attention_processor,
)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
@@ -117,26 +142,40 @@ class InvokeAIDiffuserComponent:
return
saver.add_attention_maps(slice, key)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
tokens_cross_attention_modules = get_cross_attention_modules(
self.model, CrossAttentionType.TOKENS
)
for identifier, module in tokens_cross_attention_modules:
key = ('down' if identifier.startswith('down') else
'up' if identifier.startswith('up') else
'mid')
key = (
"down"
if identifier.startswith("down")
else "up"
if identifier.startswith("up")
else "mid"
)
module.set_attention_slice_calculated_callback(
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
lambda slice, dim, offset, slice_size, key=key: callback(
slice, dim, offset, slice_size, key
)
)
def remove_attention_map_saving(self):
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
tokens_cross_attention_modules = get_cross_attention_modules(
self.model, CrossAttentionType.TOKENS
)
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor,dict],
conditioning: Union[torch.Tensor,dict],
unconditional_guidance_scale: float,
step_index: Optional[int]=None,
total_step_count: Optional[int]=None,
):
def do_diffusion_step(
self,
x: torch.Tensor,
sigma: torch.Tensor,
unconditioning: Union[torch.Tensor, dict],
conditioning: Union[torch.Tensor, dict],
unconditional_guidance_scale: float,
step_index: Optional[int] = None,
total_step_count: Optional[int] = None,
):
"""
:param x: current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur
@@ -147,33 +186,55 @@ class InvokeAIDiffuserComponent:
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
percent_through = self.calculate_percent_through(
sigma, step_index, total_step_count
)
cross_attention_control_types_to_do = (
context.get_active_cross_attention_control_types_for_step(
percent_through
)
)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
wants_hybrid_conditioning = isinstance(conditioning, dict)
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning,
conditioning)
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x, sigma, unconditioning, conditioning
)
elif wants_cross_attention_control:
unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do)
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
x,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
)
elif self.sequential_guidance:
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning)
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning
)
else:
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning)
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning
)
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
combined_next_x = self._combine(
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
)
return combined_next_x
@@ -183,24 +244,33 @@ class InvokeAIDiffuserComponent:
latents: torch.Tensor,
sigma,
step_index,
total_step_count
total_step_count,
) -> torch.Tensor:
if postprocessing_settings is not None:
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
percent_through = self.calculate_percent_through(
sigma, step_index, total_step_count
)
latents = self.apply_threshold(
postprocessing_settings, latents, percent_through
)
latents = self.apply_symmetry(
postprocessing_settings, latents, percent_through
)
return latents
def calculate_percent_through(self, sigma, step_index, total_step_count):
if step_index is not None and total_step_count is not None:
# 🧨diffusers codepath
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
percent_through = (
step_index / total_step_count
) # will never reach 1.0 - this is deliberate
else:
# legacy compvis codepath
# TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None:
raise ValueError(
f"Either step_index or sigma is required when doing cross attention control, but both are None.")
f"Either step_index or sigma is required when doing cross attention control, but both are None."
)
percent_through = self.estimate_percent_through(step_index, sigma)
return percent_through
@@ -211,24 +281,30 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings)
both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings
)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == 'mps':
if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor):
def _apply_standard_conditioning_sequentially(
self,
x: torch.Tensor,
sigma,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
):
# low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
if conditioned_next_x.device.type == 'mps':
if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict)
@@ -243,48 +319,80 @@ class InvokeAIDiffuserComponent:
]
else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings
).chunk(2)
return unconditioned_next_x, conditioned_next_x
def _apply_cross_attention_controlled_conditioning(self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do):
def _apply_cross_attention_controlled_conditioning(
self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
):
if self.is_running_diffusers:
return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning,
conditioning,
cross_attention_control_types_to_do)
return self._apply_cross_attention_controlled_conditioning__diffusers(
x,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
)
else:
return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning,
cross_attention_control_types_to_do)
return self._apply_cross_attention_controlled_conditioning__compvis(
x,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
)
def _apply_cross_attention_controlled_conditioning__diffusers(self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do):
def _apply_cross_attention_controlled_conditioning__diffusers(
self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
):
context: Context = self.cross_attention_control_context
cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
mask=context.cross_attention_mask,
cross_attention_types_to_do=[])
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
mask=context.cross_attention_mask,
cross_attention_types_to_do=[],
)
# no cross attention for unconditioning (negative prompt)
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context})
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
)
# do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning,
{"swap_cross_attn_context": cross_attn_processor_context})
cross_attn_processor_context.cross_attention_types_to_do = (
cross_attention_control_types_to_do
)
conditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
)
return unconditioned_next_x, conditioned_next_x
def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
def _apply_cross_attention_controlled_conditioning__compvis(
self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
@@ -294,24 +402,28 @@ class InvokeAIDiffuserComponent:
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
context:Context = self.cross_attention_control_context
context: Context = self.cross_attention_control_context
try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps
#print("saving attention maps for", cross_attention_control_types_to_do)
# print("saving attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do:
context.request_save_attention_maps(ca_type)
_ = self.model_forward_callback(x, sigma, conditioning)
context.clear_requests(cleanup=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
#print("applying saved attention maps for", cross_attention_control_types_to_do)
# print("applying saved attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do:
context.request_apply_saved_attention_maps(ca_type)
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
edited_conditioning = (
self.conditioning.cross_attention_control_args.edited_conditioning
)
conditioned_next_x = self.model_forward_callback(
x, sigma, edited_conditioning
)
context.clear_requests(cleanup=True)
except:
@@ -330,17 +442,21 @@ class InvokeAIDiffuserComponent:
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through: float
percent_through: float,
) -> torch.Tensor:
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
if (
postprocessing_settings.threshold is None
or postprocessing_settings.threshold == 0.0
):
return latents
threshold = postprocessing_settings.threshold
warmup = postprocessing_settings.warmup
if percent_through < warmup:
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
current_threshold = threshold + threshold * 5 * (
1 - (percent_through / warmup)
)
else:
current_threshold = threshold
@@ -354,10 +470,14 @@ class InvokeAIDiffuserComponent:
if self.debug_thresholding:
std, mean = [i.item() for i in torch.std_mean(latents)]
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
outside = torch.count_nonzero(
(latents < -current_threshold) | (latents > current_threshold)
)
print(
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
)
if maxval < current_threshold and minval > -current_threshold:
return latents
@@ -370,17 +490,23 @@ class InvokeAIDiffuserComponent:
latents = torch.clone(latents)
maxval = np.clip(maxval * scale, 1, current_threshold)
num_altered += torch.count_nonzero(latents > maxval)
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
latents[latents > maxval] = (
torch.rand_like(latents[latents > maxval]) * maxval
)
if minval < -current_threshold:
latents = torch.clone(latents)
minval = np.clip(minval * scale, -current_threshold, -1)
num_altered += torch.count_nonzero(latents < minval)
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
latents[latents < minval] = (
torch.rand_like(latents[latents < minval]) * minval
)
if self.debug_thresholding:
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
f" | {num_altered / latents.numel() * 100:.2f}% values altered")
print(
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
)
return latents
@@ -388,9 +514,8 @@ class InvokeAIDiffuserComponent:
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through: float
percent_through: float,
) -> torch.Tensor:
# Reset our last percent through if this is our first step.
if percent_through == 0.0:
self.last_percent_through = 0.0
@@ -400,11 +525,15 @@ class InvokeAIDiffuserComponent:
# Check for out of bounds
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)):
if h_symmetry_time_pct is not None and (
h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0
):
h_symmetry_time_pct = None
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
if v_symmetry_time_pct is not None and (
v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0
):
v_symmetry_time_pct = None
width = latents.shape[3]
@@ -413,24 +542,29 @@ class InvokeAIDiffuserComponent:
dtype = latents.dtype
symmetry_type = postprocessing_settings.symmetry_type or SymmetryType.FADE
latents.to(device='cpu')
latents.to(device="cpu")
def make_ramp(ease_in:int, total:int) -> torch.Tensor:
def make_ramp(ease_in: int, total: int) -> torch.Tensor:
ramp1 = torch.linspace(start=1.0, end=0.5, steps=ease_in, device=dev)
ramp2 = torch.linspace(start=0.5, end=1.0, steps=total - ease_in, device=dev)
ramp = torch.cat((ramp1, ramp2))
return ramp
if (
h_symmetry_time_pct != None and
self.last_percent_through < h_symmetry_time_pct and
percent_through >= h_symmetry_time_pct
h_symmetry_time_pct != None
and self.last_percent_through < h_symmetry_time_pct
and percent_through >= h_symmetry_time_pct
):
# Horizontal symmetry occurs on the 3rd dimension of the latent
x_flipped = torch.flip(latents, dims=[3])
if symmetry_type is SymmetryType.MIRROR:
# Use the first half of latents and then the flipped one on this dimension
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
latents = torch.cat(
[
latents[:, :, :, 0 : int(width / 2)],
x_flipped[:, :, :, int(width / 2) : int(width)],
],
dim=3,
)
elif symmetry_type is SymmetryType.FADE:
apply_width = width // 2
# Create a linear ramp so the middle gets perfect symmetry but the edges retain their original latents
@@ -442,15 +576,20 @@ class InvokeAIDiffuserComponent:
latents = ((latents * fade1) + (x_flipped * fade0)) * multiplier
if (
v_symmetry_time_pct != None and
self.last_percent_through < v_symmetry_time_pct and
percent_through >= v_symmetry_time_pct
v_symmetry_time_pct != None
and self.last_percent_through < v_symmetry_time_pct
and percent_through >= v_symmetry_time_pct
):
# Vertical symmetry occurs on the 3rd dimension of the latent
y_flipped = torch.flip(latents, dims=[2])
if symmetry_type is SymmetryType.MIRROR:
# Vertical symmetry occurs on the 2nd dimension of the latent
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
latents = torch.cat(
[
latents[:, :, 0 : int(height / 2)],
y_flipped[:, :, int(height / 2) : int(height)],
],
dim=2,
)
elif symmetry_type is SymmetryType.FADE:
apply_height = height // 2
# Create a linear ramp so the middle gets perfect symmetry but the edges retain their original latents
@@ -467,7 +606,9 @@ class InvokeAIDiffuserComponent:
def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended)
return float(step_index) / float(self.cross_attention_control_context.step_count)
return float(step_index) / float(
self.cross_attention_control_context.step_count
)
# find the best possible index of the current sigma in the sigma sequence
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
@@ -476,33 +617,38 @@ class InvokeAIDiffuserComponent:
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
# todo: make this work
@classmethod
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
def apply_conjunction(
cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale
):
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas
t_in = torch.cat([t] * 2) # aka sigmas
deltas = None
uncond_latents = None
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
weighted_cond_list = (
c_or_weighted_c_list
if type(c_or_weighted_c_list) is list
else [(c_or_weighted_c_list, 1)]
)
# below is fugly omg
num_actual_conditionings = len(c_or_weighted_c_list)
conditionings = [uc] + [c for c,weight in weighted_cond_list]
weights = [1] + [weight for c,weight in weighted_cond_list]
chunk_count = ceil(len(conditionings)/2)
conditionings = [uc] + [c for c, weight in weighted_cond_list]
weights = [1] + [weight for c, weight in weighted_cond_list]
chunk_count = ceil(len(conditionings) / 2)
deltas = None
for chunk_index in range(chunk_count):
offset = chunk_index*2
chunk_size = min(2, len(conditionings)-offset)
offset = chunk_index * 2
chunk_size = min(2, len(conditionings) - offset)
if chunk_size == 1:
c_in = conditionings[offset]
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
latents_b = None
else:
c_in = torch.cat(conditionings[offset:offset+2])
c_in = torch.cat(conditionings[offset : offset + 2])
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
@@ -515,11 +661,15 @@ class InvokeAIDiffuserComponent:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
per_delta_weights = torch.tensor(
weights[1:], dtype=deltas.dtype, device=deltas.device
)
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
reshaped_weights = per_delta_weights.reshape(
per_delta_weights.shape + (1, 1, 1)
)
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)