mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-22 10:37:55 -05:00
Updated from main.
This commit is contained in:
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Callable, Optional, Union, Any, Dict
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -10,19 +10,30 @@ from diffusers.models.cross_attention import AttnProcessor
|
||||
from einops import einops
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
||||
CrossAttentionType, SwapCrossAttnContext
|
||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from .cross_attention_control import (
|
||||
Arguments,
|
||||
Context,
|
||||
CrossAttentionType,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
override_cross_attention,
|
||||
restore_default_cross_attention,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
ModelForwardCallback: TypeAlias = Union[
|
||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor],
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
Callable[
|
||||
[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]],
|
||||
torch.Tensor,
|
||||
],
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
]
|
||||
|
||||
|
||||
|
||||
SymmetryType = Enum('SymmetryType', ['MIRROR', 'FADE'])
|
||||
|
||||
|
||||
@@ -36,20 +47,20 @@ class PostprocessingSettings:
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
'''
|
||||
"""
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
all InvokeAI diffusion procedures.
|
||||
|
||||
At the moment it includes the following features:
|
||||
* Cross attention control ("prompt2prompt")
|
||||
* Hybrid conditioning (used for inpainting)
|
||||
'''
|
||||
"""
|
||||
|
||||
debug_thresholding = False
|
||||
sequential_guidance = False
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
|
||||
@@ -57,10 +68,12 @@ class InvokeAIDiffuserComponent:
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
|
||||
def __init__(self, model, model_forward_callback: ModelForwardCallback,
|
||||
is_running_diffusers: bool=False,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model_forward_callback: ModelForwardCallback,
|
||||
is_running_diffusers: bool = False,
|
||||
):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
@@ -73,23 +86,29 @@ class InvokeAIDiffuserComponent:
|
||||
self.sequential_guidance = Globals.sequential_guidance
|
||||
|
||||
@contextmanager
|
||||
def custom_attention_context(self,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int):
|
||||
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||
def custom_attention_context(
|
||||
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
|
||||
):
|
||||
do_swap = (
|
||||
extra_conditioning_info is not None
|
||||
and extra_conditioning_info.wants_cross_attention_control
|
||||
)
|
||||
old_attn_processor = None
|
||||
if do_swap:
|
||||
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
|
||||
step_count=step_count)
|
||||
old_attn_processor = self.override_cross_attention(
|
||||
extra_conditioning_info, step_count=step_count
|
||||
)
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if old_attn_processor is not None:
|
||||
self.restore_default_cross_attention(old_attn_processor)
|
||||
# TODO resuscitate attention map saving
|
||||
#self.remove_attention_map_saving()
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
|
||||
def override_cross_attention(
|
||||
self, conditioning: ExtraConditioningInfo, step_count: int
|
||||
) -> Dict[str, AttnProcessor]:
|
||||
"""
|
||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
the previous attention processor is returned so that the caller can restore it later.
|
||||
@@ -97,18 +116,24 @@ class InvokeAIDiffuserComponent:
|
||||
self.conditioning = conditioning
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
step_count=step_count,
|
||||
)
|
||||
return override_cross_attention(
|
||||
self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
)
|
||||
return override_cross_attention(self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers)
|
||||
|
||||
def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None):
|
||||
def restore_default_cross_attention(
|
||||
self, restore_attention_processor: Optional["AttnProcessor"] = None
|
||||
):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
restore_default_cross_attention(self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
restore_attention_processor=restore_attention_processor)
|
||||
restore_default_cross_attention(
|
||||
self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
restore_attention_processor=restore_attention_processor,
|
||||
)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
@@ -117,26 +142,40 @@ class InvokeAIDiffuserComponent:
|
||||
return
|
||||
saver.add_attention_maps(slice, key)
|
||||
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(
|
||||
self.model, CrossAttentionType.TOKENS
|
||||
)
|
||||
for identifier, module in tokens_cross_attention_modules:
|
||||
key = ('down' if identifier.startswith('down') else
|
||||
'up' if identifier.startswith('up') else
|
||||
'mid')
|
||||
key = (
|
||||
"down"
|
||||
if identifier.startswith("down")
|
||||
else "up"
|
||||
if identifier.startswith("up")
|
||||
else "mid"
|
||||
)
|
||||
module.set_attention_slice_calculated_callback(
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(
|
||||
slice, dim, offset, slice_size, key
|
||||
)
|
||||
)
|
||||
|
||||
def remove_attention_map_saving(self):
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(
|
||||
self.model, CrossAttentionType.TOKENS
|
||||
)
|
||||
for _, module in tokens_cross_attention_modules:
|
||||
module.set_attention_slice_calculated_callback(None)
|
||||
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor,dict],
|
||||
conditioning: Union[torch.Tensor,dict],
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: Optional[int]=None,
|
||||
total_step_count: Optional[int]=None,
|
||||
):
|
||||
def do_diffusion_step(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor, dict],
|
||||
conditioning: Union[torch.Tensor, dict],
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: Optional[int] = None,
|
||||
total_step_count: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
:param x: current latents
|
||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||
@@ -147,33 +186,55 @@ class InvokeAIDiffuserComponent:
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||
percent_through = self.calculate_percent_through(
|
||||
sigma, step_index, total_step_count
|
||||
)
|
||||
cross_attention_control_types_to_do = (
|
||||
context.get_active_cross_attention_control_types_for_step(
|
||||
percent_through
|
||||
)
|
||||
)
|
||||
|
||||
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||
|
||||
if wants_hybrid_conditioning:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning,
|
||||
conditioning)
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||
x, sigma, unconditioning, conditioning
|
||||
)
|
||||
elif wants_cross_attention_control:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do)
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_cross_attention_controlled_conditioning(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
)
|
||||
elif self.sequential_guidance:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially(
|
||||
x, sigma, unconditioning, conditioning)
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning_sequentially(
|
||||
x, sigma, unconditioning, conditioning
|
||||
)
|
||||
|
||||
else:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(
|
||||
x, sigma, unconditioning, conditioning)
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
x, sigma, unconditioning, conditioning
|
||||
)
|
||||
|
||||
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
||||
combined_next_x = self._combine(
|
||||
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
||||
)
|
||||
|
||||
return combined_next_x
|
||||
|
||||
@@ -183,24 +244,33 @@ class InvokeAIDiffuserComponent:
|
||||
latents: torch.Tensor,
|
||||
sigma,
|
||||
step_index,
|
||||
total_step_count
|
||||
total_step_count,
|
||||
) -> torch.Tensor:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
|
||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||
percent_through = self.calculate_percent_through(
|
||||
sigma, step_index, total_step_count
|
||||
)
|
||||
latents = self.apply_threshold(
|
||||
postprocessing_settings, latents, percent_through
|
||||
)
|
||||
latents = self.apply_symmetry(
|
||||
postprocessing_settings, latents, percent_through
|
||||
)
|
||||
return latents
|
||||
|
||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||
if step_index is not None and total_step_count is not None:
|
||||
# 🧨diffusers codepath
|
||||
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
|
||||
percent_through = (
|
||||
step_index / total_step_count
|
||||
) # will never reach 1.0 - this is deliberate
|
||||
else:
|
||||
# legacy compvis codepath
|
||||
# TODO remove when compvis codepath support is dropped
|
||||
if step_index is None and sigma is None:
|
||||
raise ValueError(
|
||||
f"Either step_index or sigma is required when doing cross attention control, but both are None.")
|
||||
f"Either step_index or sigma is required when doing cross attention control, but both are None."
|
||||
)
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
return percent_through
|
||||
|
||||
@@ -211,24 +281,30 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings)
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings
|
||||
)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
if conditioned_next_x.device.type == 'mps':
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor):
|
||||
def _apply_standard_conditioning_sequentially(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning: torch.Tensor,
|
||||
conditioning: torch.Tensor,
|
||||
):
|
||||
# low-memory sequential path
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
|
||||
if conditioned_next_x.device.type == 'mps':
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
assert isinstance(conditioning, dict)
|
||||
assert isinstance(unconditioning, dict)
|
||||
@@ -243,48 +319,80 @@ class InvokeAIDiffuserComponent:
|
||||
]
|
||||
else:
|
||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings
|
||||
).chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning(self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do):
|
||||
def _apply_cross_attention_controlled_conditioning(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
):
|
||||
if self.is_running_diffusers:
|
||||
return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do)
|
||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
)
|
||||
else:
|
||||
return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning,
|
||||
cross_attention_control_types_to_do)
|
||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
)
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__diffusers(self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do):
|
||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
):
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning,
|
||||
index_map=context.cross_attention_index_map,
|
||||
mask=context.cross_attention_mask,
|
||||
cross_attention_types_to_do=[])
|
||||
cross_attn_processor_context = SwapCrossAttnContext(
|
||||
modified_text_embeddings=context.arguments.edited_conditioning,
|
||||
index_map=context.cross_attention_index_map,
|
||||
mask=context.cross_attention_mask,
|
||||
cross_attention_types_to_do=[],
|
||||
)
|
||||
# no cross attention for unconditioning (negative prompt)
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context})
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
)
|
||||
|
||||
# do requested cross attention types for conditioning (positive prompt)
|
||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context})
|
||||
cross_attn_processor_context.cross_attention_types_to_do = (
|
||||
cross_attention_control_types_to_do
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||
def _apply_cross_attention_controlled_conditioning__compvis(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
):
|
||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
@@ -294,24 +402,28 @@ class InvokeAIDiffuserComponent:
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
context:Context = self.cross_attention_control_context
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
try:
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
#print("saving attention maps for", cross_attention_control_types_to_do)
|
||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_save_attention_maps(ca_type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||
context.clear_requests(cleanup=False)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
#print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||
# print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_apply_saved_attention_maps(ca_type)
|
||||
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||
edited_conditioning = (
|
||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, edited_conditioning
|
||||
)
|
||||
context.clear_requests(cleanup=True)
|
||||
|
||||
except:
|
||||
@@ -330,17 +442,21 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through: float
|
||||
percent_through: float,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
|
||||
if (
|
||||
postprocessing_settings.threshold is None
|
||||
or postprocessing_settings.threshold == 0.0
|
||||
):
|
||||
return latents
|
||||
|
||||
threshold = postprocessing_settings.threshold
|
||||
warmup = postprocessing_settings.warmup
|
||||
|
||||
if percent_through < warmup:
|
||||
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
|
||||
current_threshold = threshold + threshold * 5 * (
|
||||
1 - (percent_through / warmup)
|
||||
)
|
||||
else:
|
||||
current_threshold = threshold
|
||||
|
||||
@@ -354,10 +470,14 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
if self.debug_thresholding:
|
||||
std, mean = [i.item() for i in torch.std_mean(latents)]
|
||||
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
|
||||
print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||
f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
|
||||
outside = torch.count_nonzero(
|
||||
(latents < -current_threshold) | (latents > current_threshold)
|
||||
)
|
||||
print(
|
||||
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||
)
|
||||
|
||||
if maxval < current_threshold and minval > -current_threshold:
|
||||
return latents
|
||||
@@ -370,17 +490,23 @@ class InvokeAIDiffuserComponent:
|
||||
latents = torch.clone(latents)
|
||||
maxval = np.clip(maxval * scale, 1, current_threshold)
|
||||
num_altered += torch.count_nonzero(latents > maxval)
|
||||
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
|
||||
latents[latents > maxval] = (
|
||||
torch.rand_like(latents[latents > maxval]) * maxval
|
||||
)
|
||||
|
||||
if minval < -current_threshold:
|
||||
latents = torch.clone(latents)
|
||||
minval = np.clip(minval * scale, -current_threshold, -1)
|
||||
num_altered += torch.count_nonzero(latents < minval)
|
||||
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
|
||||
latents[latents < minval] = (
|
||||
torch.rand_like(latents[latents < minval]) * minval
|
||||
)
|
||||
|
||||
if self.debug_thresholding:
|
||||
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||
f" | {num_altered / latents.numel() * 100:.2f}% values altered")
|
||||
print(
|
||||
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
@@ -388,9 +514,8 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through: float
|
||||
percent_through: float,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Reset our last percent through if this is our first step.
|
||||
if percent_through == 0.0:
|
||||
self.last_percent_through = 0.0
|
||||
@@ -400,11 +525,15 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# Check for out of bounds
|
||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)):
|
||||
if h_symmetry_time_pct is not None and (
|
||||
h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0
|
||||
):
|
||||
h_symmetry_time_pct = None
|
||||
|
||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
|
||||
if v_symmetry_time_pct is not None and (
|
||||
v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0
|
||||
):
|
||||
v_symmetry_time_pct = None
|
||||
|
||||
width = latents.shape[3]
|
||||
@@ -413,24 +542,29 @@ class InvokeAIDiffuserComponent:
|
||||
dtype = latents.dtype
|
||||
symmetry_type = postprocessing_settings.symmetry_type or SymmetryType.FADE
|
||||
|
||||
latents.to(device='cpu')
|
||||
latents.to(device="cpu")
|
||||
|
||||
def make_ramp(ease_in:int, total:int) -> torch.Tensor:
|
||||
def make_ramp(ease_in: int, total: int) -> torch.Tensor:
|
||||
ramp1 = torch.linspace(start=1.0, end=0.5, steps=ease_in, device=dev)
|
||||
ramp2 = torch.linspace(start=0.5, end=1.0, steps=total - ease_in, device=dev)
|
||||
ramp = torch.cat((ramp1, ramp2))
|
||||
return ramp
|
||||
|
||||
if (
|
||||
h_symmetry_time_pct != None and
|
||||
self.last_percent_through < h_symmetry_time_pct and
|
||||
percent_through >= h_symmetry_time_pct
|
||||
h_symmetry_time_pct != None
|
||||
and self.last_percent_through < h_symmetry_time_pct
|
||||
and percent_through >= h_symmetry_time_pct
|
||||
):
|
||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||
x_flipped = torch.flip(latents, dims=[3])
|
||||
if symmetry_type is SymmetryType.MIRROR:
|
||||
# Use the first half of latents and then the flipped one on this dimension
|
||||
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
|
||||
latents = torch.cat(
|
||||
[
|
||||
latents[:, :, :, 0 : int(width / 2)],
|
||||
x_flipped[:, :, :, int(width / 2) : int(width)],
|
||||
],
|
||||
dim=3,
|
||||
)
|
||||
elif symmetry_type is SymmetryType.FADE:
|
||||
apply_width = width // 2
|
||||
# Create a linear ramp so the middle gets perfect symmetry but the edges retain their original latents
|
||||
@@ -442,15 +576,20 @@ class InvokeAIDiffuserComponent:
|
||||
latents = ((latents * fade1) + (x_flipped * fade0)) * multiplier
|
||||
|
||||
if (
|
||||
v_symmetry_time_pct != None and
|
||||
self.last_percent_through < v_symmetry_time_pct and
|
||||
percent_through >= v_symmetry_time_pct
|
||||
v_symmetry_time_pct != None
|
||||
and self.last_percent_through < v_symmetry_time_pct
|
||||
and percent_through >= v_symmetry_time_pct
|
||||
):
|
||||
# Vertical symmetry occurs on the 3rd dimension of the latent
|
||||
y_flipped = torch.flip(latents, dims=[2])
|
||||
if symmetry_type is SymmetryType.MIRROR:
|
||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
|
||||
latents = torch.cat(
|
||||
[
|
||||
latents[:, :, 0 : int(height / 2)],
|
||||
y_flipped[:, :, int(height / 2) : int(height)],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
elif symmetry_type is SymmetryType.FADE:
|
||||
apply_height = height // 2
|
||||
# Create a linear ramp so the middle gets perfect symmetry but the edges retain their original latents
|
||||
@@ -467,7 +606,9 @@ class InvokeAIDiffuserComponent:
|
||||
def estimate_percent_through(self, step_index, sigma):
|
||||
if step_index is not None and self.cross_attention_control_context is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||
return float(step_index) / float(
|
||||
self.cross_attention_control_context.step_count
|
||||
)
|
||||
# find the best possible index of the current sigma in the sigma sequence
|
||||
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
|
||||
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
|
||||
@@ -476,33 +617,38 @@ class InvokeAIDiffuserComponent:
|
||||
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
|
||||
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
||||
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
def apply_conjunction(
|
||||
cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale
|
||||
):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
||||
weighted_cond_list = (
|
||||
c_or_weighted_c_list
|
||||
if type(c_or_weighted_c_list) is list
|
||||
else [(c_or_weighted_c_list, 1)]
|
||||
)
|
||||
|
||||
# below is fugly omg
|
||||
num_actual_conditionings = len(c_or_weighted_c_list)
|
||||
conditionings = [uc] + [c for c,weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c,weight in weighted_cond_list]
|
||||
chunk_count = ceil(len(conditionings)/2)
|
||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
||||
chunk_count = ceil(len(conditionings) / 2)
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index*2
|
||||
chunk_size = min(2, len(conditionings)-offset)
|
||||
offset = chunk_index * 2
|
||||
chunk_size = min(2, len(conditionings) - offset)
|
||||
|
||||
if chunk_size == 1:
|
||||
c_in = conditionings[offset]
|
||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||
latents_b = None
|
||||
else:
|
||||
c_in = torch.cat(conditionings[offset:offset+2])
|
||||
c_in = torch.cat(conditionings[offset : offset + 2])
|
||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||
|
||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||
@@ -515,11 +661,15 @@ class InvokeAIDiffuserComponent:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
per_delta_weights = torch.tensor(
|
||||
weights[1:], dtype=deltas.dtype, device=deltas.device
|
||||
)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
reshaped_weights = per_delta_weights.reshape(
|
||||
per_delta_weights.shape + (1, 1, 1)
|
||||
)
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||
|
||||
Reference in New Issue
Block a user