fix crash when no extra conditioning provided (redux)

This commit is contained in:
Lincoln Stein
2023-04-02 19:43:56 -04:00
parent 0a0e44b51e
commit afcb278e66

View File

@@ -9,18 +9,28 @@ from diffusers.models.cross_attention import AttnProcessor
from typing_extensions import TypeAlias
from ldm.invoke.globals import Globals
from ldm.models.diffusion.cross_attention_control import Arguments, \
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext
from ldm.models.diffusion.cross_attention_control import (
Arguments,
restore_default_cross_attention,
override_cross_attention,
Context,
get_cross_attention_modules,
CrossAttentionType,
SwapCrossAttnContext,
)
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ldm.modules.lora_manager import LoraCondition
ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs]
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]],
torch.Tensor,
],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
@@ -30,20 +40,20 @@ class PostprocessingSettings:
class InvokeAIDiffuserComponent:
'''
"""
The aim of this component is to provide a single place for code that can be applied identically to
all InvokeAI diffusion procedures.
At the moment it includes the following features:
* Cross attention control ("prompt2prompt")
* Hybrid conditioning (used for inpainting)
'''
"""
debug_thresholding = False
sequential_guidance = False
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
lora_conditions: Optional[list[LoraCondition]] = None
@@ -56,10 +66,12 @@ class InvokeAIDiffuserComponent:
def has_lora_conditions(self):
return self.lora_conditions is not None
def __init__(self, model, model_forward_callback: ModelForwardCallback,
is_running_diffusers: bool=False,
):
def __init__(
self,
model,
model_forward_callback: ModelForwardCallback,
is_running_diffusers: bool = False,
):
"""
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
@@ -72,13 +84,17 @@ class InvokeAIDiffuserComponent:
self.sequential_guidance = Globals.sequential_guidance
@contextmanager
def custom_attention_context(self,
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int):
def custom_attention_context(
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
):
old_attn_processor = None
if extra_conditioning_info.wants_cross_attention_control | extra_conditioning_info.has_lora_conditions:
old_attn_processor = self.override_attention_processors(extra_conditioning_info,
step_count=step_count)
if extra_conditioning_info and (
extra_conditioning_info.wants_cross_attention_control
| extra_conditioning_info.has_lora_conditions
):
old_attn_processor = self.override_attention_processors(
extra_conditioning_info, step_count=step_count
)
try:
yield None
@@ -89,10 +105,11 @@ class InvokeAIDiffuserComponent:
for lora_condition in extra_conditioning_info.lora_conditions:
lora_condition.unload()
# TODO resuscitate attention map saving
#self.remove_attention_map_saving()
# self.remove_attention_map_saving()
def override_attention_processors(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
def override_attention_processors(
self, conditioning: ExtraConditioningInfo, step_count: int
) -> Dict[str, AttnProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
@@ -107,19 +124,24 @@ class InvokeAIDiffuserComponent:
if conditioning.wants_cross_attention_control:
self.cross_attention_control_context = Context(
arguments=conditioning.cross_attention_control_args,
step_count=step_count
step_count=step_count,
)
override_cross_attention(
self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers,
)
override_cross_attention(self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers)
return old_attn_processors
def restore_default_cross_attention(self, processors_to_restore: Optional[dict[str, 'AttnProcessor']]=None):
def restore_default_cross_attention(
self, processors_to_restore: Optional[dict[str, "AttnProcessor"]] = None
):
self.cross_attention_control_context = None
restore_default_cross_attention(self.model,
is_running_diffusers=self.is_running_diffusers,
processors_to_restore=processors_to_restore)
restore_default_cross_attention(
self.model,
is_running_diffusers=self.is_running_diffusers,
processors_to_restore=processors_to_restore,
)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
@@ -128,26 +150,40 @@ class InvokeAIDiffuserComponent:
return
saver.add_attention_maps(slice, key)
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
tokens_cross_attention_modules = get_cross_attention_modules(
self.model, CrossAttentionType.TOKENS
)
for identifier, module in tokens_cross_attention_modules:
key = ('down' if identifier.startswith('down') else
'up' if identifier.startswith('up') else
'mid')
key = (
"down"
if identifier.startswith("down")
else "up"
if identifier.startswith("up")
else "mid"
)
module.set_attention_slice_calculated_callback(
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key))
lambda slice, dim, offset, slice_size, key=key: callback(
slice, dim, offset, slice_size, key
)
)
def remove_attention_map_saving(self):
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
tokens_cross_attention_modules = get_cross_attention_modules(
self.model, CrossAttentionType.TOKENS
)
for _, module in tokens_cross_attention_modules:
module.set_attention_slice_calculated_callback(None)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor,dict],
conditioning: Union[torch.Tensor,dict],
unconditional_guidance_scale: float,
step_index: Optional[int]=None,
total_step_count: Optional[int]=None,
):
def do_diffusion_step(
self,
x: torch.Tensor,
sigma: torch.Tensor,
unconditioning: Union[torch.Tensor, dict],
conditioning: Union[torch.Tensor, dict],
unconditional_guidance_scale: float,
step_index: Optional[int] = None,
total_step_count: Optional[int] = None,
):
"""
:param x: current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur
@@ -158,33 +194,55 @@ class InvokeAIDiffuserComponent:
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
percent_through = self.calculate_percent_through(
sigma, step_index, total_step_count
)
cross_attention_control_types_to_do = (
context.get_active_cross_attention_control_types_for_step(
percent_through
)
)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
wants_hybrid_conditioning = isinstance(conditioning, dict)
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning,
conditioning)
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x, sigma, unconditioning, conditioning
)
elif wants_cross_attention_control:
unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do)
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
x,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
)
elif self.sequential_guidance:
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning)
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning
)
else:
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning)
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning
)
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
combined_next_x = self._combine(
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
)
return combined_next_x
@@ -194,24 +252,33 @@ class InvokeAIDiffuserComponent:
latents: torch.Tensor,
sigma,
step_index,
total_step_count
total_step_count,
) -> torch.Tensor:
if postprocessing_settings is not None:
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
percent_through = self.calculate_percent_through(
sigma, step_index, total_step_count
)
latents = self.apply_threshold(
postprocessing_settings, latents, percent_through
)
latents = self.apply_symmetry(
postprocessing_settings, latents, percent_through
)
return latents
def calculate_percent_through(self, sigma, step_index, total_step_count):
if step_index is not None and total_step_count is not None:
# 🧨diffusers codepath
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
percent_through = (
step_index / total_step_count
) # will never reach 1.0 - this is deliberate
else:
# legacy compvis codepath
# TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None:
raise ValueError(
f"Either step_index or sigma is required when doing cross attention control, but both are None.")
f"Either step_index or sigma is required when doing cross attention control, but both are None."
)
percent_through = self.estimate_percent_through(step_index, sigma)
return percent_through
@@ -222,24 +289,30 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings)
both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings
)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == 'mps':
if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor):
def _apply_standard_conditioning_sequentially(
self,
x: torch.Tensor,
sigma,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
):
# low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
if conditioned_next_x.device.type == 'mps':
if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict)
@@ -254,48 +327,80 @@ class InvokeAIDiffuserComponent:
]
else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings
).chunk(2)
return unconditioned_next_x, conditioned_next_x
def _apply_cross_attention_controlled_conditioning(self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do):
def _apply_cross_attention_controlled_conditioning(
self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
):
if self.is_running_diffusers:
return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning,
conditioning,
cross_attention_control_types_to_do)
return self._apply_cross_attention_controlled_conditioning__diffusers(
x,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
)
else:
return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning,
cross_attention_control_types_to_do)
return self._apply_cross_attention_controlled_conditioning__compvis(
x,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
)
def _apply_cross_attention_controlled_conditioning__diffusers(self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do):
def _apply_cross_attention_controlled_conditioning__diffusers(
self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
):
context: Context = self.cross_attention_control_context
cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
mask=context.cross_attention_mask,
cross_attention_types_to_do=[])
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
mask=context.cross_attention_mask,
cross_attention_types_to_do=[],
)
# no cross attention for unconditioning (negative prompt)
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context})
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
)
# do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning,
{"swap_cross_attn_context": cross_attn_processor_context})
cross_attn_processor_context.cross_attention_types_to_do = (
cross_attention_control_types_to_do
)
conditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
)
return unconditioned_next_x, conditioned_next_x
def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
def _apply_cross_attention_controlled_conditioning__compvis(
self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
@@ -305,24 +410,26 @@ class InvokeAIDiffuserComponent:
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
context:Context = self.cross_attention_control_context
context: Context = self.cross_attention_control_context
try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps
#print("saving attention maps for", cross_attention_control_types_to_do)
# print("saving attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do:
context.request_save_attention_maps(ca_type)
_ = self.model_forward_callback(x, sigma, conditioning)
context.clear_requests(cleanup=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
#print("applying saved attention maps for", cross_attention_control_types_to_do)
# print("applying saved attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do:
context.request_apply_saved_attention_maps(ca_type)
edited_conditioning = context.arguments.edited_conditioning
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
conditioned_next_x = self.model_forward_callback(
x, sigma, edited_conditioning
)
context.clear_requests(cleanup=True)
except:
@@ -341,17 +448,21 @@ class InvokeAIDiffuserComponent:
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through: float
percent_through: float,
) -> torch.Tensor:
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
if (
postprocessing_settings.threshold is None
or postprocessing_settings.threshold == 0.0
):
return latents
threshold = postprocessing_settings.threshold
warmup = postprocessing_settings.warmup
if percent_through < warmup:
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
current_threshold = threshold + threshold * 5 * (
1 - (percent_through / warmup)
)
else:
current_threshold = threshold
@@ -365,10 +476,14 @@ class InvokeAIDiffuserComponent:
if self.debug_thresholding:
std, mean = [i.item() for i in torch.std_mean(latents)]
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
outside = torch.count_nonzero(
(latents < -current_threshold) | (latents > current_threshold)
)
print(
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
)
if maxval < current_threshold and minval > -current_threshold:
return latents
@@ -381,17 +496,23 @@ class InvokeAIDiffuserComponent:
latents = torch.clone(latents)
maxval = np.clip(maxval * scale, 1, current_threshold)
num_altered += torch.count_nonzero(latents > maxval)
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
latents[latents > maxval] = (
torch.rand_like(latents[latents > maxval]) * maxval
)
if minval < -current_threshold:
latents = torch.clone(latents)
minval = np.clip(minval * scale, -current_threshold, -1)
num_altered += torch.count_nonzero(latents < minval)
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
latents[latents < minval] = (
torch.rand_like(latents[latents < minval]) * minval
)
if self.debug_thresholding:
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
f" | {num_altered / latents.numel() * 100:.2f}% values altered")
print(
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
)
return latents
@@ -399,9 +520,8 @@ class InvokeAIDiffuserComponent:
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through: float
percent_through: float,
) -> torch.Tensor:
# Reset our last percent through if this is our first step.
if percent_through == 0.0:
self.last_percent_through = 0.0
@@ -411,36 +531,52 @@ class InvokeAIDiffuserComponent:
# Check for out of bounds
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)):
if h_symmetry_time_pct is not None and (
h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0
):
h_symmetry_time_pct = None
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
if v_symmetry_time_pct is not None and (
v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0
):
v_symmetry_time_pct = None
dev = latents.device.type
latents.to(device='cpu')
latents.to(device="cpu")
if (
h_symmetry_time_pct != None and
self.last_percent_through < h_symmetry_time_pct and
percent_through >= h_symmetry_time_pct
h_symmetry_time_pct != None
and self.last_percent_through < h_symmetry_time_pct
and percent_through >= h_symmetry_time_pct
):
# Horizontal symmetry occurs on the 3rd dimension of the latent
width = latents.shape[3]
x_flipped = torch.flip(latents, dims=[3])
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
latents = torch.cat(
[
latents[:, :, :, 0 : int(width / 2)],
x_flipped[:, :, :, int(width / 2) : int(width)],
],
dim=3,
)
if (
v_symmetry_time_pct != None and
self.last_percent_through < v_symmetry_time_pct and
percent_through >= v_symmetry_time_pct
v_symmetry_time_pct != None
and self.last_percent_through < v_symmetry_time_pct
and percent_through >= v_symmetry_time_pct
):
# Vertical symmetry occurs on the 2nd dimension of the latent
height = latents.shape[2]
y_flipped = torch.flip(latents, dims=[2])
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
latents = torch.cat(
[
latents[:, :, 0 : int(height / 2)],
y_flipped[:, :, int(height / 2) : int(height)],
],
dim=2,
)
self.last_percent_through = percent_through
return latents.to(device=dev)
@@ -448,7 +584,9 @@ class InvokeAIDiffuserComponent:
def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended)
return float(step_index) / float(self.cross_attention_control_context.step_count)
return float(step_index) / float(
self.cross_attention_control_context.step_count
)
# find the best possible index of the current sigma in the sigma sequence
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
@@ -457,33 +595,38 @@ class InvokeAIDiffuserComponent:
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())
# todo: make this work
@classmethod
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
def apply_conjunction(
cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale
):
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas
t_in = torch.cat([t] * 2) # aka sigmas
deltas = None
uncond_latents = None
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
weighted_cond_list = (
c_or_weighted_c_list
if type(c_or_weighted_c_list) is list
else [(c_or_weighted_c_list, 1)]
)
# below is fugly omg
num_actual_conditionings = len(c_or_weighted_c_list)
conditionings = [uc] + [c for c,weight in weighted_cond_list]
weights = [1] + [weight for c,weight in weighted_cond_list]
chunk_count = ceil(len(conditionings)/2)
conditionings = [uc] + [c for c, weight in weighted_cond_list]
weights = [1] + [weight for c, weight in weighted_cond_list]
chunk_count = ceil(len(conditionings) / 2)
deltas = None
for chunk_index in range(chunk_count):
offset = chunk_index*2
chunk_size = min(2, len(conditionings)-offset)
offset = chunk_index * 2
chunk_size = min(2, len(conditionings) - offset)
if chunk_size == 1:
c_in = conditionings[offset]
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
latents_b = None
else:
c_in = torch.cat(conditionings[offset:offset+2])
c_in = torch.cat(conditionings[offset : offset + 2])
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
@@ -496,11 +639,15 @@ class InvokeAIDiffuserComponent:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
per_delta_weights = torch.tensor(
weights[1:], dtype=deltas.dtype, device=deltas.device
)
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
reshaped_weights = per_delta_weights.reshape(
per_delta_weights.shape + (1, 1, 1)
)
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)