mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 04:38:07 -05:00
497 lines
23 KiB
Python
497 lines
23 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
import torch
|
|
from typing_extensions import TypeAlias
|
|
|
|
from invokeai.app.services.config.config_default import get_config
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|
IPAdapterData,
|
|
Range,
|
|
TextConditioningData,
|
|
TextConditioningRegions,
|
|
)
|
|
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
|
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
|
|
|
ModelForwardCallback: TypeAlias = Union[
|
|
# x, t, conditioning, Optional[cross-attention kwargs]
|
|
Callable[
|
|
[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]],
|
|
torch.Tensor,
|
|
],
|
|
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
|
]
|
|
|
|
|
|
class InvokeAIDiffuserComponent:
|
|
"""
|
|
The aim of this component is to provide a single place for code that can be applied identically to
|
|
all InvokeAI diffusion procedures.
|
|
|
|
At the moment it includes the following features:
|
|
* Cross attention control ("prompt2prompt")
|
|
* Hybrid conditioning (used for inpainting)
|
|
"""
|
|
|
|
debug_thresholding = False
|
|
sequential_guidance = False
|
|
|
|
def __init__(
|
|
self,
|
|
model,
|
|
model_forward_callback: ModelForwardCallback,
|
|
):
|
|
"""
|
|
:param model: the unet model to pass through to cross attention control
|
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
|
"""
|
|
config = get_config()
|
|
self.conditioning = None
|
|
self.model = model
|
|
self.model_forward_callback = model_forward_callback
|
|
self.sequential_guidance = config.sequential_guidance
|
|
|
|
def do_controlnet_step(
|
|
self,
|
|
control_data,
|
|
sample: torch.Tensor,
|
|
timestep: torch.Tensor,
|
|
step_index: int,
|
|
total_step_count: int,
|
|
conditioning_data: TextConditioningData,
|
|
):
|
|
down_block_res_samples, mid_block_res_sample = None, None
|
|
|
|
# control_data should be type List[ControlNetData]
|
|
# this loop covers both ControlNet (one ControlNetData in list)
|
|
# and MultiControlNet (multiple ControlNetData in list)
|
|
for _i, control_datum in enumerate(control_data):
|
|
control_mode = control_datum.control_mode
|
|
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
|
|
# that are combined at higher level to make control_mode enum
|
|
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
|
|
# or default weighting (if False)
|
|
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
|
|
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
|
|
# or the default both conditional and unconditional (if False)
|
|
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
|
|
|
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
|
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
|
# only apply controlnet if current step is within the controlnet's begin/end step range
|
|
if step_index >= first_control_step and step_index <= last_control_step:
|
|
if cfg_injection:
|
|
sample_model_input = sample
|
|
else:
|
|
# expand the latents input to control model if doing classifier free guidance
|
|
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
|
# classifier_free_guidance is <= 1.0 ?)
|
|
sample_model_input = torch.cat([sample] * 2)
|
|
|
|
added_cond_kwargs = None
|
|
|
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
|
if conditioning_data.is_sdxl():
|
|
added_cond_kwargs = {
|
|
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
|
}
|
|
encoder_hidden_states = conditioning_data.cond_text.embeds
|
|
encoder_attention_mask = None
|
|
else:
|
|
if conditioning_data.is_sdxl():
|
|
added_cond_kwargs = {
|
|
"text_embeds": torch.cat(
|
|
[
|
|
# TODO: how to pad? just by zeros? or even truncate?
|
|
conditioning_data.uncond_text.pooled_embeds,
|
|
conditioning_data.cond_text.pooled_embeds,
|
|
],
|
|
dim=0,
|
|
),
|
|
"time_ids": torch.cat(
|
|
[
|
|
conditioning_data.uncond_text.add_time_ids,
|
|
conditioning_data.cond_text.add_time_ids,
|
|
],
|
|
dim=0,
|
|
),
|
|
}
|
|
(
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
) = self._concat_conditionings_for_batch(
|
|
conditioning_data.uncond_text.embeds,
|
|
conditioning_data.cond_text.embeds,
|
|
)
|
|
if isinstance(control_datum.weight, list):
|
|
# if controlnet has multiple weights, use the weight for the current step
|
|
controlnet_weight = control_datum.weight[step_index]
|
|
else:
|
|
# if controlnet has a single weight, use it for all steps
|
|
controlnet_weight = control_datum.weight
|
|
|
|
# controlnet(s) inference
|
|
down_samples, mid_sample = control_datum.model(
|
|
sample=sample_model_input,
|
|
timestep=timestep,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
controlnet_cond=control_datum.image_tensor,
|
|
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
|
return_dict=False,
|
|
)
|
|
if cfg_injection:
|
|
# Inferred ControlNet only for the conditional batch.
|
|
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
|
# prepend zeros for unconditional batch
|
|
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
|
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
|
|
|
if down_block_res_samples is None and mid_block_res_sample is None:
|
|
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
|
else:
|
|
# add controlnet outputs together if have multiple controlnets
|
|
down_block_res_samples = [
|
|
samples_prev + samples_curr
|
|
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=True)
|
|
]
|
|
mid_block_res_sample += mid_sample
|
|
|
|
return down_block_res_samples, mid_block_res_sample
|
|
|
|
def do_unet_step(
|
|
self,
|
|
sample: torch.Tensor,
|
|
timestep: torch.Tensor,
|
|
conditioning_data: TextConditioningData,
|
|
ip_adapter_data: Optional[list[IPAdapterData]],
|
|
step_index: int,
|
|
total_step_count: int,
|
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
|
):
|
|
if self.sequential_guidance:
|
|
(
|
|
unconditioned_next_x,
|
|
conditioned_next_x,
|
|
) = self._apply_standard_conditioning_sequentially(
|
|
x=sample,
|
|
sigma=timestep,
|
|
conditioning_data=conditioning_data,
|
|
ip_adapter_data=ip_adapter_data,
|
|
step_index=step_index,
|
|
total_step_count=total_step_count,
|
|
down_block_additional_residuals=down_block_additional_residuals,
|
|
mid_block_additional_residual=mid_block_additional_residual,
|
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
|
)
|
|
else:
|
|
(
|
|
unconditioned_next_x,
|
|
conditioned_next_x,
|
|
) = self._apply_standard_conditioning(
|
|
x=sample,
|
|
sigma=timestep,
|
|
conditioning_data=conditioning_data,
|
|
ip_adapter_data=ip_adapter_data,
|
|
step_index=step_index,
|
|
total_step_count=total_step_count,
|
|
down_block_additional_residuals=down_block_additional_residuals,
|
|
mid_block_additional_residual=mid_block_additional_residual,
|
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
|
)
|
|
|
|
return unconditioned_next_x, conditioned_next_x
|
|
|
|
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
|
conditioning_attention_mask = torch.ones(
|
|
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
|
|
)
|
|
|
|
if cond.shape[1] < max_len:
|
|
conditioning_attention_mask = torch.cat(
|
|
[
|
|
conditioning_attention_mask,
|
|
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
|
],
|
|
dim=1,
|
|
)
|
|
|
|
cond = torch.cat(
|
|
[
|
|
cond,
|
|
torch.zeros(
|
|
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
|
|
device=cond.device,
|
|
dtype=cond.dtype,
|
|
),
|
|
],
|
|
dim=1,
|
|
)
|
|
|
|
if encoder_attention_mask is None:
|
|
encoder_attention_mask = conditioning_attention_mask
|
|
else:
|
|
encoder_attention_mask = torch.cat(
|
|
[
|
|
encoder_attention_mask,
|
|
conditioning_attention_mask,
|
|
]
|
|
)
|
|
|
|
return cond, encoder_attention_mask
|
|
|
|
encoder_attention_mask = None
|
|
if unconditioning.shape[1] != conditioning.shape[1]:
|
|
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
|
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
|
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
|
|
|
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
|
|
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
|
|
|
def _apply_standard_conditioning(
|
|
self,
|
|
x: torch.Tensor,
|
|
sigma: torch.Tensor,
|
|
conditioning_data: TextConditioningData,
|
|
ip_adapter_data: Optional[list[IPAdapterData]],
|
|
step_index: int,
|
|
total_step_count: int,
|
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
|
the cost of higher memory usage.
|
|
"""
|
|
x_twice = torch.cat([x] * 2)
|
|
sigma_twice = torch.cat([sigma] * 2)
|
|
|
|
cross_attention_kwargs = {}
|
|
if ip_adapter_data is not None:
|
|
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
|
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
|
image_prompt_embeds = [
|
|
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
|
for ipa_conditioning in ip_adapter_conditioning
|
|
]
|
|
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
|
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
|
regional_ip_data = RegionalIPData(
|
|
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
|
)
|
|
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
|
|
|
added_cond_kwargs = None
|
|
if conditioning_data.is_sdxl():
|
|
added_cond_kwargs = {
|
|
"text_embeds": torch.cat(
|
|
[
|
|
# TODO: how to pad? just by zeros? or even truncate?
|
|
conditioning_data.uncond_text.pooled_embeds,
|
|
conditioning_data.cond_text.pooled_embeds,
|
|
],
|
|
dim=0,
|
|
),
|
|
"time_ids": torch.cat(
|
|
[
|
|
conditioning_data.uncond_text.add_time_ids,
|
|
conditioning_data.cond_text.add_time_ids,
|
|
],
|
|
dim=0,
|
|
),
|
|
}
|
|
|
|
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
|
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
|
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
|
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
|
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
|
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
|
regions = []
|
|
for c, r in [
|
|
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
|
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
|
]:
|
|
if r is None:
|
|
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
|
_, _, h, w = x.shape
|
|
r = TextConditioningRegions(
|
|
masks=torch.ones((1, 1, h, w), dtype=x.dtype),
|
|
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
|
)
|
|
regions.append(r)
|
|
|
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
|
regions=regions, device=x.device, dtype=x.dtype
|
|
)
|
|
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
|
|
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
|
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
|
)
|
|
both_results = self.model_forward_callback(
|
|
x_twice,
|
|
sigma_twice,
|
|
both_conditionings,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
down_block_additional_residuals=down_block_additional_residuals,
|
|
mid_block_additional_residual=mid_block_additional_residual,
|
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
)
|
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
|
return unconditioned_next_x, conditioned_next_x
|
|
|
|
def _apply_standard_conditioning_sequentially(
|
|
self,
|
|
x: torch.Tensor,
|
|
sigma,
|
|
conditioning_data: TextConditioningData,
|
|
ip_adapter_data: Optional[list[IPAdapterData]],
|
|
step_index: int,
|
|
total_step_count: int,
|
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
|
):
|
|
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
|
slower execution speed.
|
|
"""
|
|
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
|
# and T2I-Adapter residuals into two chunks.
|
|
uncond_down_block, cond_down_block = None, None
|
|
if down_block_additional_residuals is not None:
|
|
uncond_down_block, cond_down_block = [], []
|
|
for down_block in down_block_additional_residuals:
|
|
_uncond_down, _cond_down = down_block.chunk(2)
|
|
uncond_down_block.append(_uncond_down)
|
|
cond_down_block.append(_cond_down)
|
|
|
|
uncond_down_intrablock, cond_down_intrablock = None, None
|
|
if down_intrablock_additional_residuals is not None:
|
|
uncond_down_intrablock, cond_down_intrablock = [], []
|
|
for down_intrablock in down_intrablock_additional_residuals:
|
|
_uncond_down, _cond_down = down_intrablock.chunk(2)
|
|
uncond_down_intrablock.append(_uncond_down)
|
|
cond_down_intrablock.append(_cond_down)
|
|
|
|
uncond_mid_block, cond_mid_block = None, None
|
|
if mid_block_additional_residual is not None:
|
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
|
|
|
#####################
|
|
# Unconditioned pass
|
|
#####################
|
|
|
|
cross_attention_kwargs = {}
|
|
|
|
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
|
if ip_adapter_data is not None:
|
|
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
|
image_prompt_embeds = [
|
|
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
|
for ipa_conditioning in ip_adapter_conditioning
|
|
]
|
|
|
|
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
|
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
|
regional_ip_data = RegionalIPData(
|
|
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
|
)
|
|
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
|
|
|
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
|
added_cond_kwargs = None
|
|
if conditioning_data.is_sdxl():
|
|
added_cond_kwargs = {
|
|
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
|
|
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
|
}
|
|
|
|
# Prepare prompt regions for the unconditioned pass.
|
|
if conditioning_data.uncond_regions is not None:
|
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
|
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
|
|
)
|
|
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
|
|
|
# Run unconditioned UNet denoising (i.e. negative prompt).
|
|
unconditioned_next_x = self.model_forward_callback(
|
|
x,
|
|
sigma,
|
|
conditioning_data.uncond_text.embeds,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
down_block_additional_residuals=uncond_down_block,
|
|
mid_block_additional_residual=uncond_mid_block,
|
|
down_intrablock_additional_residuals=uncond_down_intrablock,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
)
|
|
|
|
###################
|
|
# Conditioned pass
|
|
###################
|
|
|
|
cross_attention_kwargs = {}
|
|
|
|
if ip_adapter_data is not None:
|
|
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
|
image_prompt_embeds = [
|
|
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
|
for ipa_conditioning in ip_adapter_conditioning
|
|
]
|
|
|
|
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
|
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
|
regional_ip_data = RegionalIPData(
|
|
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
|
)
|
|
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
|
|
|
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
|
added_cond_kwargs = None
|
|
if conditioning_data.is_sdxl():
|
|
added_cond_kwargs = {
|
|
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
|
}
|
|
|
|
# Prepare prompt regions for the conditioned pass.
|
|
if conditioning_data.cond_regions is not None:
|
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
|
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
|
|
)
|
|
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
|
|
|
# Run conditioned UNet denoising (i.e. positive prompt).
|
|
conditioned_next_x = self.model_forward_callback(
|
|
x,
|
|
sigma,
|
|
conditioning_data.cond_text.embeds,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
down_block_additional_residuals=cond_down_block,
|
|
mid_block_additional_residual=cond_mid_block,
|
|
down_intrablock_additional_residuals=cond_down_intrablock,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
)
|
|
return unconditioned_next_x, conditioned_next_x
|
|
|
|
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
|
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
|
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
|
combined_next_x = unconditioned_next_x + scaled_delta
|
|
return combined_next_x
|