Compare commits

..

51 Commits

Author SHA1 Message Date
Ryan Dick
6bcf48aa37 WIP - Started working towards MultiDiffusion batching. 2024-06-18 15:44:39 -04:00
Ryan Dick
b1bb1511fe Delete rough notes. 2024-06-18 15:36:36 -04:00
Ryan Dick
99046a8145 Fix advanced scheduler behaviour in MultiDiffusionPipeline. 2024-06-18 15:36:36 -04:00
Ryan Dick
72be7e71e3 Fix handling of stateful schedulers in MultiDiffusionPipeline. 2024-06-18 15:36:36 -04:00
Ryan Dick
35adaf1c17 Connect TiledMultiDiffusionDenoiseLatents to the MultiDiffusionPipeline backend. 2024-06-18 15:36:34 -04:00
Ryan Dick
865c2335de Remove regional conditioning logic from MultiDiffusionPipeline - it is not yet supported. 2024-06-18 15:35:52 -04:00
Ryan Dick
49ca42f84a Initial (untested) implementation of MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
Ryan Dick
493fcd8660 Remove inpainting support from MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
Ryan Dick
20322d781e Remove IP-Adapter and T2I-Adapter support from MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
Ryan Dick
889d13e02a Document plan for the rest of the MultiDiffusion implementation. 2024-06-18 15:35:52 -04:00
Ryan Dick
6ccd2a867b Add detailed docstring to latents_from_embeddings(). 2024-06-18 15:35:52 -04:00
Ryan Dick
5861fa1719 Copy StableDiffusionGeneratorPipeline as a starting point for a new MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
Ryan Dick
dfd4beb62b Simplify handling of inpainting models. Improve the in-code documentation around inpainting. 2024-06-18 15:35:52 -04:00
Ryan Dick
83df0c0df5 Minor tidying of latents_from_embeddings(...). 2024-06-18 15:35:52 -04:00
Ryan Dick
c58c4069a7 Consolidate latents_from_embeddings(...) and generate_latents_from_embeddings(...) into a single function. 2024-06-18 15:35:52 -04:00
Ryan Dick
3937fffa94 Fix invocation name of tiled_multi_diffusion_denoise_latents. 2024-06-18 15:35:52 -04:00
Ryan Dick
bbf5f67691 Improve clarity of comments regarded when 'noise' and 'latents' are expected to be set. 2024-06-18 15:35:52 -04:00
Ryan Dick
2f5c147b84 Fix static check errors on imports in diffusers_pipeline.py. 2024-06-18 15:35:52 -04:00
Ryan Dick
bd2839b748 Remove a condition for handling inpainting models that never resolves to True. The same logic is already applied earlier by AddsMaskLatents. 2024-06-18 15:35:52 -04:00
Ryan Dick
4f70dd7ce1 Add clarifying comment to explain why noise might be None in latents_from_embedding(). 2024-06-18 15:35:52 -04:00
Ryan Dick
066672fbfd Remove unused are_like_tensors() function. 2024-06-18 15:35:52 -04:00
Ryan Dick
abefaee4d1 Remove unused StableDiffusionGeneratorPipeline.use_ip_adapter member. 2024-06-18 15:35:52 -04:00
Ryan Dick
3254ba5904 Remove unused StableDiffusionGeneratorPipeline.control_model. 2024-06-18 15:35:52 -04:00
Ryan Dick
73a8c55852 Stricter typing for the is_gradient_mask: bool. 2024-06-18 15:35:52 -04:00
Ryan Dick
f82af7c22d Fix typing of control_data to reflect that it can be None. 2024-06-18 15:35:52 -04:00
Ryan Dick
3aef717ef4 Fix typing of timesteps and init_timestep. 2024-06-18 15:35:52 -04:00
Ryan Dick
c2cf1137e9 Fix typing to reflect that the callback arg to latents_from_embeddings is never None. 2024-06-18 15:35:52 -04:00
Ryan Dick
803a24bc0a Move seed above optional params. 2024-06-18 15:35:52 -04:00
Ryan Dick
7d24ad8ccd Simplify handling of AddsMaskGuidance, and fix some related type errors. 2024-06-18 15:35:52 -04:00
Ryan Dick
cb389063b2 Remove unused num_inference_steps. 2024-06-18 15:35:52 -04:00
Ryan Dick
81b8a69e1a WIP TiledMultiDiffusionDenoiseLatents. Updated parameter list and first half of the logic. 2024-06-18 15:35:50 -04:00
Ryan Dick
7ee5db87ad Tidy DenoiseLatentsInvocation.prep_control_data(...) and fix some type errors. 2024-06-18 15:34:30 -04:00
Ryan Dick
66cf2c59bd Make DenoiseLatentsInvocation.prep_control_data(...) a staticmethod so that it can be called externally. 2024-06-18 15:34:30 -04:00
Ryan Dick
3bad1367e9 Copy TiledStableDiffusionRefineInvocation as a starting point for TiledMultiDiffusionDenoiseLatents.py 2024-06-18 15:34:22 -04:00
Ryan Dick
867a7642a6 Change tiling strategy to make TiledStableDiffusionRefineInvocation work with more tile shapes and overlaps. 2024-06-18 15:31:58 -04:00
Ryan Dick
d9d1c8f9cb Expose a few more params from TiledStableDiffusionRefineInvocation. 2024-06-18 15:31:58 -04:00
Ryan Dick
e03eb7fb45 Add support for LoRA models in TiledStableDiffusionRefineInvocation. 2024-06-18 15:31:58 -04:00
Ryan Dick
85db33bc7e Add naive ControlNet support to TiledStableDiffusionRefineInvocation 2024-06-18 15:31:58 -04:00
Ryan Dick
93e3a2b504 Fix ControlNetModel type hint import source. 2024-06-18 15:31:58 -04:00
Ryan Dick
6a7a26f1bf Rough prototype of TiledStableDiffusionRefineInvocation is working. 2024-06-18 15:31:58 -04:00
Ryan Dick
08ca03ef9f WIP - TiledStableDiffusionRefine 2024-06-18 15:31:54 -04:00
Ryan Dick
ccf90b6bd6 Minor improvements to LatentsToImageInvocation type hints. 2024-06-18 15:31:21 -04:00
Ryan Dick
753239b48d Expose vae_decode(...) as a staticmethod on LatentsToImageInvocation. 2024-06-18 15:31:21 -04:00
Ryan Dick
65fa4664c9 Fix return type of prepare_noise_and_latents(...). 2024-06-18 15:31:21 -04:00
Ryan Dick
297570ded3 Make init_scheduler() a staticmethod on DenoiseLatentsInvocation so that it can be called externally. 2024-06-18 15:31:21 -04:00
Ryan Dick
680fdcf293 Only allow a single positive/negative prompt conditioning input for tiled refine. 2024-06-18 15:31:21 -04:00
Ryan Dick
5ff91f2c44 WIP on TiledStableDiffusionRefine 2024-06-18 15:31:14 -04:00
Ryan Dick
69aa7057e7 Convert several methods in DenoiseLatentsInvocation to staticmethods so that they can be called externally. 2024-06-18 15:25:08 -04:00
Ryan Dick
d3932f40de Simplify the logic in prepare_noise_and_latents(...). 2024-06-18 15:25:08 -04:00
Ryan Dick
ee74cd7fab Split out the prepare_noise_and_latents(...) logic in DenoiseLatentsInvocation so that it can be called from other invocations. 2024-06-18 15:25:08 -04:00
Ryan Dick
bda25b40c9 (minor) Add a TODO note to get_scheduler(...). 2024-06-18 15:25:08 -04:00
40 changed files with 1555 additions and 812 deletions

View File

@@ -1328,7 +1328,7 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
config = InvokeAIAppConfig.get_config()
ram_cache = ModelCache(
max_cache_size=config.ram_cache_size, logger=logger
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
)
convert_cache = ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size

View File

@@ -103,7 +103,6 @@ class CompelInvocation(BaseInvocation):
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(self.prompt)
@@ -118,7 +117,6 @@ class CompelInvocation(BaseInvocation):
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
@@ -205,7 +203,6 @@ class SDXLPromptInvocationBase:
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(prompt)
@@ -316,6 +313,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
)
]
)
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(

View File

@@ -1,5 +1,4 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import copy
import inspect
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@@ -56,6 +55,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
)
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings
@@ -66,6 +66,9 @@ def get_scheduler(
scheduler_name: str,
seed: int,
) -> Scheduler:
"""Load a scheduler and apply some scheduler-specific overrides."""
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
@@ -183,8 +186,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def _get_text_embeddings_and_masks(
self,
cond_list: list[ConditioningField],
context: InvocationContext,
device: torch.device,
@@ -194,8 +197,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
cond_data = context.conditioning.load(cond.conditioning_name)
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
@@ -203,8 +207,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
return text_embeddings, text_embeddings_masks
@staticmethod
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
@@ -226,11 +231,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
assert isinstance(resized_mask, torch.Tensor)
return resized_mask
@staticmethod
def _concat_regional_text_embeddings(
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
@@ -280,7 +284,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
)
processed_masks.append(
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
mask, latent_height, latent_width, dtype=dtype
)
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
@@ -302,36 +308,41 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
return BasicConditioningInfo(embeds=text_embedding), regions
@staticmethod
def get_conditioning_data(
self,
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
) -> TextConditioningData:
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
cond_list = self.positive_conditioning
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
cond_list = positive_conditioning_field
if not isinstance(cond_list, list):
cond_list = [cond_list]
uncond_list = self.negative_conditioning
uncond_list = negative_conditioning_field
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
)
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
@@ -339,23 +350,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype=unet.dtype,
)
if isinstance(self.cfg_scale, list):
assert (
len(self.cfg_scale) == self.steps
), "cfg_scale (list) must have the same length as the number of steps"
if isinstance(cfg_scale, list):
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
guidance_scale=cfg_scale,
guidance_rescale_multiplier=cfg_rescale_multiplier,
)
return conditioning_data
@staticmethod
def create_pipeline(
self,
unet: UNet2DConditionModel,
scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline:
@@ -378,38 +387,38 @@ class DenoiseLatentsInvocation(BaseInvocation):
requires_safety_checker=False,
)
@staticmethod
def prep_control_data(
self,
context: InvocationContext,
control_input: Optional[Union[ControlField, List[ControlField]]],
control_input: ControlField | list[ControlField] | None,
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> Optional[List[ControlNetData]]:
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
if control_input is None:
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
control_list = None
elif isinstance(control_input, ControlField):
) -> list[ControlNetData] | None:
# Normalize control_input to a list.
control_list: list[ControlField]
if isinstance(control_input, ControlField):
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
elif isinstance(control_input, list):
control_list = control_input
elif control_input is None:
control_list = []
else:
control_list = None
if control_list is None:
return None
# After above handling, any control that is not None should now be of type list[ControlField].
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
controlnet_data = []
if len(control_list) == 0:
return None
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
_, _, latent_height, latent_width = latents_shape
control_height_resize = latent_height * LATENT_SCALE_FACTOR
control_width_resize = latent_width * LATENT_SCALE_FACTOR
controlnet_data: list[ControlNetData] = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
assert isinstance(control_model, ControlNetModel)
# control_models.append(control_model)
control_image_field = control_info.image
input_image = context.images.get_pil(control_image_field.image_name)
# self.image.image_type, self.image.image_name
@@ -430,7 +439,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
@@ -584,15 +593,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
@staticmethod
def init_scheduler(
self,
scheduler: Union[Scheduler, ConfigMixin],
device: torch.device,
steps: int,
denoising_start: float,
denoising_end: float,
seed: int,
) -> Tuple[int, List[int], int, Dict[str, Any]]:
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu")
@@ -618,7 +627,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order
scheduler_step_kwargs: Dict[str, Any] = {}
scheduler_step_signature = inspect.signature(scheduler.step)
@@ -640,7 +648,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if isinstance(scheduler, TCDScheduler):
scheduler_step_kwargs.update({"eta": 1.0})
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
return timesteps, init_timestep, scheduler_step_kwargs
def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor
@@ -657,31 +665,52 @@ class DenoiseLatentsInvocation(BaseInvocation):
return 1 - mask, masked_latents, self.denoise_mask.gradient
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed = None
@staticmethod
def prepare_noise_and_latents(
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
) -> Tuple[int, torch.Tensor | None, torch.Tensor]:
"""Depending on the workflow, we expect different combinations of noise and latents to be provided. This
function handles preparing these values accordingly.
Expected workflows:
- Text-to-Image Denoising: `noise` is provided, `latents` is not. `latents` is initialized to zeros.
- Image-to-Image Denoising: `noise` and `latents` are both provided.
- Text-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
- Image-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
NOTE(ryand): I wrote this docstring, but I am not the original author of this code. There may be other workflows
I haven't considered.
"""
noise = None
if self.noise is not None:
noise = context.tensors.load(self.noise.latents_name)
seed = self.noise.seed
if self.latents is not None:
latents = context.tensors.load(self.latents.latents_name)
if seed is None:
seed = self.latents.seed
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
if noise_field is not None:
noise = context.tensors.load(noise_field.latents_name)
if latents_field is not None:
latents = context.tensors.load(latents_field.latents_name)
elif noise is not None:
latents = torch.zeros_like(noise)
else:
raise Exception("'latents' or 'noise' must be provided!")
raise ValueError("'latents' or 'noise' must be provided!")
if seed is None:
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise ValueError(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
seed = 0
if noise_field is not None and noise_field.seed is not None:
seed = noise_field.seed
elif latents_field is not None and latents_field.seed is not None:
seed = latents_field.seed
else:
seed = 0
return seed, noise, latents
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
@@ -755,7 +784,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = self.prep_control_data(
@@ -777,7 +814,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype=unet.dtype,
)
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
@@ -794,8 +831,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed=seed,
mask=mask,
masked_latents=masked_latents,
gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps,
is_gradient_mask=gradient_mask,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=controlnet_data,

View File

@@ -8,7 +8,7 @@ from diffusers.models.attention_processor import (
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
@@ -23,6 +23,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.util.devices import TorchDevice
@@ -48,16 +49,20 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
@staticmethod
def vae_decode(
context: InvocationContext,
vae_info: LoadedModel,
seamless_axes: list[str],
latents: torch.Tensor,
use_fp32: bool,
use_tiling: bool,
) -> Image.Image:
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:
if use_fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
@@ -82,7 +87,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.to(dtype=torch.float16)
latents = latents.half()
if self.tiled or context.config.get().force_tiled_decode:
if use_tiling or context.config.get().force_tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
@@ -102,6 +107,21 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
TorchDevice.empty_cache()
return image
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
image = self.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=latents,
use_fp32=self.fp32,
use_tiling=self.tiled,
)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,268 @@
import copy
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
LatentsField,
UIType,
)
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
MultiDiffusionRegionConditioning,
)
from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap,
)
from invokeai.backend.tiles.utils import TBLR
from invokeai.backend.util.devices import TorchDevice
def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
"""Crop a ControlNetData object to a region."""
# Create a shallow copy of the control_data object.
control_data_copy = copy.copy(control_data)
# The ControlNet reference image is the only attribute that needs to be cropped.
control_data_copy.image_tensor = control_data.image_tensor[
:,
:,
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
]
return control_data_copy
@invocation(
"tiled_multi_diffusion_denoise_latents",
title="Tiled Multi-Diffusion Denoise Latents",
tags=["upscale", "denoise"],
category="latents",
# TODO(ryand): Reset to 1.0.0 right before release.
version="1.0.0",
)
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
"""Tiled Multi-Diffusion denoising.
This node handles automatically tiling the input image. Future iterations of
this node should allow the user to specify custom regions with different parameters for each region to harness the
full power of Multi-Diffusion.
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
T2I-Adapter, masking, etc.).
"""
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
noise: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
)
latents: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
# TODO(ryand): Add multiple-of validation.
# TODO(ryand): Smaller defaults might make more sense.
tile_height: int = InputField(default=112, gt=0, description="Height of the tiles in latent space.")
tile_width: int = InputField(default=112, gt=0, description="Width of the tiles in latent space.")
tile_min_overlap: int = InputField(
default=16,
gt=0,
description="The minimum overlap between adjacent tiles in latent space. The actual overlap may be larger than "
"this to evenly cover the entire image.",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
# TODO(ryand): The default here should probably be 0.0.
denoising_start: float = InputField(
default=0.65,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
control: ControlField | list[ControlField] | None = InputField(
default=None,
input=Input.Connection,
)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def create_pipeline(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
) -> MultiDiffusionPipeline:
# TODO(ryand): Get rid of this FakeVae hack.
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
return MultiDiffusionPipeline(
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# Calculate the tile locations to cover the latent-space image.
# TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
min_overlap=self.tile_min_overlap,
)
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=self.tile_height,
latent_width=self.tile_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = DenoiseLatentsInvocation.prep_control_data(
context=context,
control_input=self.control,
latents_shape=list(latents.shape),
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# Split the controlnet_data into tiles.
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
controlnet_data_tiles: list[list[ControlNetData]] = []
for tile in tiles:
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
controlnet_data_tiles.append(tile_controlnet_data)
# Prepare the MultiDiffusionRegionConditioning list.
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
multi_diffusion_conditioning.append(
MultiDiffusionRegionConditioning(
region=tile.coords,
text_conditioning_data=conditioning_data,
control_data=tile_controlnet_data,
)
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
# Run Multi-Diffusion denoising.
result_latents = pipeline.multi_diffusion_denoise(
multi_diffusion_conditioning=multi_diffusion_conditioning,
latents=latents,
scheduler_step_kwargs=scheduler_step_kwargs,
noise=noise,
timesteps=timesteps,
init_timestep=init_timestep,
# TODO(ryand): Add proper callback.
callback=lambda x: None,
)
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
result_latents = result_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

View File

@@ -0,0 +1,380 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
import numpy as np
import numpy.typing as npt
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
ImageField,
Input,
InputField,
UIType,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField
from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.utils import Tile
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@invocation(
"tiled_stable_diffusion_refine",
title="Tiled Stable Diffusion Refine",
tags=["upscale", "denoise"],
category="latents",
version="1.0.0",
)
class TiledStableDiffusionRefineInvocation(BaseInvocation):
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
"""
image: ImageField = InputField(description="Image to be refined.")
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
# TODO(ryand): Add multiple-of validation.
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.")
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.")
tile_overlap: int = InputField(
default=16,
gt=0,
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField(
default=0.65,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
vae_fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
)
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly
# don't want to use the image field. Figure out how best to handle this.
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
# etc.
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float = InputField(default=0.6)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
"""
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
if coord % LATENT_SCALE_FACTOR != 0:
raise ValueError(
f"The tile coordinates must all be divisible by the latent scale factor"
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
)
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
top = image_tile.coords.top // LATENT_SCALE_FACTOR
left = image_tile.coords.left // LATENT_SCALE_FACTOR
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
right = image_tile.coords.right // LATENT_SCALE_FACTOR
return latents[..., top:bottom, left:right]
def run_controlnet(
self,
image: Image.Image,
controlnet_model: ControlNetModel,
weight: float,
do_classifier_free_guidance: bool,
width: int,
height: int,
device: torch.device,
dtype: torch.dtype,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
) -> ControlNetData:
control_image = prepare_control_image(
image=image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
device=device,
dtype=dtype,
control_mode=control_mode,
resize_mode=resize_mode,
)
return ControlNetData(
model=controlnet_model,
image_tensor=control_image,
weight=weight,
begin_step_percent=0.0,
end_step_percent=1.0,
control_mode=control_mode,
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
# ControlNetData in case needed in the future.
resize_mode=resize_mode,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# TODO(ryand): Expose the seed parameter.
seed = 0
# Load the input image.
input_image = context.images.get_pil(self.image.image_name)
# Calculate the tile locations to cover the image.
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This
# facilitates conversions between image space and latent space.
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
tiles = calc_tiles_with_overlap(
image_height=input_image.height,
image_width=input_image.width,
tile_height=self.tile_height,
tile_width=self.tile_width,
overlap=self.tile_overlap,
)
# Convert the input image to a torch.Tensor.
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
# Validate our assumptions about the shape of input_image_torch.
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
assert input_image_torch.shape[:2] == (1, 3)
# Split the input image into tiles in torch.Tensor format.
image_tiles_torch: list[torch.Tensor] = []
for tile in tiles:
image_tile = input_image_torch[
:,
:,
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
]
image_tiles_torch.append(image_tile)
# Split the input image into tiles in numpy format.
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
# with torch.Tensor tiles.
input_image_np = np.array(input_image)
image_tiles_np: list[npt.NDArray[np.uint8]] = []
for tile in tiles:
image_tile_np = input_image_np[
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
:,
]
image_tiles_np.append(image_tile_np)
# VAE-encode each image tile independently.
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
# about for decoding?
vae_info = context.models.load(self.vae.vae)
latent_tiles: list[torch.Tensor] = []
for image_tile_torch in image_tiles_torch:
latent_tiles.append(
ImageToLatentsInvocation.vae_encode(
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch
)
)
# Generate noise with dimensions corresponding to the full image in latent space.
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
# noise.
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
global_noise = get_noise(
width=input_image_torch.shape[3],
height=input_image_torch.shape[2],
device=TorchDevice.choose_torch_device(),
seed=seed,
downsampling_factor=LATENT_SCALE_FACTOR,
use_cpu=True,
)
# Crop the global noise into tiles.
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles]
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
refined_latent_tiles: list[torch.Tensor] = []
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
# Assume that all tiles have the same shape.
_, _, latent_height, latent_width = latent_tiles[0].shape
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
# Load the ControlNet model.
# TODO(ryand): Support multiple ControlNet models.
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model))
assert isinstance(controlnet_model, ControlNetModel)
# Denoise (i.e. "refine") each tile independently.
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
assert latent_tile.shape == noise_tile.shape
# Prepare a PIL Image for ControlNet processing.
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
image_tile_pil = Image.fromarray(image_tile_np)
# Run the ControlNet on the image tile.
height, width, _ = image_tile_np.shape
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
# validate this assumption here.
assert height % LATENT_SCALE_FACTOR == 0
assert width % LATENT_SCALE_FACTOR == 0
controlnet_data = self.run_controlnet(
image=image_tile_pil,
controlnet_model=controlnet_model,
weight=self.control_weight,
do_classifier_free_guidance=True,
width=width,
height=height,
device=controlnet_model.device,
dtype=controlnet_model.dtype,
control_mode="balanced",
resize_mode="just_resize_simple",
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
refined_latent_tile = pipeline.latents_from_embeddings(
latents=latent_tile,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise_tile,
seed=seed,
mask=None,
masked_latents=None,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=[controlnet_data],
ip_adapter_data=None,
t2i_adapter_data=None,
callback=lambda x: None,
)
refined_latent_tiles.append(refined_latent_tile)
# VAE-decode each refined latent tile independently.
refined_image_tiles: list[Image.Image] = []
for refined_latent_tile in refined_latent_tiles:
refined_image_tile = LatentsToImageInvocation.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=refined_latent_tile,
use_fp32=self.vae_fp32,
use_tiling=False,
)
refined_image_tiles.append(refined_image_tile)
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
# Merge the refined image tiles back into a single image.
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
# TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
merge_tiles_with_linear_blending(
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap
)
# Save the refined image and return its reference.
merged_image_pil = Image.fromarray(merged_image_np)
image_dto = context.images.save(image=merged_image_pil)
return ImageOutput.build(image_dto)

View File

@@ -26,13 +26,13 @@ LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.2"
CONFIG_SCHEMA_VERSION = "4.0.1"
def get_default_ram_cache_size() -> float:
@@ -105,17 +105,14 @@ class InvokeAIAppConfig(BaseSettings):
convert_cache: Maximum size of on-disk converted models cache (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps`
devices: List of execution devices; will override default device selected.
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
clear_queue_on_startup: Empties session queue on startup.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory.
@@ -180,7 +177,6 @@ class InvokeAIAppConfig(BaseSettings):
# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.")
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
# GENERATION
@@ -190,8 +186,6 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
# NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
@@ -380,6 +374,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
@@ -427,27 +424,6 @@ def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig
return config
def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.1 config dictionary to a current config object.
A few new multi-GPU options were added in 4.0.2, and this simply
updates the schema label.
Args:
config_dict: A dictionary of settings from a v4.0.1 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for k, _ in config_dict.items():
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
# TO DO: replace this with a formal registration and migration system
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
@@ -479,10 +455,6 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
elif loaded_config_dict["schema_version"] == "4.0.1":
loaded_config_dict = migrate_v4_0_1_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately

View File

@@ -53,11 +53,11 @@ class InvocationServices:
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase",
names: "NameServiceBase",
performance_statistics: "InvocationStatsServiceBase",
urls: "UrlServiceBase",
workflow_records: "WorkflowRecordsStorageBase",
tensors: "ObjectSerializerBase[torch.Tensor]",
@@ -77,11 +77,11 @@ class InvocationServices:
self.model_images = model_images
self.model_manager = model_manager
self.download_queue = download_queue
self.performance_statistics = performance_statistics
self.session_queue = session_queue
self.session_processor = session_processor
self.invocation_cache = invocation_cache
self.names = names
self.performance_statistics = performance_statistics
self.urls = urls
self.workflow_records = workflow_records
self.tensors = tensors

View File

@@ -74,9 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def reset_stats(self, graph_execution_state_id: str):
self._stats.pop(graph_execution_state_id)
self._cache_stats.pop(graph_execution_state_id)
def reset_stats(self):
self._stats = {}
self._cache_stats = {}
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

View File

@@ -284,14 +284,9 @@ class ModelInstallService(ModelInstallServiceBase):
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
self._install_jobs = unfinished_jobs
def _migrate_yaml(self, rename_yaml: Optional[bool] = True, overwrite_db: Optional[bool] = False) -> None:
def _migrate_yaml(self) -> None:
db_models = self.record_store.all_models()
if overwrite_db:
for model in db_models:
self.record_store.del_model(model.key)
db_models = self.record_store.all_models()
legacy_models_yaml_path = (
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
)
@@ -341,8 +336,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
if rename_yaml:
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
# Unset the path - we are done with it either way
self._app_config.legacy_models_yaml_path = None

View File

@@ -33,11 +33,6 @@ class ModelLoadServiceBase(ABC):
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
@property
@abstractmethod
def gpu_count(self) -> int:
"""Return the number of GPUs we are configured to use."""
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None

View File

@@ -46,7 +46,6 @@ class ModelLoadService(ModelLoadServiceBase):
self._registry = registry
def start(self, invoker: Invoker) -> None:
"""Start the service."""
self._invoker = invoker
@property
@@ -54,11 +53,6 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader."""
return self._ram_cache
@property
def gpu_count(self) -> int:
"""Return the number of GPUs available for our uses."""
return len(self._ram_cache.execution_devices)
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod
from typing import Optional, Set
import torch
from typing_extensions import Self
@@ -32,7 +31,7 @@ class ModelManagerServiceBase(ABC):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_devices: Optional[Set[torch.device]] = None,
execution_device: torch.device,
) -> Self:
"""
Construct the model manager service instance.

View File

@@ -1,10 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
@@ -65,6 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: Optional[torch.device] = None,
) -> Self:
"""
Construct the model manager service instance.
@@ -77,7 +82,9 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache(
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(

View File

@@ -1,6 +1,5 @@
import shutil
import tempfile
import threading
import typing
from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypeVar
@@ -10,7 +9,6 @@ import torch
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker
@@ -72,10 +70,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
return self._output_dir / name
def _new_name(self) -> str:
tid = threading.current_thread().ident
# Add tid to the object name because uuid4 not thread-safe on windows
# See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
return f"{self._obj_class_name}_{tid}-{uuid_string()}"
return f"{self._obj_class_name}_{uuid_string()}"
def _tempdir_cleanup(self) -> None:
"""Calls `cleanup` on the temporary directory, if it exists."""

View File

@@ -1,9 +1,8 @@
import traceback
from contextlib import suppress
from queue import Queue
from threading import BoundedSemaphore, Lock, Thread
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional, Set
from typing import Optional
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import (
@@ -27,7 +26,6 @@ from invokeai.app.services.session_queue.session_queue_common import SessionQueu
from invokeai.app.services.shared.graph import NodeInputError
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
from invokeai.backend.util.devices import TorchDevice
from ..invoker import Invoker
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
@@ -59,11 +57,8 @@ class DefaultSessionRunner(SessionRunnerBase):
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
self._on_node_error_callbacks = on_node_error_callbacks or []
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
self._process_lock = Lock()
def start(
self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
) -> None:
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
self._services = services
self._cancel_event = cancel_event
self._profiler = profiler
@@ -81,8 +76,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# Loop over invocations until the session is complete or canceled
while True:
try:
with self._process_lock:
invocation = queue_item.session.next()
invocation = queue_item.session.next()
# Anything other than a `NodeInputError` is handled as a processor error
except NodeInputError as e:
error_type = e.__class__.__name__
@@ -114,7 +108,7 @@ class DefaultSessionRunner(SessionRunnerBase):
self._on_after_run_session(queue_item=queue_item)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
try:
# Any unhandled exception in this scope is an invocation error & will fail the graph
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
@@ -216,7 +210,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item)
@@ -330,7 +324,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker) -> None:
self._invoker: Invoker = invoker
self._active_queue_items: Set[SessionQueueItem] = set()
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
self._resume_event = ThreadEvent()
@@ -356,14 +350,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None
)
self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
TorchDevice.execution_devices()
)
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
# Session processor - singlethreaded
self._thread = Thread(
name="session_processor",
target=self._process,
@@ -376,16 +363,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
)
self._thread.start()
# Session processor workers - multithreaded
self._invoker.services.logger.debug(f"Starting {self._worker_thread_count} session processing threads.")
for _i in range(0, self._worker_thread_count):
worker = Thread(
name="session_worker",
target=self._process_next_session,
daemon=True,
)
worker.start()
def stop(self, *args, **kwargs) -> None:
self._stop_event.set()
@@ -393,7 +370,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event.set()
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
if any(item.queue_id == event[1].queue_id for item in self._active_queue_items):
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
self._cancel_event.set()
self._poll_now()
@@ -401,7 +378,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now()
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
if self._active_queue_items and event[1].status in ["completed", "failed", "canceled"]:
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
@@ -426,7 +403,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus(
is_started=self._resume_event.is_set(),
is_processing=len(self._active_queue_items) > 0,
is_processing=self._queue_item is not None,
)
def _process(
@@ -451,22 +428,30 @@ class DefaultSessionProcessor(SessionProcessorBase):
resume_event.wait()
# Get the next session to process
queue_item = self._invoker.services.session_queue.dequeue()
self._queue_item = self._invoker.services.session_queue.dequeue()
if queue_item is None:
if self._queue_item is None:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
self._session_worker_queue.put(queue_item)
self._invoker.services.logger.debug(f"Scheduling queue item {queue_item.item_id} to run")
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# Run the graph
# self.session_runner.run(queue_item=self._queue_item)
self.session_runner.run(queue_item=self._queue_item)
except Exception:
except Exception as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_non_fatal_processor_error(
queue_item=self._queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
# Wait for next polling interval or event to try again
poll_now_event.wait(self._polling_interval)
continue
@@ -481,25 +466,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
finally:
stop_event.clear()
poll_now_event.clear()
self._queue_item = None
self._thread_semaphore.release()
def _process_next_session(self) -> None:
while True:
self._resume_event.wait()
queue_item = self._session_worker_queue.get()
if queue_item.status == "canceled":
continue
try:
self._active_queue_items.add(queue_item)
# reserve a GPU for this session - may block
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
# Run the session on the reserved GPU
self.session_runner.run(queue_item=queue_item)
except Exception:
continue
finally:
self._active_queue_items.remove(queue_item)
def _on_non_fatal_processor_error(
self,
queue_item: Optional[SessionQueueItem],

View File

@@ -236,9 +236,6 @@ class SessionQueueItemWithoutGraph(BaseModel):
}
)
def __hash__(self) -> int:
return self.item_id
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass

View File

@@ -37,14 +37,10 @@ class SqliteSessionQueue(SessionQueueBase):
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
self._set_in_progress_to_canceled()
if self.__invoker.services.configuration.clear_queue_on_startup:
clear_result = self.clear(DEFAULT_QUEUE_ID)
if clear_result.deleted > 0:
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
else:
prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()

View File

@@ -652,7 +652,7 @@ class Graph(BaseModel):
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list
if get_origin(input_field) is not list:
if get_origin(input_field) != list:
return False
# Validate that all outputs match the input type

View File

@@ -2,7 +2,6 @@ from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from torch import Tensor
@@ -27,13 +26,11 @@ from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
"""
The InvocationContext provides access to various services and data about the current invocation.
@@ -326,6 +323,7 @@ class ConditioningInterface(InvocationContextInterface):
Returns:
The loaded conditioning data.
"""
return self._services.conditioning.load(name)
@@ -559,28 +557,6 @@ class UtilInterface(InvocationContextInterface):
is_canceled=self.is_canceled,
)
def torch_device(self) -> torch.device:
"""
Return a torch device to use in the current invocation.
Returns:
A torch.device not currently in use by the system.
"""
ram_cache: "ModelCacheBase[AnyModel]" = self._services.model_manager.load.ram_cache
return ram_cache.get_execution_device()
def torch_dtype(self, device: Optional[torch.device] = None) -> torch.dtype:
"""
Return a precision type to use with the current invocation and torch device.
Args:
device: Optional device.
Returns:
A torch.dtype suited for the current device.
"""
return TorchDevice.choose_torch_dtype(device)
class InvocationContext:
"""Provides access to various services and data for the current invocation.

View File

@@ -289,7 +289,7 @@ def prepare_control_image(
width: int,
height: int,
num_channels: int = 3,
device: str = "cuda",
device: str | torch.device = "cuda",
dtype: torch.dtype = torch.float16,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
@@ -304,7 +304,7 @@ def prepare_control_image(
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
device (str, optional): The target device for the output image. Defaults to "cuda".
device (str | torch.Device, optional): The target device for the output image. Defaults to "cuda".
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
Defaults to True.

View File

@@ -25,7 +25,6 @@ from enum import Enum
from typing import Literal, Optional, Type, TypeAlias, Union
import torch
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
@@ -38,7 +37,7 @@ from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ConfigMixin, ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
class InvalidModelConfigException(Exception):
@@ -178,7 +177,6 @@ class ModelConfigBase(BaseModel):
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
"""Extend the pydantic schema from a json."""
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
@@ -445,7 +443,7 @@ class ModelConfigFactory(object):
model = dest_class.model_validate(model_data)
else:
# mypy doesn't typecheck TypeAdapters well?
model = AnyModelConfigValidator.validate_python(model_data)
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
assert model is not None
if key:
model.key = key

View File

@@ -65,7 +65,8 @@ class LoadedModelWithoutConfig:
def __enter__(self) -> AnyModel:
"""Context entry."""
return self._locker.lock()
self._locker.lock()
return self.model
def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Context exit."""

View File

@@ -8,10 +8,9 @@ model will be cleared and (re)loaded from disk when next needed.
"""
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, Generator, Generic, Optional, Set, TypeVar
from typing import Dict, Generic, Optional, TypeVar
import torch
@@ -52,13 +51,44 @@ class CacheRecord(Generic[T]):
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Read-only copy of the model *without weights* residing in the "meta device"
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
key: str
size: int
model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0
def lock(self) -> None:
"""Lock this record."""
self._locks += 1
def unlock(self) -> None:
"""Unlock this record."""
self._locks -= 1
assert self._locks >= 0
@property
def locked(self) -> bool:
"""Return true if record is locked."""
return self._locks > 0
@dataclass
@@ -85,27 +115,14 @@ class ModelCacheBase(ABC, Generic[T]):
@property
@abstractmethod
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
pass
@contextmanager
@property
@abstractmethod
def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]:
"""Reserve an execution device (GPU) under the current thread id."""
pass
@abstractmethod
def get_execution_device(self) -> torch.device:
"""
Return an execution device that has been reserved for current thread.
Note that reservations are done using the current thread's TID.
It might be better to do this using the session ID, but that involves
too many detailed changes to model manager calls.
May generate a ValueError if no GPU has been reserved.
"""
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@property
@@ -114,6 +131,16 @@ class ModelCacheBase(ABC, Generic[T]):
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@abstractmethod
def offload_unlocked_models(self, size_required: int) -> None:
"""Offload from VRAM any models not actively in use."""
pass
@abstractmethod
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device."""
pass
@property
@abstractmethod
def stats(self) -> Optional[CacheStats]:
@@ -175,11 +202,6 @@ class ModelCacheBase(ABC, Generic[T]):
"""Return true if the model identified by key and submodel_type is in the cache."""
pass
@abstractmethod
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
"""Move a copy of the model into the indicated device and return it."""
pass
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""

View File

@@ -18,19 +18,17 @@ context. Use like this:
"""
import copy
import gc
import sys
import threading
from contextlib import contextmanager, suppress
import math
import time
from contextlib import suppress
from logging import Logger
from threading import BoundedSemaphore
from typing import Dict, Generator, List, Optional, Set
from typing import Dict, List, Optional
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -41,7 +39,9 @@ from .model_locker import ModelLocker
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
@@ -57,8 +57,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
@@ -66,19 +70,23 @@ class ModelCache(ModelCacheBase[AnyModel]):
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
self._storage_device: torch.device = storage_device
self._ram_lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
@@ -86,87 +94,25 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
# device to thread id
self._device_lock = threading.Lock()
self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()}
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
self.logger.info(
f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}"
)
@property
def logger(self) -> Logger:
"""Return the logger used by the cache."""
return self._logger
@property
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
return self._lazy_offloading
@property
def storage_device(self) -> torch.device:
"""Return the storage device (e.g. "CPU" for RAM)."""
return self._storage_device
@property
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
devices = self._execution_devices.keys()
return set(devices)
def get_execution_device(self) -> torch.device:
"""
Return an execution device that has been reserved for current thread.
Note that reservations are done using the current thread's TID.
It would be better to do this using the session ID, but that involves
too many detailed changes to model manager calls.
May generate a ValueError if no GPU has been reserved.
"""
current_thread = threading.current_thread().ident
assert current_thread is not None
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
if not assigned:
raise ValueError(f"No GPU has been reserved for the use of thread {current_thread}")
return assigned[0]
@contextmanager
def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]:
"""Reserve an execution device (e.g. GPU) for exclusive use by a generation thread.
Note that the reservation is done using the current thread's TID.
It would be better to do this using the session ID, but that involves
too many detailed changes to model manager calls.
"""
device = None
with self._device_lock:
current_thread = threading.current_thread().ident
assert current_thread is not None
# look for a device that has already been assigned to this thread
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
if assigned:
device = assigned[0]
# no device already assigned. Get one.
if device is None:
self._free_execution_device.acquire(timeout=timeout)
with self._device_lock:
free_device = [x for x, tid in self._execution_devices.items() if tid == 0]
self._execution_devices[free_device[0]] = current_thread
device = free_device[0]
# we are outside the lock region now
self.logger.info(f"{current_thread} Reserved torch device {device}")
# Tell TorchDevice to use this object to get the torch device.
TorchDevice.set_model_cache(self)
try:
yield device
finally:
with self._device_lock:
self.logger.info(f"{current_thread} Released torch device {device}")
self._execution_devices[device] = 0
self._free_execution_device.release()
torch.cuda.empty_cache()
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
return self._execution_device
@property
def max_cache_size(self) -> float:
@@ -211,16 +157,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
with self._ram_lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
size = calc_model_size_by_data(model)
self.make_room(size)
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
size = calc_model_size_by_data(model)
self.make_room(size)
cache_record = CacheRecord(key=key, model=model, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
def get(
self,
@@ -238,37 +184,36 @@ class ModelCache(ModelCacheBase[AnyModel]):
This may raise an IndexError if the model is not in the cache.
"""
with self._ram_lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats:
self.stats.hits += 1
else:
if self.stats:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
# more stats
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
self.stats.hits += 1
else:
if self.stats:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return ModelLocker(
cache=self,
cache_entry=cache_entry,
cache_entry = self._cached_models[key]
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return ModelLocker(
cache=self,
cache_entry=cache_entry,
)
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
@@ -280,34 +225,127 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
return model_key
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
"""Move a copy of the model into the indicated device and return it.
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.loaded:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
:param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError
"""
with self._ram_lock:
self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
source_device = cache_entry.device
# Some models don't have a state dictionary, in which case the
# stored model will still reside in CPU
if hasattr(cache_entry.model, "to"):
model_in_gpu = copy.deepcopy(cache_entry.model)
assert hasattr(model_in_gpu, "to")
model_in_gpu.to(target_device)
return model_in_gpu
else:
return cache_entry.model # what happens in CPU stays in CPU
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
# Some models don't have a `to` method, in which case they run in RAM/CPU.
if not hasattr(cache_entry.model, "to"):
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self.storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device, non_blocking=True)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.debug(
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % (self.cache_size() / GIG)
in_ram_models = len(self._cached_models)
self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}")
in_ram_models = 0
in_vram_models = 0
locked_in_vram_models = 0
for cache_record in self._cached_models.values():
if hasattr(cache_record.model, "device"):
if cache_record.model.device == self.storage_device:
in_ram_models += 1
else:
in_vram_models += 1
if cache_record.locked:
locked_in_vram_models += 1
self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
@@ -330,14 +368,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
)
refs = sys.getrefcount(cache_entry.model)
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if refs <= (3 if "onnx" in model_key else 2):
if not cache_entry.locked:
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
@@ -364,26 +400,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
if self.stats:
self.stats.cleared = models_cleared
gc.collect()
TorchDevice.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
if target_device.type != "cuda":
return
vram_device = ( # mem_get_info() needs an indexed device
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
)
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
if needed_size > free_mem:
raise torch.cuda.OutOfMemoryError
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
try:
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]
except ValueError:
pass
@staticmethod
def _device_name(device: torch.device) -> str:
return f"{device.type}:{device.index}"
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]

View File

@@ -10,8 +10,6 @@ from invokeai.backend.model_manager import AnyModel
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
@@ -31,29 +29,33 @@ class ModelLocker(ModelLockerBase):
"""Return the model without moving it around."""
return self._cache_entry.model
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return self._cache_entry.state_dict
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
self._cache_entry.lock()
try:
device = self._cache.get_execution_device()
model_on_device = self._cache.model_to_device(self._cache_entry, device)
self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}")
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
self._cache_entry.unlock()
raise
except Exception:
self._cache_entry.unlock()
raise
return model_on_device
return self.model
# It is no longer necessary to move the model out of VRAM
# because it will be removed when it goes out of scope
# in the caller's context
def unlock(self) -> None:
"""Call upon exit from context."""
self._cache.print_cuda_stats()
# This is no longer in use in MGPU.
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return None
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(0)
self._cache.print_cuda_stats()

View File

@@ -22,7 +22,8 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VAELoader(GenericDiffusersLoader):
"""Class to load VAE models."""
@@ -39,8 +40,12 @@ class VAELoader(GenericDiffusersLoader):
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
assert isinstance(config, CheckpointConfigBase)
config_file = self._app_config.legacy_conf_path / config.config_path
# TODO(MM2): check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"VAE conversion not supported for model type: {config.base}")
else:
assert isinstance(config, CheckpointConfigBase)
config_file = self._app_config.legacy_conf_path / config.config_path
if model_path.suffix == ".safetensors":
checkpoint = safetensors_load_file(model_path, device="cpu")

View File

@@ -451,16 +451,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
# VAEs of all base types have the same structure, so we wimp out and
# guess using the name.
for regexp, basetype in [
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
]:
if re.search(regexp, self.model_path.name, re.IGNORECASE):
return basetype
raise InvalidModelConfigException("Cannot determine base type")
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):

View File

@@ -4,7 +4,6 @@
from __future__ import annotations
import pickle
import threading
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
@@ -35,8 +34,6 @@ with LoRAHelper.apply_lora_unet(unet, loras):
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
_thread_lock = threading.Lock()
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
@@ -109,7 +106,7 @@ class ModelPatcher:
"""
original_weights = {}
try:
with torch.no_grad(), cls._thread_lock:
with torch.no_grad():
for lora, lora_weight in loras:
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
@@ -132,7 +129,9 @@ class ModelPatcher:
dtype = module.weight.dtype
if module_key not in original_weights:
if model_state_dict is None: # no CPU copy of the state dict was provided
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
original_weights[module_key] = model_state_dict[module_key + ".weight"]
else:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0

View File

@@ -10,12 +10,11 @@ import PIL.Image
import psutil
import torch
import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from diffusers.utils.import_utils import is_xformers_available
from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@@ -26,6 +25,7 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@dataclass
@@ -38,56 +38,18 @@ class PipelineIntermediateState:
predicted_original: Optional[torch.Tensor] = None
@dataclass
class AddsMaskLatents:
"""Add the channels required for inpainting model input.
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
and the latent encoding of the base image.
This class assumes the same mask and base image should apply to all items in the batch.
"""
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
mask: torch.Tensor
initial_image_latents: torch.Tensor
def __call__(
self,
latents: torch.Tensor,
t: torch.Tensor,
text_embeddings: torch.Tensor,
**kwargs,
) -> torch.Tensor:
model_input = self.add_mask_channels(latents)
return self.forward(model_input, t, text_embeddings, **kwargs)
def add_mask_channels(self, latents):
batch_size = latents.size(0)
# duplicate mask and latents for each batch
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
# add mask and image as additional channels
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
return model_input
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
return isinstance(b, torch.Tensor) and (a.size() == b.size())
@dataclass
class AddsMaskGuidance:
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
mask: torch.Tensor
mask_latents: torch.Tensor
scheduler: SchedulerMixin
noise: torch.Tensor
gradient_mask: bool
is_gradient_mask: bool
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return self.apply_mask(latents, t)
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
def apply_mask(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
if t.dim() == 0:
@@ -100,7 +62,7 @@ class AddsMaskGuidance:
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
if self.gradient_mask:
if self.is_gradient_mask:
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
masked_input = torch.where(mask_bool, latents, mask_latents)
@@ -200,7 +162,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
control_model: ControlNetModel = None,
):
super().__init__(
vae=vae,
@@ -214,8 +175,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.control_model = control_model
self.use_ip_adapter = False
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
@@ -280,116 +239,131 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
raise Exception("Should not be called")
def add_inpainting_channels_to_latents(
self, latents: torch.Tensor, masked_ref_image_latents: torch.Tensor, inpainting_mask: torch.Tensor
):
"""Given a `latents` tensor, adds the mask and image latents channels required for inpainting.
Standard (non-inpainting) SD UNet models expect an input with shape (N, 4, H, W). Inpainting models expect an
input of shape (N, 9, H, W). The 9 channels are defined as follows:
- Channel 0-3: The latents being denoised.
- Channel 4: The mask indicating which parts of the image are being inpainted.
- Channel 5-8: The latent representation of the masked reference image being inpainted.
This function assumes that the same mask and base image should apply to all items in the batch.
"""
# Validate assumptions about input tensor shapes.
batch_size, latent_channels, latent_height, latent_width = latents.shape
assert latent_channels == 4
assert masked_ref_image_latents.shape == [1, 4, latent_height, latent_width]
assert inpainting_mask == [1, 1, latent_height, latent_width]
# Repeat original_image_latents and inpainting_mask to match the latents batch size.
original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1)
inpainting_mask = inpainting_mask.expand(batch_size, -1, -1, -1)
# Concatenate along the channel dimension.
return torch.cat([latents, inpainting_mask, original_image_latents], dim=1)
def latents_from_embeddings(
self,
latents: torch.Tensor,
num_inference_steps: int,
scheduler_step_kwargs: dict[str, Any],
conditioning_data: TextConditioningData,
*,
noise: Optional[torch.Tensor],
seed: int,
timesteps: torch.Tensor,
init_timestep: torch.Tensor,
additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
callback: Callable[[PipelineIntermediateState], None],
control_data: list[ControlNetData] | None = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
gradient_mask: Optional[bool] = False,
seed: int,
is_gradient_mask: bool = False,
) -> torch.Tensor:
if init_timestep.shape[0] == 0:
return latents
"""Denoise the latents.
if additional_guidance is None:
additional_guidance = []
Args:
latents: The latent-space image to denoise.
- If we are inpainting, this is the initial latent image before noise has been added.
- If we are generating a new image, this should be initialized to zeros.
- In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method.
conditioning_data: Text conditionging data.
noise: Noise used for two purposes:
1. Used by the scheduler to noise the initial `latents` before denoising.
2. Used to noise the `masked_latents` when inpainting.
`noise` should be None if the `latents` tensor has already been noised.
seed: The seed used to generate the noise for the denoising process.
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
same noise used earlier in the pipeline. This should really be handled in a clearer way.
timesteps: The timestep schedule for the denoising process.
init_timestep: The first timestep in the schedule.
TODO(ryand): I'm pretty sure this should always be the same as timesteps[0:1]. Confirm that that is the
case, and remove this duplicate param.
callback: A callback function that is called to report progress during the denoising process.
control_data: ControlNet data.
ip_adapter_data: IP-Adapter data.
t2i_adapter_data: T2I-Adapter data.
mask: A mask indicating which parts of the image are being inpainted. The presence of mask is used to
determine whether we are inpainting or not. `mask` should have the same spatial dimensions as the
`latents` tensor.
TODO(ryand): Check and document the expected dtype, range, and values used to represent
foreground/background.
masked_latents: A latent-space representation of a masked inpainting reference image. This tensor is only
used if an *inpainting* model is being used i.e. this tensor is not used when inpainting with a standard
SD UNet model.
is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not.
"""
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
return latents
orig_latents = latents.clone()
batch_size = latents.shape[0]
batched_t = init_timestep.expand(batch_size)
batched_init_timestep = init_timestep.expand(batch_size)
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
if noise is not None:
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
# full noise. Investigate the history of why this got commented out.
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
if mask is not None:
if is_inpainting_model(self.unet):
if masked_latents is None:
raise Exception("Source image required for inpaint mask when inpaint model used!")
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
self._unet_forward, mask, masked_latents
)
else:
# if no noise provided, noisify unmasked area based on seed
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
try:
latents = self.generate_latents_from_embeddings(
latents,
timesteps,
conditioning_data,
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=callback,
)
finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward
# restore unmasked part after the last step is completed
# in-process masking happens before each step
if mask is not None:
if gradient_mask:
latents = torch.where(mask > 0, latents, orig_latents)
else:
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
return latents
def generate_latents_from_embeddings(
self,
latents: torch.Tensor,
timesteps,
conditioning_data: TextConditioningData,
scheduler_step_kwargs: dict[str, Any],
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
) -> torch.Tensor:
self._adjust_memory_efficient_attention(latents)
if additional_guidance is None:
additional_guidance = []
batch_size = latents.shape[0]
# Handle mask guidance (a.k.a. inpainting).
mask_guidance: AddsMaskGuidance | None = None
if mask is not None and not is_inpainting_model(self.unet):
# We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will
# apply mask guidance to the latents.
if timesteps.shape[0] == 0:
return latents
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
# We still need noise for inpainting, so we generate it from the seed here.
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
mask_guidance = AddsMaskGuidance(
mask=mask,
mask_latents=orig_latents,
scheduler=self.scheduler,
noise=noise,
is_gradient_mask=is_gradient_mask,
)
use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting:
@@ -402,28 +376,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
with attn_ctx:
if callback is not None:
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
)
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
step_output = self.step(
batched_t,
latents,
conditioning_data,
t=batched_t,
latents=latents,
conditioning_data=conditioning_data,
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
mask_guidance=mask_guidance,
mask=mask,
masked_latents=masked_latents,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
@@ -431,19 +405,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = step_output.prev_sample
predicted_original = getattr(step_output, "pred_original_sample", None)
if callback is not None:
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
)
return latents
# restore unmasked part after the last step is completed
# in-process masking happens before each step
if mask is not None:
if is_gradient_mask:
latents = torch.where(mask > 0, latents, orig_latents)
else:
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
return latents
@torch.inference_mode()
def step(
@@ -454,19 +437,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
mask_guidance: AddsMaskGuidance | None,
mask: torch.Tensor | None,
masked_latents: torch.Tensor | None,
control_data: list[ControlNetData] | None = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
if additional_guidance is None:
additional_guidance = []
# one day we will expand this extension point, but for now it just does denoise masking
for guidance in additional_guidance:
latents = guidance(latents, timestep)
# Handle masked image-to-image (a.k.a inpainting).
if mask_guidance is not None:
# NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...).
latents = mask_guidance(latents, timestep)
# TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent
@@ -514,6 +498,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
down_intrablock_additional_residuals = accum_adapter_state
# Handle inpainting models.
if is_inpainting_model(self.unet):
# NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after*
# self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image
# latents.
if mask is not None:
if masked_latents is None:
raise ValueError("Source image required for inpaint mask when inpaint model used!")
latent_model_input = self.add_inpainting_channels_to_latents(
latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask
)
else:
# We are using an inpainting model, but no mask was provided, so we are not really "inpainting".
# We generate a global mask and empty original image so that we can still generate in this
# configuration.
# TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting
# to do this.
# TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake'
# mask and original image once rather than on every denoising step.
latent_model_input = self.add_inpainting_channels_to_latents(
latents=latent_model_input,
masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]),
inpainting_mask=torch.ones_like(latent_model_input[:1, :1]),
)
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input,
timestep=t, # TODO: debug how handled batched and non batched timesteps
@@ -542,17 +551,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
for guidance in additional_guidance:
# apply the mask to any "denoised" or "pred_original_sample" fields
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting
# again.
if mask_guidance is not None:
# Apply the mask to any "denoised" or "pred_original_sample" fields.
if hasattr(step_output, "denoised"):
step_output.pred_original_sample = guidance(step_output.denoised, self.scheduler.timesteps[-1])
step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1])
elif hasattr(step_output, "pred_original_sample"):
step_output.pred_original_sample = guidance(
step_output.pred_original_sample = mask_guidance(
step_output.pred_original_sample, self.scheduler.timesteps[-1]
)
else:
step_output.pred_original_sample = guidance(latents, self.scheduler.timesteps[-1])
step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1])
return step_output
@@ -575,17 +585,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
**kwargs,
):
"""predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4:
# Pad out normal non-inpainting inputs for an inpainting model.
# FIXME: There are too many layers of functions and we have too many different ways of
# overriding things! This should get handled in a way more consistent with the other
# use of AddsMaskLatents.
latents = AddsMaskLatents(
self._unet_forward,
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype),
).add_mask_channels(latents)
# First three args should be positional, not keywords, so torch hooks can see them.
return self.unet(
latents,

View File

@@ -32,11 +32,8 @@ class SDXLConditioningInfo(BasicConditioningInfo):
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
assert self.pooled_embeds.device == device
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
result = super().to(device=device, dtype=dtype)
assert self.embeds.device == device
return result
return super().to(device=device, dtype=dtype)
@dataclass

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import math
import threading
from typing import Any, Callable, Optional, Union
import torch
@@ -294,31 +293,24 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
added_cond_kwargs = None
try:
if conditioning_data.is_sdxl():
# tid = threading.current_thread().ident
# print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.uncond_text.pooled_embeds,
conditioning_data.cond_text.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.uncond_text.add_time_ids,
conditioning_data.cond_text.add_time_ids,
],
dim=0,
),
}
except Exception as e:
tid = threading.current_thread().ident
print(f"DEBUG: {tid} {str(e)}")
raise e
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.uncond_text.pooled_embeds,
conditioning_data.cond_text.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.uncond_text.add_time_ids,
conditioning_data.cond_text.add_time_ids,
],
dim=0,
),
}
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings

View File

@@ -0,0 +1,242 @@
from __future__ import annotations
import copy
from dataclasses import dataclass
from typing import Any, Callable, Optional
import torch
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
PipelineIntermediateState,
StableDiffusionGeneratorPipeline,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
from invokeai.backend.tiles.utils import TBLR
# The maximum number of regions with compatible sizes that will be batched together.
# Larger batch sizes improve speed, but require more device memory.
MAX_REGION_BATCH_SIZE = 4
@dataclass
class MultiDiffusionRegionConditioning:
# Region coords in latent space.
region: TBLR
text_conditioning_data: TextConditioningData
control_data: list[ControlNetData]
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
def _split_into_region_batches(
self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]
) -> list[list[MultiDiffusionRegionConditioning]]:
# Group the regions by shape. Only regions with the same shape can be batched together.
conditioning_by_shape: dict[tuple[int, int], list[MultiDiffusionRegionConditioning]] = {}
for region_conditioning in multi_diffusion_conditioning:
shape_hw = (
region_conditioning.region.bottom - region_conditioning.region.top,
region_conditioning.region.right - region_conditioning.region.left,
)
# In python, a tuple of hashable objects is hashable, so can be used as a key in a dict.
if shape_hw not in conditioning_by_shape:
conditioning_by_shape[shape_hw] = []
conditioning_by_shape[shape_hw].append(region_conditioning)
# Split the regions into batches, respecting the MAX_REGION_BATCH_SIZE constraint.
region_conditioning_batches = []
for region_conditioning_batch in conditioning_by_shape.values():
for i in range(0, len(region_conditioning_batch), MAX_REGION_BATCH_SIZE):
region_conditioning_batches.append(region_conditioning_batch[i : i + MAX_REGION_BATCH_SIZE])
return region_conditioning_batches
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
"""Check the input conditioning and confirm that regional prompting is not used."""
for region_conditioning in multi_diffusion_conditioning:
if (
region_conditioning.text_conditioning_data.cond_regions is not None
or region_conditioning.text_conditioning_data.uncond_regions is not None
):
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
def multi_diffusion_denoise(
self,
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
latents: torch.Tensor,
scheduler_step_kwargs: dict[str, Any],
noise: Optional[torch.Tensor],
timesteps: torch.Tensor,
init_timestep: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None],
) -> torch.Tensor:
self._check_regional_prompting(multi_diffusion_conditioning)
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
return latents
batch_size, _, latent_height, latent_width = latents.shape
batched_init_timestep = init_timestep.expand(batch_size)
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
if noise is not None:
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
# full noise. Investigate the history of why this got commented out.
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
# cropping into regions.
self._adjust_memory_efficient_attention(latents)
# Populate a weighted mask that will be used to combine the results from each region after every step.
# For now, we assume that each region has the same weight (1.0).
region_weight_mask = torch.zeros(
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
)
for region_conditioning in multi_diffusion_conditioning:
region = region_conditioning.region
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
# Group the region conditioning into batches for faster processing.
# region_conditioning_batches[b][r] is the r'th region in the b'th batch.
region_conditioning_batches = self._split_into_region_batches(multi_diffusion_conditioning)
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
# separate scheduler state for each region batch.
region_batch_schedulers: list[SchedulerMixin] = [
copy.deepcopy(self.scheduler) for _ in region_conditioning_batches
]
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
merged_latents = torch.zeros_like(latents)
merged_pred_original: torch.Tensor | None = None
for region_batch_idx, region_conditioning_batch in enumerate(region_conditioning_batches):
# Switch to the scheduler for the region batch.
self.scheduler = region_batch_schedulers[region_batch_idx]
# TODO(ryand): This logic has not yet been tested with input latents with a batch_size > 1.
# Prepare the latents for the region batch.
batch_latents = torch.cat(
[
latents[
:,
:,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
for region_conditioning in region_conditioning_batch
],
)
# TODO(ryand): Do we have to repeat the text_conditioning_data to match the batch size? Or does step()
# handle broadcasting properly?
# TODO(ryand): Resume here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# Run the denoising step on the region.
step_output = self.step(
t=batched_t,
latents=batch_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=i,
total_step_count=total_step_count,
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)
# Run a denoising step on the region.
# step_output = self._region_step(
# region_conditioning=region_conditioning,
# t=batched_t,
# latents=latents,
# step_index=i,
# total_step_count=len(timesteps),
# scheduler_step_kwargs=scheduler_step_kwargs,
# )
# Store the results from the region.
region = region_conditioning.region
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
if pred_orig_sample is not None:
# If one region has pred_original_sample, then we can assume that all regions will have it, because
# they all use the same scheduler.
if merged_pred_original is None:
merged_pred_original = torch.zeros_like(latents)
merged_pred_original[:, :, region.top : region.bottom, region.left : region.right] += (
pred_orig_sample
)
# Normalize the merged results.
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
predicted_original = None
if merged_pred_original is not None:
predicted_original = torch.where(
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
)
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
)
return latents
@torch.inference_mode()
def _region_batch_step(
self,
region_conditioning: MultiDiffusionRegionConditioning,
t: torch.Tensor,
latents: torch.Tensor,
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
):
# Crop the inputs to the region.
region_latents = latents[
:,
:,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
# Run the denoising step on the region.
return self.step(
t=t,
latents=region_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=step_index,
total_step_count=total_step_count,
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)

View File

@@ -1,16 +1,10 @@
"""Torch Device class provides torch device selection services."""
from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union
from typing import Dict, Literal, Optional, Union
import torch
from deprecated import deprecated
from invokeai.app.services.config.config_default import get_config
if TYPE_CHECKING:
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
# legacy APIs
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
CPU_DEVICE = torch.device("cpu")
@@ -48,23 +42,9 @@ PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NA
class TorchDevice:
"""Abstraction layer for torch devices."""
_model_cache: Optional["ModelCacheBase[AnyModel]"] = None
@classmethod
def set_model_cache(cls, cache: "ModelCacheBase[AnyModel]"):
"""Set the current model cache."""
cls._model_cache = cache
@classmethod
def choose_torch_device(cls) -> torch.device:
"""Return the torch.device to use for accelerated inference."""
if cls._model_cache:
return cls._model_cache.get_execution_device()
else:
return cls._choose_device()
@classmethod
def _choose_device(cls) -> torch.device:
app_config = get_config()
if app_config.device != "auto":
device = torch.device(app_config.device)
@@ -76,19 +56,11 @@ class TorchDevice:
device = CPU_DEVICE
return cls.normalize(device)
@classmethod
def execution_devices(cls) -> Set[torch.device]:
"""Return a list of torch.devices that can be used for accelerated inference."""
app_config = get_config()
if app_config.devices is None:
return cls._lookup_execution_devices()
return {torch.device(x) for x in app_config.devices}
@classmethod
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
"""Return the precision to use for accelerated inference."""
device = device or cls.choose_torch_device()
config = get_config()
device = device or cls._choose_device()
if device.type == "cuda" and torch.cuda.is_available():
device_name = torch.cuda.get_device_name(device)
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
@@ -136,13 +108,3 @@ class TorchDevice:
@classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]
@classmethod
def _lookup_execution_devices(cls) -> Set[torch.device]:
if torch.cuda.is_available():
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
elif torch.backends.mps.is_available():
devices = {torch.device("mps")}
else:
devices = {torch.device("cpu")}
return devices

View File

@@ -1,54 +0,0 @@
#!/bin/env python
from argparse import ArgumentParser, Namespace
from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig, get_config
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
def get_args() -> Namespace:
parser = ArgumentParser(description="Update models database from yaml file")
parser.add_argument("--root", type=Path, required=False, default=None)
parser.add_argument("--yaml_file", type=Path, required=False, default=None)
return parser.parse_args()
def populate_config() -> InvokeAIAppConfig:
args = get_args()
config = get_config()
if args.root:
config._root = args.root
if args.yaml_file:
config.legacy_models_yaml_path = args.yaml_file
else:
config.legacy_models_yaml_path = config.root_path / "configs/models.yaml"
return config
def initialize_installer(config: InvokeAIAppConfig) -> ModelInstallService:
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db)
queue = DownloadQueueService()
queue.start()
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=queue)
return installer
def main() -> None:
config = populate_config()
installer = initialize_installer(config)
installer._migrate_yaml(rename_yaml=False, overwrite_db=True)
print("\n<INSTALLED MODELS>")
print("\t".join(["key", "name", "type", "path"]))
for model in installer.record_store.all_models():
print("\t".join([model.key, model.name, model.type, (config.models_path / model.path).as_posix()]))
if __name__ == "__main__":
main()

View File

@@ -14,14 +14,13 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = mm2_model_manager.install.register_path(embedding_file)
with mm2_model_manager.load.ram_cache.reserve_execution_device():
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
assert loaded_model is not None
assert loaded_model.config.key == key
with loaded_model as model:
assert isinstance(model, TextualInversionModelRaw)
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
assert loaded_model is not None
assert loaded_model.config.key == key
with loaded_model as model:
assert isinstance(model, TextualInversionModelRaw)
config = mm2_model_manager.store.get_model(key)
loaded_model_2 = mm2_model_manager.load.load_model(config)
config = mm2_model_manager.store.get_model(key)
loaded_model_2 = mm2_model_manager.load.load_model(config)
assert loaded_model.config.key == loaded_model_2.config.key
assert loaded_model.config.key == loaded_model_2.config.key

View File

@@ -89,10 +89,11 @@ def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
ram_cache = ModelCache(
logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram,
max_vram_cache_size=mm2_app_config.vram,
)
convert_cache = ModelConvertCache(mm2_app_config.convert_cache_path)
return ModelLoadService(

View File

@@ -8,9 +8,7 @@ import pytest
import torch
from invokeai.app.services.config import get_config
from invokeai.backend.model_manager.load import ModelCache
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
@@ -22,7 +20,6 @@ device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", t
def test_device_choice(device_name):
config = get_config()
config.device = device_name
TorchDevice.set_model_cache(None) # disable dynamic selection of GPU device
torch_device = TorchDevice.choose_torch_device()
assert torch_device == torch.device(device_name)
@@ -133,32 +130,3 @@ def test_legacy_precision_name():
assert "float16" == choose_precision(torch.device("cuda"))
assert "float16" == choose_precision(torch.device("mps"))
assert "float32" == choose_precision(torch.device("cpu"))
def test_multi_device_support_1():
config = get_config()
config.devices = ["cuda:0", "cuda:1"]
assert TorchDevice.execution_devices() == {torch.device("cuda:0"), torch.device("cuda:1")}
def test_multi_device_support_2():
config = get_config()
config.devices = None
with (
patch("torch.cuda.device_count", return_value=3),
patch("torch.cuda.is_available", return_value=True),
):
assert TorchDevice.execution_devices() == {
torch.device("cuda:0"),
torch.device("cuda:1"),
torch.device("cuda:2"),
}
def test_multi_device_support_3():
config = get_config()
config.devices = ["cuda:0", "cuda:1"]
cache = ModelCache()
with cache.reserve_execution_device() as gpu:
assert gpu in [torch.device(x) for x in config.devices]
assert TorchDevice.choose_torch_device() == gpu

View File

@@ -17,6 +17,7 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403
@@ -48,13 +49,13 @@ def mock_services() -> InvocationServices:
model_manager=None, # type: ignore
download_queue=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
tensors=None, # type: ignore
conditioning=None, # type: ignore
performance_statistics=None, # type: ignore
)

View File

@@ -92,6 +92,7 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
assert config.host == "192.168.1.1"
assert config.port == 8080
assert config.ram == 100
assert config.vram == 50
assert config.legacy_models_yaml_path == Path("/custom/models.yaml")
# This should be stripped out
assert not hasattr(config, "esrgan")