mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 07:47:56 -05:00
Compare commits
51 Commits
lstein/fea
...
ryan/multi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6bcf48aa37 | ||
|
|
b1bb1511fe | ||
|
|
99046a8145 | ||
|
|
72be7e71e3 | ||
|
|
35adaf1c17 | ||
|
|
865c2335de | ||
|
|
49ca42f84a | ||
|
|
493fcd8660 | ||
|
|
20322d781e | ||
|
|
889d13e02a | ||
|
|
6ccd2a867b | ||
|
|
5861fa1719 | ||
|
|
dfd4beb62b | ||
|
|
83df0c0df5 | ||
|
|
c58c4069a7 | ||
|
|
3937fffa94 | ||
|
|
bbf5f67691 | ||
|
|
2f5c147b84 | ||
|
|
bd2839b748 | ||
|
|
4f70dd7ce1 | ||
|
|
066672fbfd | ||
|
|
abefaee4d1 | ||
|
|
3254ba5904 | ||
|
|
73a8c55852 | ||
|
|
f82af7c22d | ||
|
|
3aef717ef4 | ||
|
|
c2cf1137e9 | ||
|
|
803a24bc0a | ||
|
|
7d24ad8ccd | ||
|
|
cb389063b2 | ||
|
|
81b8a69e1a | ||
|
|
7ee5db87ad | ||
|
|
66cf2c59bd | ||
|
|
3bad1367e9 | ||
|
|
867a7642a6 | ||
|
|
d9d1c8f9cb | ||
|
|
e03eb7fb45 | ||
|
|
85db33bc7e | ||
|
|
93e3a2b504 | ||
|
|
6a7a26f1bf | ||
|
|
08ca03ef9f | ||
|
|
ccf90b6bd6 | ||
|
|
753239b48d | ||
|
|
65fa4664c9 | ||
|
|
297570ded3 | ||
|
|
680fdcf293 | ||
|
|
5ff91f2c44 | ||
|
|
69aa7057e7 | ||
|
|
d3932f40de | ||
|
|
ee74cd7fab | ||
|
|
bda25b40c9 |
@@ -1328,7 +1328,7 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size, logger=logger
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
|
||||
@@ -103,7 +103,6 @@ class CompelInvocation(BaseInvocation):
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
@@ -118,7 +117,6 @@ class CompelInvocation(BaseInvocation):
|
||||
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
@@ -205,7 +203,6 @@ class SDXLPromptInvocationBase:
|
||||
truncate_long_prompts=False, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=get_pooled,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt)
|
||||
@@ -316,6 +313,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import copy
|
||||
import inspect
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
@@ -56,6 +55,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
@@ -66,6 +66,9 @@ def get_scheduler(
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
"""Load a scheduler and apply some scheduler-specific overrides."""
|
||||
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
|
||||
# possible.
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(scheduler_info)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
@@ -183,8 +186,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def _get_text_embeddings_and_masks(
|
||||
self,
|
||||
cond_list: list[ConditioningField],
|
||||
context: InvocationContext,
|
||||
device: torch.device,
|
||||
@@ -194,8 +197,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||
for cond in cond_list:
|
||||
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
|
||||
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||
|
||||
mask = cond.mask
|
||||
if mask is not None:
|
||||
mask = context.tensors.load(mask.tensor_name)
|
||||
@@ -203,8 +207,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return text_embeddings, text_embeddings_masks
|
||||
|
||||
@staticmethod
|
||||
def _preprocess_regional_prompt_mask(
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
@@ -226,11 +231,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
resized_mask = tf(mask)
|
||||
assert isinstance(resized_mask, torch.Tensor)
|
||||
return resized_mask
|
||||
|
||||
@staticmethod
|
||||
def _concat_regional_text_embeddings(
|
||||
self,
|
||||
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||
masks: Optional[list[Optional[torch.Tensor]]],
|
||||
latent_height: int,
|
||||
@@ -280,7 +284,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
)
|
||||
processed_masks.append(
|
||||
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
|
||||
mask, latent_height, latent_width, dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
@@ -302,36 +308,41 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
return BasicConditioningInfo(embeds=text_embedding), regions
|
||||
|
||||
@staticmethod
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
unet: UNet2DConditionModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
cfg_scale: float | list[float],
|
||||
steps: int,
|
||||
cfg_rescale_multiplier: float,
|
||||
) -> TextConditioningData:
|
||||
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
|
||||
cond_list = self.positive_conditioning
|
||||
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
|
||||
cond_list = positive_conditioning_field
|
||||
if not isinstance(cond_list, list):
|
||||
cond_list = [cond_list]
|
||||
uncond_list = self.negative_conditioning
|
||||
uncond_list = negative_conditioning_field
|
||||
if not isinstance(uncond_list, list):
|
||||
uncond_list = [uncond_list]
|
||||
|
||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
cond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
uncond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
|
||||
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
|
||||
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
text_conditionings=cond_text_embeddings,
|
||||
masks=cond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
|
||||
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
@@ -339,23 +350,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
|
||||
if isinstance(self.cfg_scale, list):
|
||||
assert (
|
||||
len(self.cfg_scale) == self.steps
|
||||
), "cfg_scale (list) must have the same length as the number of steps"
|
||||
if isinstance(cfg_scale, list):
|
||||
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
|
||||
|
||||
conditioning_data = TextConditioningData(
|
||||
uncond_text=uncond_text_embedding,
|
||||
cond_text=cond_text_embedding,
|
||||
uncond_regions=uncond_regions,
|
||||
cond_regions=cond_regions,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
guidance_scale=cfg_scale,
|
||||
guidance_rescale_multiplier=cfg_rescale_multiplier,
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
@staticmethod
|
||||
def create_pipeline(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
@@ -378,38 +387,38 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
control_input: Optional[Union[ControlField, List[ControlField]]],
|
||||
control_input: ControlField | list[ControlField] | None,
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> Optional[List[ControlNetData]]:
|
||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
|
||||
if control_input is None:
|
||||
control_list = None
|
||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||
control_list = None
|
||||
elif isinstance(control_input, ControlField):
|
||||
) -> list[ControlNetData] | None:
|
||||
# Normalize control_input to a list.
|
||||
control_list: list[ControlField]
|
||||
if isinstance(control_input, ControlField):
|
||||
control_list = [control_input]
|
||||
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
||||
elif isinstance(control_input, list):
|
||||
control_list = control_input
|
||||
elif control_input is None:
|
||||
control_list = []
|
||||
else:
|
||||
control_list = None
|
||||
if control_list is None:
|
||||
return None
|
||||
# After above handling, any control that is not None should now be of type list[ControlField].
|
||||
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
||||
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
if len(control_list) == 0:
|
||||
return None
|
||||
|
||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||
_, _, latent_height, latent_width = latents_shape
|
||||
control_height_resize = latent_height * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
controlnet_data: list[ControlNetData] = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
assert isinstance(control_model, ControlNetModel)
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.images.get_pil(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
@@ -430,7 +439,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
control_item = ControlNetData(
|
||||
model=control_model, # model object
|
||||
model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
@@ -584,15 +593,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
@staticmethod
|
||||
def init_scheduler(
|
||||
self,
|
||||
scheduler: Union[Scheduler, ConfigMixin],
|
||||
device: torch.device,
|
||||
steps: int,
|
||||
denoising_start: float,
|
||||
denoising_end: float,
|
||||
seed: int,
|
||||
) -> Tuple[int, List[int], int, Dict[str, Any]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||
assert isinstance(scheduler, ConfigMixin)
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
scheduler.set_timesteps(steps, device="cpu")
|
||||
@@ -618,7 +627,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
|
||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||
num_inference_steps = len(timesteps) // scheduler.order
|
||||
|
||||
scheduler_step_kwargs: Dict[str, Any] = {}
|
||||
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||
@@ -640,7 +648,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if isinstance(scheduler, TCDScheduler):
|
||||
scheduler_step_kwargs.update({"eta": 1.0})
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
||||
return timesteps, init_timestep, scheduler_step_kwargs
|
||||
|
||||
def prep_inpaint_mask(
|
||||
self, context: InvocationContext, latents: torch.Tensor
|
||||
@@ -657,31 +665,52 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
seed = None
|
||||
@staticmethod
|
||||
def prepare_noise_and_latents(
|
||||
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
|
||||
) -> Tuple[int, torch.Tensor | None, torch.Tensor]:
|
||||
"""Depending on the workflow, we expect different combinations of noise and latents to be provided. This
|
||||
function handles preparing these values accordingly.
|
||||
|
||||
Expected workflows:
|
||||
- Text-to-Image Denoising: `noise` is provided, `latents` is not. `latents` is initialized to zeros.
|
||||
- Image-to-Image Denoising: `noise` and `latents` are both provided.
|
||||
- Text-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
|
||||
- Image-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
|
||||
|
||||
NOTE(ryand): I wrote this docstring, but I am not the original author of this code. There may be other workflows
|
||||
I haven't considered.
|
||||
"""
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.tensors.load(self.noise.latents_name)
|
||||
seed = self.noise.seed
|
||||
|
||||
if self.latents is not None:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
|
||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||
if noise_field is not None:
|
||||
noise = context.tensors.load(noise_field.latents_name)
|
||||
|
||||
if latents_field is not None:
|
||||
latents = context.tensors.load(latents_field.latents_name)
|
||||
elif noise is not None:
|
||||
latents = torch.zeros_like(noise)
|
||||
else:
|
||||
raise Exception("'latents' or 'noise' must be provided!")
|
||||
raise ValueError("'latents' or 'noise' must be provided!")
|
||||
|
||||
if seed is None:
|
||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||
raise ValueError(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||
|
||||
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
|
||||
seed = 0
|
||||
if noise_field is not None and noise_field.seed is not None:
|
||||
seed = noise_field.seed
|
||||
elif latents_field is not None and latents_field.seed is not None:
|
||||
seed = latents_field.seed
|
||||
else:
|
||||
seed = 0
|
||||
|
||||
return seed, noise, latents
|
||||
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
|
||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||
@@ -755,7 +784,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
|
||||
controlnet_data = self.prep_control_data(
|
||||
@@ -777,7 +814,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
|
||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
@@ -794,8 +831,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed=seed,
|
||||
mask=mask,
|
||||
masked_latents=masked_latents,
|
||||
gradient_mask=gradient_mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
is_gradient_mask=gradient_mask,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=controlnet_data,
|
||||
|
||||
@@ -8,7 +8,7 @@ from diffusers.models.attention_processor import (
|
||||
)
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||
@@ -23,6 +23,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion import set_seamless
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@@ -48,16 +49,20 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
@staticmethod
|
||||
def vae_decode(
|
||||
context: InvocationContext,
|
||||
vae_info: LoadedModel,
|
||||
seamless_axes: list[str],
|
||||
latents: torch.Tensor,
|
||||
use_fp32: bool,
|
||||
use_tiling: bool,
|
||||
) -> Image.Image:
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
if use_fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||
@@ -82,7 +87,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
if self.tiled or context.config.get().force_tiled_decode:
|
||||
if use_tiling or context.config.get().force_tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
@@ -102,6 +107,21 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
image = self.vae_decode(
|
||||
context=context,
|
||||
vae_info=vae_info,
|
||||
seamless_axes=self.vae.seamless_axes,
|
||||
latents=latents,
|
||||
use_fp32=self.fp32,
|
||||
use_tiling=self.tiled,
|
||||
)
|
||||
image_dto = context.images.save(image=image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -0,0 +1,268 @@
|
||||
import copy
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
MultiDiffusionPipeline,
|
||||
MultiDiffusionRegionConditioning,
|
||||
)
|
||||
from invokeai.backend.tiles.tiles import (
|
||||
calc_tiles_min_overlap,
|
||||
)
|
||||
from invokeai.backend.tiles.utils import TBLR
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
|
||||
"""Crop a ControlNetData object to a region."""
|
||||
# Create a shallow copy of the control_data object.
|
||||
control_data_copy = copy.copy(control_data)
|
||||
# The ControlNet reference image is the only attribute that needs to be cropped.
|
||||
control_data_copy.image_tensor = control_data.image_tensor[
|
||||
:,
|
||||
:,
|
||||
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
|
||||
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
|
||||
]
|
||||
return control_data_copy
|
||||
|
||||
|
||||
@invocation(
|
||||
"tiled_multi_diffusion_denoise_latents",
|
||||
title="Tiled Multi-Diffusion Denoise Latents",
|
||||
tags=["upscale", "denoise"],
|
||||
category="latents",
|
||||
# TODO(ryand): Reset to 1.0.0 right before release.
|
||||
version="1.0.0",
|
||||
)
|
||||
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
"""Tiled Multi-Diffusion denoising.
|
||||
|
||||
This node handles automatically tiling the input image. Future iterations of
|
||||
this node should allow the user to specify custom regions with different parameters for each region to harness the
|
||||
full power of Multi-Diffusion.
|
||||
|
||||
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
|
||||
T2I-Adapter, masking, etc.).
|
||||
"""
|
||||
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
noise: LatentsField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
)
|
||||
latents: LatentsField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# TODO(ryand): Add multiple-of validation.
|
||||
# TODO(ryand): Smaller defaults might make more sense.
|
||||
tile_height: int = InputField(default=112, gt=0, description="Height of the tiles in latent space.")
|
||||
tile_width: int = InputField(default=112, gt=0, description="Width of the tiles in latent space.")
|
||||
tile_min_overlap: int = InputField(
|
||||
default=16,
|
||||
gt=0,
|
||||
description="The minimum overlap between adjacent tiles in latent space. The actual overlap may be larger than "
|
||||
"this to evenly cover the entire image.",
|
||||
)
|
||||
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
# TODO(ryand): The default here should probably be 0.0.
|
||||
denoising_start: float = InputField(
|
||||
default=0.65,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
)
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
cfg_rescale_multiplier: float = InputField(
|
||||
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||
)
|
||||
control: ControlField | list[ControlField] | None = InputField(
|
||||
default=None,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v: list[float] | float) -> list[float] | float:
|
||||
"""Validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def create_pipeline(
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
) -> MultiDiffusionPipeline:
|
||||
# TODO(ryand): Get rid of this FakeVae hack.
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self) -> None:
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
return MultiDiffusionPipeline(
|
||||
vae=FakeVae(), # TODO: oh...
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
# Calculate the tile locations to cover the latent-space image.
|
||||
# TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
|
||||
tiles = calc_tiles_min_overlap(
|
||||
image_height=latent_height,
|
||||
image_width=latent_width,
|
||||
tile_height=self.tile_height,
|
||||
tile_width=self.tile_width,
|
||||
min_overlap=self.tile_min_overlap,
|
||||
)
|
||||
|
||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
|
||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
)
|
||||
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
|
||||
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
latent_height=self.tile_height,
|
||||
latent_width=self.tile_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
|
||||
controlnet_data = DenoiseLatentsInvocation.prep_control_data(
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=list(latents.shape),
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# Split the controlnet_data into tiles.
|
||||
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
|
||||
controlnet_data_tiles: list[list[ControlNetData]] = []
|
||||
for tile in tiles:
|
||||
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
|
||||
controlnet_data_tiles.append(tile_controlnet_data)
|
||||
|
||||
# Prepare the MultiDiffusionRegionConditioning list.
|
||||
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
|
||||
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
|
||||
multi_diffusion_conditioning.append(
|
||||
MultiDiffusionRegionConditioning(
|
||||
region=tile.coords,
|
||||
text_conditioning_data=conditioning_data,
|
||||
control_data=tile_controlnet_data,
|
||||
)
|
||||
)
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
# Run Multi-Diffusion denoising.
|
||||
result_latents = pipeline.multi_diffusion_denoise(
|
||||
multi_diffusion_conditioning=multi_diffusion_conditioning,
|
||||
latents=latents,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
# TODO(ryand): Add proper callback.
|
||||
callback=lambda x: None,
|
||||
)
|
||||
|
||||
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
|
||||
result_latents = result_latents.to("cpu")
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||
380
invokeai/app/invocations/tiled_stable_diffusion_refine.py
Normal file
380
invokeai/app/invocations/tiled_stable_diffusion_refine.py
Normal file
@@ -0,0 +1,380 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from PIL import Image
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
|
||||
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField
|
||||
from invokeai.app.invocations.noise import get_noise
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
|
||||
from invokeai.backend.tiles.utils import Tile
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
@invocation(
|
||||
"tiled_stable_diffusion_refine",
|
||||
title="Tiled Stable Diffusion Refine",
|
||||
tags=["upscale", "denoise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
||||
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
|
||||
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
|
||||
"""
|
||||
|
||||
image: ImageField = InputField(description="Image to be refined.")
|
||||
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
# TODO(ryand): Add multiple-of validation.
|
||||
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.")
|
||||
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.")
|
||||
tile_overlap: int = InputField(
|
||||
default=16,
|
||||
gt=0,
|
||||
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).",
|
||||
)
|
||||
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
denoising_start: float = InputField(
|
||||
default=0.65,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
)
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
cfg_rescale_multiplier: float = InputField(
|
||||
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae_fp32: bool = InputField(
|
||||
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
|
||||
)
|
||||
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly
|
||||
# don't want to use the image field. Figure out how best to handle this.
|
||||
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty
|
||||
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
|
||||
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
|
||||
# etc.
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: float = InputField(default=0.6)
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v: list[float] | float) -> list[float] | float:
|
||||
"""Validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
|
||||
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
|
||||
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
|
||||
"""
|
||||
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
|
||||
if coord % LATENT_SCALE_FACTOR != 0:
|
||||
raise ValueError(
|
||||
f"The tile coordinates must all be divisible by the latent scale factor"
|
||||
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
|
||||
)
|
||||
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
|
||||
|
||||
top = image_tile.coords.top // LATENT_SCALE_FACTOR
|
||||
left = image_tile.coords.left // LATENT_SCALE_FACTOR
|
||||
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
|
||||
right = image_tile.coords.right // LATENT_SCALE_FACTOR
|
||||
return latents[..., top:bottom, left:right]
|
||||
|
||||
def run_controlnet(
|
||||
self,
|
||||
image: Image.Image,
|
||||
controlnet_model: ControlNetModel,
|
||||
weight: float,
|
||||
do_classifier_free_guidance: bool,
|
||||
width: int,
|
||||
height: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
||||
) -> ControlNetData:
|
||||
control_image = prepare_control_image(
|
||||
image=image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode=control_mode,
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
return ControlNetData(
|
||||
model=controlnet_model,
|
||||
image_tensor=control_image,
|
||||
weight=weight,
|
||||
begin_step_percent=0.0,
|
||||
end_step_percent=1.0,
|
||||
control_mode=control_mode,
|
||||
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
|
||||
# ControlNetData in case needed in the future.
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# TODO(ryand): Expose the seed parameter.
|
||||
seed = 0
|
||||
|
||||
# Load the input image.
|
||||
input_image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Calculate the tile locations to cover the image.
|
||||
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This
|
||||
# facilitates conversions between image space and latent space.
|
||||
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
|
||||
tiles = calc_tiles_with_overlap(
|
||||
image_height=input_image.height,
|
||||
image_width=input_image.width,
|
||||
tile_height=self.tile_height,
|
||||
tile_width=self.tile_width,
|
||||
overlap=self.tile_overlap,
|
||||
)
|
||||
|
||||
# Convert the input image to a torch.Tensor.
|
||||
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
|
||||
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
|
||||
# Validate our assumptions about the shape of input_image_torch.
|
||||
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
|
||||
assert input_image_torch.shape[:2] == (1, 3)
|
||||
|
||||
# Split the input image into tiles in torch.Tensor format.
|
||||
image_tiles_torch: list[torch.Tensor] = []
|
||||
for tile in tiles:
|
||||
image_tile = input_image_torch[
|
||||
:,
|
||||
:,
|
||||
tile.coords.top : tile.coords.bottom,
|
||||
tile.coords.left : tile.coords.right,
|
||||
]
|
||||
image_tiles_torch.append(image_tile)
|
||||
|
||||
# Split the input image into tiles in numpy format.
|
||||
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
|
||||
# with torch.Tensor tiles.
|
||||
input_image_np = np.array(input_image)
|
||||
image_tiles_np: list[npt.NDArray[np.uint8]] = []
|
||||
for tile in tiles:
|
||||
image_tile_np = input_image_np[
|
||||
tile.coords.top : tile.coords.bottom,
|
||||
tile.coords.left : tile.coords.right,
|
||||
:,
|
||||
]
|
||||
image_tiles_np.append(image_tile_np)
|
||||
|
||||
# VAE-encode each image tile independently.
|
||||
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
|
||||
# about for decoding?
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
latent_tiles: list[torch.Tensor] = []
|
||||
for image_tile_torch in image_tiles_torch:
|
||||
latent_tiles.append(
|
||||
ImageToLatentsInvocation.vae_encode(
|
||||
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch
|
||||
)
|
||||
)
|
||||
|
||||
# Generate noise with dimensions corresponding to the full image in latent space.
|
||||
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
|
||||
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
|
||||
# noise.
|
||||
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
|
||||
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
|
||||
global_noise = get_noise(
|
||||
width=input_image_torch.shape[3],
|
||||
height=input_image_torch.shape[2],
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
seed=seed,
|
||||
downsampling_factor=LATENT_SCALE_FACTOR,
|
||||
use_cpu=True,
|
||||
)
|
||||
|
||||
# Crop the global noise into tiles.
|
||||
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles]
|
||||
|
||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
|
||||
refined_latent_tiles: list[torch.Tensor] = []
|
||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
)
|
||||
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
|
||||
# Assume that all tiles have the same shape.
|
||||
_, _, latent_height, latent_width = latent_tiles[0].shape
|
||||
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
|
||||
# Load the ControlNet model.
|
||||
# TODO(ryand): Support multiple ControlNet models.
|
||||
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model))
|
||||
assert isinstance(controlnet_model, ControlNetModel)
|
||||
|
||||
# Denoise (i.e. "refine") each tile independently.
|
||||
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
|
||||
assert latent_tile.shape == noise_tile.shape
|
||||
|
||||
# Prepare a PIL Image for ControlNet processing.
|
||||
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
|
||||
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
|
||||
image_tile_pil = Image.fromarray(image_tile_np)
|
||||
|
||||
# Run the ControlNet on the image tile.
|
||||
height, width, _ = image_tile_np.shape
|
||||
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
|
||||
# validate this assumption here.
|
||||
assert height % LATENT_SCALE_FACTOR == 0
|
||||
assert width % LATENT_SCALE_FACTOR == 0
|
||||
controlnet_data = self.run_controlnet(
|
||||
image=image_tile_pil,
|
||||
controlnet_model=controlnet_model,
|
||||
weight=self.control_weight,
|
||||
do_classifier_free_guidance=True,
|
||||
width=width,
|
||||
height=height,
|
||||
device=controlnet_model.device,
|
||||
dtype=controlnet_model.dtype,
|
||||
control_mode="balanced",
|
||||
resize_mode="just_resize_simple",
|
||||
)
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
|
||||
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
|
||||
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
|
||||
refined_latent_tile = pipeline.latents_from_embeddings(
|
||||
latents=latent_tile,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise_tile,
|
||||
seed=seed,
|
||||
mask=None,
|
||||
masked_latents=None,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=[controlnet_data],
|
||||
ip_adapter_data=None,
|
||||
t2i_adapter_data=None,
|
||||
callback=lambda x: None,
|
||||
)
|
||||
refined_latent_tiles.append(refined_latent_tile)
|
||||
|
||||
# VAE-decode each refined latent tile independently.
|
||||
refined_image_tiles: list[Image.Image] = []
|
||||
for refined_latent_tile in refined_latent_tiles:
|
||||
refined_image_tile = LatentsToImageInvocation.vae_decode(
|
||||
context=context,
|
||||
vae_info=vae_info,
|
||||
seamless_axes=self.vae.seamless_axes,
|
||||
latents=refined_latent_tile,
|
||||
use_fp32=self.vae_fp32,
|
||||
use_tiling=False,
|
||||
)
|
||||
refined_image_tiles.append(refined_image_tile)
|
||||
|
||||
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
# Merge the refined image tiles back into a single image.
|
||||
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
|
||||
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
|
||||
# TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
|
||||
merge_tiles_with_linear_blending(
|
||||
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap
|
||||
)
|
||||
|
||||
# Save the refined image and return its reference.
|
||||
merged_image_pil = Image.fromarray(merged_image_np)
|
||||
image_dto = context.images.save(image=merged_image_pil)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -26,13 +26,13 @@ LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
DEFAULT_CONVERT_CACHE = 20.0
|
||||
DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
|
||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
CONFIG_SCHEMA_VERSION = "4.0.2"
|
||||
CONFIG_SCHEMA_VERSION = "4.0.1"
|
||||
|
||||
|
||||
def get_default_ram_cache_size() -> float:
|
||||
@@ -105,17 +105,14 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
convert_cache: Maximum size of on-disk converted models cache (GB).
|
||||
lazy_offload: Keep models in VRAM until their space is needed.
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps`
|
||||
devices: List of execution devices; will override default device selected.
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
||||
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||
max_queue_size: Maximum number of items in the session queue.
|
||||
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
|
||||
clear_queue_on_startup: Empties session queue on startup.
|
||||
allow_nodes: List of nodes to allow. Omit to allow all.
|
||||
deny_nodes: List of nodes to deny. Omit to deny none.
|
||||
node_cache_size: How many cached nodes to keep in memory.
|
||||
@@ -180,7 +177,6 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
|
||||
# DEVICE
|
||||
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
|
||||
devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.")
|
||||
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
|
||||
|
||||
# GENERATION
|
||||
@@ -190,8 +186,6 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
|
||||
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
|
||||
|
||||
# NODES
|
||||
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
|
||||
@@ -380,6 +374,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
@@ -427,27 +424,6 @@ def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig
|
||||
return config
|
||||
|
||||
|
||||
def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate v4.0.1 config dictionary to a current config object.
|
||||
|
||||
A few new multi-GPU options were added in 4.0.2, and this simply
|
||||
updates the schema label.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.1 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, _ in config_dict.items():
|
||||
if k == "schema_version":
|
||||
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
return config
|
||||
|
||||
|
||||
# TO DO: replace this with a formal registration and migration system
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
@@ -479,10 +455,6 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
|
||||
elif loaded_config_dict["schema_version"] == "4.0.1":
|
||||
loaded_config_dict = migrate_v4_0_1_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
|
||||
@@ -53,11 +53,11 @@ class InvocationServices:
|
||||
model_images: "ModelImageFileStorageBase",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
session_queue: "SessionQueueBase",
|
||||
session_processor: "SessionProcessorBase",
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
names: "NameServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||
@@ -77,11 +77,11 @@ class InvocationServices:
|
||||
self.model_images = model_images
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.performance_statistics = performance_statistics
|
||||
self.session_queue = session_queue
|
||||
self.session_processor = session_processor
|
||||
self.invocation_cache = invocation_cache
|
||||
self.names = names
|
||||
self.performance_statistics = performance_statistics
|
||||
self.urls = urls
|
||||
self.workflow_records = workflow_records
|
||||
self.tensors = tensors
|
||||
|
||||
@@ -74,9 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
self._stats.pop(graph_execution_state_id)
|
||||
self._cache_stats.pop(graph_execution_state_id)
|
||||
def reset_stats(self):
|
||||
self._stats = {}
|
||||
self._cache_stats = {}
|
||||
|
||||
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
||||
|
||||
@@ -284,14 +284,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
|
||||
self._install_jobs = unfinished_jobs
|
||||
|
||||
def _migrate_yaml(self, rename_yaml: Optional[bool] = True, overwrite_db: Optional[bool] = False) -> None:
|
||||
def _migrate_yaml(self) -> None:
|
||||
db_models = self.record_store.all_models()
|
||||
|
||||
if overwrite_db:
|
||||
for model in db_models:
|
||||
self.record_store.del_model(model.key)
|
||||
db_models = self.record_store.all_models()
|
||||
|
||||
legacy_models_yaml_path = (
|
||||
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
|
||||
)
|
||||
@@ -341,8 +336,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
||||
|
||||
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
||||
if rename_yaml:
|
||||
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
||||
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
||||
|
||||
# Unset the path - we are done with it either way
|
||||
self._app_config.legacy_models_yaml_path = None
|
||||
|
||||
@@ -33,11 +33,6 @@ class ModelLoadServiceBase(ABC):
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs we are configured to use."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
|
||||
@@ -46,7 +46,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._registry = registry
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
"""Start the service."""
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
@@ -54,11 +53,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@property
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs available for our uses."""
|
||||
return len(self._ram_cache.execution_devices)
|
||||
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
@@ -32,7 +31,7 @@ class ModelManagerServiceBase(ABC):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_devices: Optional[Set[torch.device]] = None,
|
||||
execution_device: torch.device,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
@@ -65,6 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
@@ -77,7 +82,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram,
|
||||
max_vram_cache_size=app_config.vram,
|
||||
lazy_offloading=app_config.lazy_offload,
|
||||
logger=logger,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
)
|
||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||
loader = ModelLoadService(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
@@ -10,7 +9,6 @@ import torch
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
@@ -72,10 +70,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
return self._output_dir / name
|
||||
|
||||
def _new_name(self) -> str:
|
||||
tid = threading.current_thread().ident
|
||||
# Add tid to the object name because uuid4 not thread-safe on windows
|
||||
# See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
|
||||
return f"{self._obj_class_name}_{tid}-{uuid_string()}"
|
||||
return f"{self._obj_class_name}_{uuid_string()}"
|
||||
|
||||
def _tempdir_cleanup(self) -> None:
|
||||
"""Calls `cleanup` on the temporary directory, if it exists."""
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from queue import Queue
|
||||
from threading import BoundedSemaphore, Lock, Thread
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional, Set
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.events.events_common import (
|
||||
@@ -27,7 +26,6 @@ from invokeai.app.services.session_queue.session_queue_common import SessionQueu
|
||||
from invokeai.app.services.shared.graph import NodeInputError
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
|
||||
@@ -59,11 +57,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
|
||||
self._on_node_error_callbacks = on_node_error_callbacks or []
|
||||
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
||||
self._process_lock = Lock()
|
||||
|
||||
def start(
|
||||
self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
|
||||
) -> None:
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||
self._services = services
|
||||
self._cancel_event = cancel_event
|
||||
self._profiler = profiler
|
||||
@@ -81,8 +76,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while True:
|
||||
try:
|
||||
with self._process_lock:
|
||||
invocation = queue_item.session.next()
|
||||
invocation = queue_item.session.next()
|
||||
# Anything other than a `NodeInputError` is handled as a processor error
|
||||
except NodeInputError as e:
|
||||
error_type = e.__class__.__name__
|
||||
@@ -114,7 +108,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
|
||||
self._on_after_run_session(queue_item=queue_item)
|
||||
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
try:
|
||||
# Any unhandled exception in this scope is an invocation error & will fail the graph
|
||||
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
@@ -216,7 +210,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._services.performance_statistics.log_stats(queue_item.session.id)
|
||||
self._services.performance_statistics.reset_stats(queue_item.session.id)
|
||||
self._services.performance_statistics.reset_stats()
|
||||
|
||||
for callback in self._on_after_run_session_callbacks:
|
||||
callback(queue_item=queue_item)
|
||||
@@ -330,7 +324,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._active_queue_items: Set[SessionQueueItem] = set()
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
@@ -356,14 +350,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
else None
|
||||
)
|
||||
|
||||
self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
|
||||
TorchDevice.execution_devices()
|
||||
)
|
||||
|
||||
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
|
||||
|
||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
|
||||
# Session processor - singlethreaded
|
||||
self._thread = Thread(
|
||||
name="session_processor",
|
||||
target=self._process,
|
||||
@@ -376,16 +363,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
# Session processor workers - multithreaded
|
||||
self._invoker.services.logger.debug(f"Starting {self._worker_thread_count} session processing threads.")
|
||||
for _i in range(0, self._worker_thread_count):
|
||||
worker = Thread(
|
||||
name="session_worker",
|
||||
target=self._process_next_session,
|
||||
daemon=True,
|
||||
)
|
||||
worker.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self._stop_event.set()
|
||||
|
||||
@@ -393,7 +370,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now_event.set()
|
||||
|
||||
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
|
||||
if any(item.queue_id == event[1].queue_id for item in self._active_queue_items):
|
||||
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
|
||||
@@ -401,7 +378,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now()
|
||||
|
||||
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
|
||||
if self._active_queue_items and event[1].status in ["completed", "failed", "canceled"]:
|
||||
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
|
||||
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
|
||||
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
|
||||
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
|
||||
@@ -426,7 +403,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
return SessionProcessorStatus(
|
||||
is_started=self._resume_event.is_set(),
|
||||
is_processing=len(self._active_queue_items) > 0,
|
||||
is_processing=self._queue_item is not None,
|
||||
)
|
||||
|
||||
def _process(
|
||||
@@ -451,22 +428,30 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
resume_event.wait()
|
||||
|
||||
# Get the next session to process
|
||||
queue_item = self._invoker.services.session_queue.dequeue()
|
||||
self._queue_item = self._invoker.services.session_queue.dequeue()
|
||||
|
||||
if queue_item is None:
|
||||
if self._queue_item is None:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
self._session_worker_queue.put(queue_item)
|
||||
self._invoker.services.logger.debug(f"Scheduling queue item {queue_item.item_id} to run")
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# Run the graph
|
||||
# self.session_runner.run(queue_item=self._queue_item)
|
||||
self.session_runner.run(queue_item=self._queue_item)
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
error_type = e.__class__.__name__
|
||||
error_message = str(e)
|
||||
error_traceback = traceback.format_exc()
|
||||
self._on_non_fatal_processor_error(
|
||||
queue_item=self._queue_item,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
# Wait for next polling interval or event to try again
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
@@ -481,25 +466,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
finally:
|
||||
stop_event.clear()
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
|
||||
def _process_next_session(self) -> None:
|
||||
while True:
|
||||
self._resume_event.wait()
|
||||
queue_item = self._session_worker_queue.get()
|
||||
if queue_item.status == "canceled":
|
||||
continue
|
||||
try:
|
||||
self._active_queue_items.add(queue_item)
|
||||
# reserve a GPU for this session - may block
|
||||
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
|
||||
# Run the session on the reserved GPU
|
||||
self.session_runner.run(queue_item=queue_item)
|
||||
except Exception:
|
||||
continue
|
||||
finally:
|
||||
self._active_queue_items.remove(queue_item)
|
||||
|
||||
def _on_non_fatal_processor_error(
|
||||
self,
|
||||
queue_item: Optional[SessionQueueItem],
|
||||
|
||||
@@ -236,9 +236,6 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
}
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.item_id
|
||||
|
||||
|
||||
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||
pass
|
||||
|
||||
@@ -37,14 +37,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
if self.__invoker.services.configuration.clear_queue_on_startup:
|
||||
clear_result = self.clear(DEFAULT_QUEUE_ID)
|
||||
if clear_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
|
||||
else:
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -652,7 +652,7 @@ class Graph(BaseModel):
|
||||
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
|
||||
# Input type must be a list
|
||||
if get_origin(input_field) is not list:
|
||||
if get_origin(input_field) != list:
|
||||
return False
|
||||
|
||||
# Validate that all outputs match the input type
|
||||
|
||||
@@ -2,7 +2,6 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from torch import Tensor
|
||||
@@ -27,13 +26,11 @@ from invokeai.backend.model_manager.config import (
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
"""
|
||||
The InvocationContext provides access to various services and data about the current invocation.
|
||||
@@ -326,6 +323,7 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
The loaded conditioning data.
|
||||
"""
|
||||
|
||||
return self._services.conditioning.load(name)
|
||||
|
||||
|
||||
@@ -559,28 +557,6 @@ class UtilInterface(InvocationContextInterface):
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
def torch_device(self) -> torch.device:
|
||||
"""
|
||||
Return a torch device to use in the current invocation.
|
||||
|
||||
Returns:
|
||||
A torch.device not currently in use by the system.
|
||||
"""
|
||||
ram_cache: "ModelCacheBase[AnyModel]" = self._services.model_manager.load.ram_cache
|
||||
return ram_cache.get_execution_device()
|
||||
|
||||
def torch_dtype(self, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
"""
|
||||
Return a precision type to use with the current invocation and torch device.
|
||||
|
||||
Args:
|
||||
device: Optional device.
|
||||
|
||||
Returns:
|
||||
A torch.dtype suited for the current device.
|
||||
"""
|
||||
return TorchDevice.choose_torch_dtype(device)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Provides access to various services and data for the current invocation.
|
||||
|
||||
@@ -289,7 +289,7 @@ def prepare_control_image(
|
||||
width: int,
|
||||
height: int,
|
||||
num_channels: int = 3,
|
||||
device: str = "cuda",
|
||||
device: str | torch.device = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
||||
@@ -304,7 +304,7 @@ def prepare_control_image(
|
||||
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
|
||||
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
|
||||
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
|
||||
device (str, optional): The target device for the output image. Defaults to "cuda".
|
||||
device (str | torch.Device, optional): The target device for the output image. Defaults to "cuda".
|
||||
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
|
||||
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
|
||||
Defaults to True.
|
||||
|
||||
@@ -25,7 +25,6 @@ from enum import Enum
|
||||
from typing import Literal, Optional, Type, TypeAlias, Union
|
||||
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
@@ -38,7 +37,7 @@ from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ConfigMixin, ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
@@ -178,7 +177,6 @@ class ModelConfigBase(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
"""Extend the pydantic schema from a json."""
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
@@ -445,7 +443,7 @@ class ModelConfigFactory(object):
|
||||
model = dest_class.model_validate(model_data)
|
||||
else:
|
||||
# mypy doesn't typecheck TypeAdapters well?
|
||||
model = AnyModelConfigValidator.validate_python(model_data)
|
||||
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
|
||||
@@ -65,7 +65,8 @@ class LoadedModelWithoutConfig:
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
return self._locker.lock()
|
||||
self._locker.lock()
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
|
||||
@@ -8,10 +8,9 @@ model will be cleared and (re)loaded from disk when next needed.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, Generator, Generic, Optional, Set, TypeVar
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
@@ -52,13 +51,44 @@ class CacheRecord(Generic[T]):
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Read-only copy of the model *without weights* residing in the "meta device"
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
|
||||
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||
context manager call `model_on_device()`.
|
||||
"""
|
||||
|
||||
key: str
|
||||
size: int
|
||||
model: T
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock this record."""
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock this record."""
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
return self._locks > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -85,27 +115,14 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def execution_devices(self) -> Set[torch.device]:
|
||||
"""Return the set of available execution devices."""
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@property
|
||||
@abstractmethod
|
||||
def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]:
|
||||
"""Reserve an execution device (GPU) under the current thread id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_execution_device(self) -> torch.device:
|
||||
"""
|
||||
Return an execution device that has been reserved for current thread.
|
||||
|
||||
Note that reservations are done using the current thread's TID.
|
||||
It might be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
|
||||
May generate a ValueError if no GPU has been reserved.
|
||||
"""
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@@ -114,6 +131,16 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload from VRAM any models not actively in use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
@@ -175,11 +202,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
|
||||
"""Move a copy of the model into the indicated device and return it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
|
||||
@@ -18,19 +18,17 @@ context. Use like this:
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import sys
|
||||
import threading
|
||||
from contextlib import contextmanager, suppress
|
||||
import math
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
from threading import BoundedSemaphore
|
||||
from typing import Dict, Generator, List, Optional, Set
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -41,7 +39,9 @@ from .model_locker import ModelLocker
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
@@ -57,8 +57,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
@@ -66,19 +70,23 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._ram_lock = threading.Lock()
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
@@ -86,87 +94,25 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
# device to thread id
|
||||
self._device_lock = threading.Lock()
|
||||
self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()}
|
||||
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
||||
|
||||
self.logger.info(
|
||||
f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}"
|
||||
)
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
return self._lazy_offloading
|
||||
|
||||
@property
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
return self._storage_device
|
||||
|
||||
@property
|
||||
def execution_devices(self) -> Set[torch.device]:
|
||||
"""Return the set of available execution devices."""
|
||||
devices = self._execution_devices.keys()
|
||||
return set(devices)
|
||||
|
||||
def get_execution_device(self) -> torch.device:
|
||||
"""
|
||||
Return an execution device that has been reserved for current thread.
|
||||
|
||||
Note that reservations are done using the current thread's TID.
|
||||
It would be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
|
||||
May generate a ValueError if no GPU has been reserved.
|
||||
"""
|
||||
current_thread = threading.current_thread().ident
|
||||
assert current_thread is not None
|
||||
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||
if not assigned:
|
||||
raise ValueError(f"No GPU has been reserved for the use of thread {current_thread}")
|
||||
return assigned[0]
|
||||
|
||||
@contextmanager
|
||||
def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]:
|
||||
"""Reserve an execution device (e.g. GPU) for exclusive use by a generation thread.
|
||||
|
||||
Note that the reservation is done using the current thread's TID.
|
||||
It would be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
"""
|
||||
device = None
|
||||
with self._device_lock:
|
||||
current_thread = threading.current_thread().ident
|
||||
assert current_thread is not None
|
||||
|
||||
# look for a device that has already been assigned to this thread
|
||||
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||
if assigned:
|
||||
device = assigned[0]
|
||||
|
||||
# no device already assigned. Get one.
|
||||
if device is None:
|
||||
self._free_execution_device.acquire(timeout=timeout)
|
||||
with self._device_lock:
|
||||
free_device = [x for x, tid in self._execution_devices.items() if tid == 0]
|
||||
self._execution_devices[free_device[0]] = current_thread
|
||||
device = free_device[0]
|
||||
|
||||
# we are outside the lock region now
|
||||
self.logger.info(f"{current_thread} Reserved torch device {device}")
|
||||
|
||||
# Tell TorchDevice to use this object to get the torch device.
|
||||
TorchDevice.set_model_cache(self)
|
||||
try:
|
||||
yield device
|
||||
finally:
|
||||
with self._device_lock:
|
||||
self.logger.info(f"{current_thread} Released torch device {device}")
|
||||
self._execution_devices[device] = 0
|
||||
self._free_execution_device.release()
|
||||
torch.cuda.empty_cache()
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
return self._execution_device
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
@@ -211,16 +157,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
with self._ram_lock:
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
|
||||
cache_record = CacheRecord(key=key, model=model, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
def get(
|
||||
self,
|
||||
@@ -238,37 +184,36 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
with self._ram_lock:
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
@@ -280,34 +225,127 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
else:
|
||||
return model_key
|
||||
|
||||
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
|
||||
"""Move a copy of the model into the indicated device and return it.
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Move any unused models from VRAM."""
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
|
||||
:param cache_entry: The CacheRecord for the model
|
||||
:param target_device: The torch.device to move the model into
|
||||
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
with self._ram_lock:
|
||||
self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
source_device = cache_entry.device
|
||||
|
||||
# Some models don't have a state dictionary, in which case the
|
||||
# stored model will still reside in CPU
|
||||
if hasattr(cache_entry.model, "to"):
|
||||
model_in_gpu = copy.deepcopy(cache_entry.model)
|
||||
assert hasattr(model_in_gpu, "to")
|
||||
model_in_gpu.to(target_device)
|
||||
return model_in_gpu
|
||||
else:
|
||||
return cache_entry.model # what happens in CPU stays in CPU
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
# Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||
if not hasattr(cache_entry.model, "to"):
|
||||
return
|
||||
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# This operation is slightly faster than running `to()` on the whole model.
|
||||
#
|
||||
# When the model needs to be removed from VRAM we simply delete the copy
|
||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# in RAM into the model. So this operation is very fast.
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
try:
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self.storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device, non_blocking=True)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if (
|
||||
snapshot_before is not None
|
||||
and snapshot_after is not None
|
||||
and snapshot_before.vram is not None
|
||||
and snapshot_after.vram is not None
|
||||
):
|
||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||
if not math.isclose(
|
||||
vram_change,
|
||||
cache_entry.size,
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||
|
||||
in_ram_models = len(self._cached_models)
|
||||
self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}")
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
locked_in_vram_models = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
if hasattr(cache_record.model, "device"):
|
||||
if cache_record.model.device == self.storage_device:
|
||||
in_ram_models += 1
|
||||
else:
|
||||
in_vram_models += 1
|
||||
if cache_record.locked:
|
||||
locked_in_vram_models += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
@@ -330,14 +368,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
)
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# Expected refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
# 1 from onnx runtime object
|
||||
if refs <= (3 if "onnx" in model_key else 2):
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
@@ -364,26 +400,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
|
||||
if target_device.type != "cuda":
|
||||
return
|
||||
vram_device = ( # mem_get_info() needs an indexed device
|
||||
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
|
||||
)
|
||||
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
||||
if needed_size > free_mem:
|
||||
raise torch.cuda.OutOfMemoryError
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
try:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _device_name(device: torch.device) -> str:
|
||||
return f"{device.type}:{device.index}"
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
|
||||
@@ -10,8 +10,6 @@ from invokeai.backend.model_manager import AnyModel
|
||||
|
||||
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
|
||||
|
||||
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
|
||||
|
||||
|
||||
class ModelLocker(ModelLockerBase):
|
||||
"""Internal class that mediates movement in and out of GPU."""
|
||||
@@ -31,29 +29,33 @@ class ModelLocker(ModelLockerBase):
|
||||
"""Return the model without moving it around."""
|
||||
return self._cache_entry.model
|
||||
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return self._cache_entry.state_dict
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
self._cache_entry.lock()
|
||||
try:
|
||||
device = self._cache.get_execution_device()
|
||||
model_on_device = self._cache.model_to_device(self._cache_entry, device)
|
||||
self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}")
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||
self._cache_entry.loaded = True
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||
self._cache.print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
|
||||
return model_on_device
|
||||
return self.model
|
||||
|
||||
# It is no longer necessary to move the model out of VRAM
|
||||
# because it will be removed when it goes out of scope
|
||||
# in the caller's context
|
||||
def unlock(self) -> None:
|
||||
"""Call upon exit from context."""
|
||||
self._cache.print_cuda_stats()
|
||||
|
||||
# This is no longer in use in MGPU.
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return None
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(0)
|
||||
self._cache.print_cuda_stats()
|
||||
|
||||
@@ -22,7 +22,8 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class VAELoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
@@ -39,8 +40,12 @@ class VAELoader(GenericDiffusersLoader):
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = self._app_config.legacy_conf_path / config.config_path
|
||||
# TODO(MM2): check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = self._app_config.legacy_conf_path / config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
|
||||
@@ -451,16 +451,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# VAEs of all base types have the same structure, so we wimp out and
|
||||
# guess using the name.
|
||||
for regexp, basetype in [
|
||||
(r"xl", BaseModelType.StableDiffusionXL),
|
||||
(r"sd2", BaseModelType.StableDiffusion2),
|
||||
(r"vae", BaseModelType.StableDiffusion1),
|
||||
]:
|
||||
if re.search(regexp, self.model_path.name, re.IGNORECASE):
|
||||
return basetype
|
||||
raise InvalidModelConfigException("Cannot determine base type")
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
@@ -35,8 +34,6 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
|
||||
# TODO: rename smth like ModelPatcher and add TI method?
|
||||
class ModelPatcher:
|
||||
_thread_lock = threading.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
@@ -109,7 +106,7 @@ class ModelPatcher:
|
||||
"""
|
||||
original_weights = {}
|
||||
try:
|
||||
with torch.no_grad(), cls._thread_lock:
|
||||
with torch.no_grad():
|
||||
for lora, lora_weight in loras:
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
@@ -132,7 +129,9 @@ class ModelPatcher:
|
||||
dtype = module.weight.dtype
|
||||
|
||||
if module_key not in original_weights:
|
||||
if model_state_dict is None: # no CPU copy of the state dict was provided
|
||||
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
||||
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
||||
else:
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
@@ -10,12 +10,11 @@ import PIL.Image
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
@@ -26,6 +25,7 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,56 +38,18 @@ class PipelineIntermediateState:
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskLatents:
|
||||
"""Add the channels required for inpainting model input.
|
||||
|
||||
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
|
||||
and the latent encoding of the base image.
|
||||
|
||||
This class assumes the same mask and base image should apply to all items in the batch.
|
||||
"""
|
||||
|
||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
mask: torch.Tensor
|
||||
initial_image_latents: torch.Tensor
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
model_input = self.add_mask_channels(latents)
|
||||
return self.forward(model_input, t, text_embeddings, **kwargs)
|
||||
|
||||
def add_mask_channels(self, latents):
|
||||
batch_size = latents.size(0)
|
||||
# duplicate mask and latents for each batch
|
||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
# add mask and image as additional channels
|
||||
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
|
||||
return model_input
|
||||
|
||||
|
||||
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
|
||||
return isinstance(b, torch.Tensor) and (a.size() == b.size())
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
mask: torch.FloatTensor
|
||||
mask_latents: torch.FloatTensor
|
||||
mask: torch.Tensor
|
||||
mask_latents: torch.Tensor
|
||||
scheduler: SchedulerMixin
|
||||
noise: torch.Tensor
|
||||
gradient_mask: bool
|
||||
is_gradient_mask: bool
|
||||
|
||||
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
return self.apply_mask(latents, t)
|
||||
|
||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||
def apply_mask(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
if t.dim() == 0:
|
||||
@@ -100,7 +62,7 @@ class AddsMaskGuidance:
|
||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
if self.gradient_mask:
|
||||
if self.is_gradient_mask:
|
||||
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
|
||||
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
|
||||
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||
@@ -200,7 +162,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
control_model: ControlNetModel = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
@@ -214,8 +175,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||
self.control_model = control_model
|
||||
self.use_ip_adapter = False
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||
"""
|
||||
@@ -280,116 +239,131 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
raise Exception("Should not be called")
|
||||
|
||||
def add_inpainting_channels_to_latents(
|
||||
self, latents: torch.Tensor, masked_ref_image_latents: torch.Tensor, inpainting_mask: torch.Tensor
|
||||
):
|
||||
"""Given a `latents` tensor, adds the mask and image latents channels required for inpainting.
|
||||
|
||||
Standard (non-inpainting) SD UNet models expect an input with shape (N, 4, H, W). Inpainting models expect an
|
||||
input of shape (N, 9, H, W). The 9 channels are defined as follows:
|
||||
- Channel 0-3: The latents being denoised.
|
||||
- Channel 4: The mask indicating which parts of the image are being inpainted.
|
||||
- Channel 5-8: The latent representation of the masked reference image being inpainted.
|
||||
|
||||
This function assumes that the same mask and base image should apply to all items in the batch.
|
||||
"""
|
||||
# Validate assumptions about input tensor shapes.
|
||||
batch_size, latent_channels, latent_height, latent_width = latents.shape
|
||||
assert latent_channels == 4
|
||||
assert masked_ref_image_latents.shape == [1, 4, latent_height, latent_width]
|
||||
assert inpainting_mask == [1, 1, latent_height, latent_width]
|
||||
|
||||
# Repeat original_image_latents and inpainting_mask to match the latents batch size.
|
||||
original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1)
|
||||
inpainting_mask = inpainting_mask.expand(batch_size, -1, -1, -1)
|
||||
|
||||
# Concatenate along the channel dimension.
|
||||
return torch.cat([latents, inpainting_mask, original_image_latents], dim=1)
|
||||
|
||||
def latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
conditioning_data: TextConditioningData,
|
||||
*,
|
||||
noise: Optional[torch.Tensor],
|
||||
seed: int,
|
||||
timesteps: torch.Tensor,
|
||||
init_timestep: torch.Tensor,
|
||||
additional_guidance: List[Callable] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None],
|
||||
control_data: list[ControlNetData] | None = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
gradient_mask: Optional[bool] = False,
|
||||
seed: int,
|
||||
is_gradient_mask: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if init_timestep.shape[0] == 0:
|
||||
return latents
|
||||
"""Denoise the latents.
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
Args:
|
||||
latents: The latent-space image to denoise.
|
||||
- If we are inpainting, this is the initial latent image before noise has been added.
|
||||
- If we are generating a new image, this should be initialized to zeros.
|
||||
- In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
|
||||
scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method.
|
||||
conditioning_data: Text conditionging data.
|
||||
noise: Noise used for two purposes:
|
||||
1. Used by the scheduler to noise the initial `latents` before denoising.
|
||||
2. Used to noise the `masked_latents` when inpainting.
|
||||
`noise` should be None if the `latents` tensor has already been noised.
|
||||
seed: The seed used to generate the noise for the denoising process.
|
||||
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
|
||||
same noise used earlier in the pipeline. This should really be handled in a clearer way.
|
||||
timesteps: The timestep schedule for the denoising process.
|
||||
init_timestep: The first timestep in the schedule.
|
||||
TODO(ryand): I'm pretty sure this should always be the same as timesteps[0:1]. Confirm that that is the
|
||||
case, and remove this duplicate param.
|
||||
callback: A callback function that is called to report progress during the denoising process.
|
||||
control_data: ControlNet data.
|
||||
ip_adapter_data: IP-Adapter data.
|
||||
t2i_adapter_data: T2I-Adapter data.
|
||||
mask: A mask indicating which parts of the image are being inpainted. The presence of mask is used to
|
||||
determine whether we are inpainting or not. `mask` should have the same spatial dimensions as the
|
||||
`latents` tensor.
|
||||
TODO(ryand): Check and document the expected dtype, range, and values used to represent
|
||||
foreground/background.
|
||||
masked_latents: A latent-space representation of a masked inpainting reference image. This tensor is only
|
||||
used if an *inpainting* model is being used i.e. this tensor is not used when inpainting with a standard
|
||||
SD UNet model.
|
||||
is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not.
|
||||
"""
|
||||
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
||||
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
||||
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
|
||||
return latents
|
||||
|
||||
orig_latents = latents.clone()
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = init_timestep.expand(batch_size)
|
||||
batched_init_timestep = init_timestep.expand(batch_size)
|
||||
|
||||
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
if noise is not None:
|
||||
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
|
||||
# full noise. Investigate the history of why this got commented out.
|
||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||
|
||||
if mask is not None:
|
||||
if is_inpainting_model(self.unet):
|
||||
if masked_latents is None:
|
||||
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
||||
|
||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
||||
self._unet_forward, mask, masked_latents
|
||||
)
|
||||
else:
|
||||
# if no noise provided, noisify unmasked area based on seed
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||
|
||||
try:
|
||||
latents = self.generate_latents_from_embeddings(
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
callback=callback,
|
||||
)
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
# restore unmasked part after the last step is completed
|
||||
# in-process masking happens before each step
|
||||
if mask is not None:
|
||||
if gradient_mask:
|
||||
latents = torch.where(mask > 0, latents, orig_latents)
|
||||
else:
|
||||
latents = torch.lerp(
|
||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
timesteps,
|
||||
conditioning_data: TextConditioningData,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
*,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
) -> torch.Tensor:
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
# Handle mask guidance (a.k.a. inpainting).
|
||||
mask_guidance: AddsMaskGuidance | None = None
|
||||
if mask is not None and not is_inpainting_model(self.unet):
|
||||
# We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will
|
||||
# apply mask guidance to the latents.
|
||||
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents
|
||||
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
# We still need noise for inpainting, so we generate it from the seed here.
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
mask_guidance = AddsMaskGuidance(
|
||||
mask=mask,
|
||||
mask_latents=orig_latents,
|
||||
scheduler=self.scheduler,
|
||||
noise=noise,
|
||||
is_gradient_mask=is_gradient_mask,
|
||||
)
|
||||
|
||||
use_ip_adapter = ip_adapter_data is not None
|
||||
use_regional_prompting = (
|
||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||
)
|
||||
unet_attention_patcher = None
|
||||
self.use_ip_adapter = use_ip_adapter
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
if use_ip_adapter or use_regional_prompting:
|
||||
@@ -402,28 +376,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
|
||||
with attn_ctx:
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
)
|
||||
|
||||
# print("timesteps:", timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t = t.expand(batch_size)
|
||||
step_output = self.step(
|
||||
batched_t,
|
||||
latents,
|
||||
conditioning_data,
|
||||
t=batched_t,
|
||||
latents=latents,
|
||||
conditioning_data=conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
mask_guidance=mask_guidance,
|
||||
mask=mask,
|
||||
masked_latents=masked_latents,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
@@ -431,19 +405,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=i,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
)
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=i,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
)
|
||||
)
|
||||
|
||||
return latents
|
||||
# restore unmasked part after the last step is completed
|
||||
# in-process masking happens before each step
|
||||
if mask is not None:
|
||||
if is_gradient_mask:
|
||||
latents = torch.where(mask > 0, latents, orig_latents)
|
||||
else:
|
||||
latents = torch.lerp(
|
||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(
|
||||
@@ -454,19 +437,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
mask_guidance: AddsMaskGuidance | None,
|
||||
mask: torch.Tensor | None,
|
||||
masked_latents: torch.Tensor | None,
|
||||
control_data: list[ControlNetData] | None = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
# one day we will expand this extension point, but for now it just does denoise masking
|
||||
for guidance in additional_guidance:
|
||||
latents = guidance(latents, timestep)
|
||||
# Handle masked image-to-image (a.k.a inpainting).
|
||||
if mask_guidance is not None:
|
||||
# NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...).
|
||||
latents = mask_guidance(latents, timestep)
|
||||
|
||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
@@ -514,6 +498,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
down_intrablock_additional_residuals = accum_adapter_state
|
||||
|
||||
# Handle inpainting models.
|
||||
if is_inpainting_model(self.unet):
|
||||
# NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after*
|
||||
# self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image
|
||||
# latents.
|
||||
if mask is not None:
|
||||
if masked_latents is None:
|
||||
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||
latent_model_input = self.add_inpainting_channels_to_latents(
|
||||
latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask
|
||||
)
|
||||
else:
|
||||
# We are using an inpainting model, but no mask was provided, so we are not really "inpainting".
|
||||
# We generate a global mask and empty original image so that we can still generate in this
|
||||
# configuration.
|
||||
# TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting
|
||||
# to do this.
|
||||
# TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake'
|
||||
# mask and original image once rather than on every denoising step.
|
||||
latent_model_input = self.add_inpainting_channels_to_latents(
|
||||
latents=latent_model_input,
|
||||
masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]),
|
||||
inpainting_mask=torch.ones_like(latent_model_input[:1, :1]),
|
||||
)
|
||||
|
||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||
sample=latent_model_input,
|
||||
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||
@@ -542,17 +551,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||
|
||||
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
|
||||
for guidance in additional_guidance:
|
||||
# apply the mask to any "denoised" or "pred_original_sample" fields
|
||||
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting
|
||||
# again.
|
||||
if mask_guidance is not None:
|
||||
# Apply the mask to any "denoised" or "pred_original_sample" fields.
|
||||
if hasattr(step_output, "denoised"):
|
||||
step_output.pred_original_sample = guidance(step_output.denoised, self.scheduler.timesteps[-1])
|
||||
step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1])
|
||||
elif hasattr(step_output, "pred_original_sample"):
|
||||
step_output.pred_original_sample = guidance(
|
||||
step_output.pred_original_sample = mask_guidance(
|
||||
step_output.pred_original_sample, self.scheduler.timesteps[-1]
|
||||
)
|
||||
else:
|
||||
step_output.pred_original_sample = guidance(latents, self.scheduler.timesteps[-1])
|
||||
step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1])
|
||||
|
||||
return step_output
|
||||
|
||||
@@ -575,17 +585,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
**kwargs,
|
||||
):
|
||||
"""predict the noise residual"""
|
||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||
# Pad out normal non-inpainting inputs for an inpainting model.
|
||||
# FIXME: There are too many layers of functions and we have too many different ways of
|
||||
# overriding things! This should get handled in a way more consistent with the other
|
||||
# use of AddsMaskLatents.
|
||||
latents = AddsMaskLatents(
|
||||
self._unet_forward,
|
||||
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
|
||||
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype),
|
||||
).add_mask_channels(latents)
|
||||
|
||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||
return self.unet(
|
||||
latents,
|
||||
|
||||
@@ -32,11 +32,8 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||
assert self.pooled_embeds.device == device
|
||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||
result = super().to(device=device, dtype=dtype)
|
||||
assert self.embeds.device == device
|
||||
return result
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -294,31 +293,24 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
added_cond_kwargs = None
|
||||
try:
|
||||
if conditioning_data.is_sdxl():
|
||||
# tid = threading.current_thread().ident
|
||||
# print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.uncond_text.add_time_ids,
|
||||
conditioning_data.cond_text.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
tid = threading.current_thread().ident
|
||||
print(f"DEBUG: {tid} {str(e)}")
|
||||
raise e
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.uncond_text.add_time_ids,
|
||||
conditioning_data.cond_text.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
|
||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||
|
||||
242
invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
Normal file
242
invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
Normal file
@@ -0,0 +1,242 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
PipelineIntermediateState,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||||
from invokeai.backend.tiles.utils import TBLR
|
||||
|
||||
# The maximum number of regions with compatible sizes that will be batched together.
|
||||
# Larger batch sizes improve speed, but require more device memory.
|
||||
MAX_REGION_BATCH_SIZE = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiDiffusionRegionConditioning:
|
||||
# Region coords in latent space.
|
||||
region: TBLR
|
||||
text_conditioning_data: TextConditioningData
|
||||
control_data: list[ControlNetData]
|
||||
|
||||
|
||||
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
|
||||
|
||||
def _split_into_region_batches(
|
||||
self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]
|
||||
) -> list[list[MultiDiffusionRegionConditioning]]:
|
||||
# Group the regions by shape. Only regions with the same shape can be batched together.
|
||||
conditioning_by_shape: dict[tuple[int, int], list[MultiDiffusionRegionConditioning]] = {}
|
||||
for region_conditioning in multi_diffusion_conditioning:
|
||||
shape_hw = (
|
||||
region_conditioning.region.bottom - region_conditioning.region.top,
|
||||
region_conditioning.region.right - region_conditioning.region.left,
|
||||
)
|
||||
# In python, a tuple of hashable objects is hashable, so can be used as a key in a dict.
|
||||
if shape_hw not in conditioning_by_shape:
|
||||
conditioning_by_shape[shape_hw] = []
|
||||
conditioning_by_shape[shape_hw].append(region_conditioning)
|
||||
|
||||
# Split the regions into batches, respecting the MAX_REGION_BATCH_SIZE constraint.
|
||||
region_conditioning_batches = []
|
||||
for region_conditioning_batch in conditioning_by_shape.values():
|
||||
for i in range(0, len(region_conditioning_batch), MAX_REGION_BATCH_SIZE):
|
||||
region_conditioning_batches.append(region_conditioning_batch[i : i + MAX_REGION_BATCH_SIZE])
|
||||
|
||||
return region_conditioning_batches
|
||||
|
||||
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
|
||||
"""Check the input conditioning and confirm that regional prompting is not used."""
|
||||
for region_conditioning in multi_diffusion_conditioning:
|
||||
if (
|
||||
region_conditioning.text_conditioning_data.cond_regions is not None
|
||||
or region_conditioning.text_conditioning_data.uncond_regions is not None
|
||||
):
|
||||
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
||||
|
||||
def multi_diffusion_denoise(
|
||||
self,
|
||||
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
|
||||
latents: torch.Tensor,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
noise: Optional[torch.Tensor],
|
||||
timesteps: torch.Tensor,
|
||||
init_timestep: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None],
|
||||
) -> torch.Tensor:
|
||||
self._check_regional_prompting(multi_diffusion_conditioning)
|
||||
|
||||
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
||||
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
||||
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
|
||||
return latents
|
||||
|
||||
batch_size, _, latent_height, latent_width = latents.shape
|
||||
batched_init_timestep = init_timestep.expand(batch_size)
|
||||
|
||||
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
if noise is not None:
|
||||
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
|
||||
# full noise. Investigate the history of why this got commented out.
|
||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||
|
||||
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
|
||||
# cropping into regions.
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
|
||||
# Populate a weighted mask that will be used to combine the results from each region after every step.
|
||||
# For now, we assume that each region has the same weight (1.0).
|
||||
region_weight_mask = torch.zeros(
|
||||
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
|
||||
)
|
||||
for region_conditioning in multi_diffusion_conditioning:
|
||||
region = region_conditioning.region
|
||||
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
|
||||
|
||||
# Group the region conditioning into batches for faster processing.
|
||||
# region_conditioning_batches[b][r] is the r'th region in the b'th batch.
|
||||
region_conditioning_batches = self._split_into_region_batches(multi_diffusion_conditioning)
|
||||
|
||||
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
|
||||
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
|
||||
# separate scheduler state for each region batch.
|
||||
region_batch_schedulers: list[SchedulerMixin] = [
|
||||
copy.deepcopy(self.scheduler) for _ in region_conditioning_batches
|
||||
]
|
||||
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t = t.expand(batch_size)
|
||||
|
||||
merged_latents = torch.zeros_like(latents)
|
||||
merged_pred_original: torch.Tensor | None = None
|
||||
for region_batch_idx, region_conditioning_batch in enumerate(region_conditioning_batches):
|
||||
# Switch to the scheduler for the region batch.
|
||||
self.scheduler = region_batch_schedulers[region_batch_idx]
|
||||
|
||||
# TODO(ryand): This logic has not yet been tested with input latents with a batch_size > 1.
|
||||
|
||||
# Prepare the latents for the region batch.
|
||||
batch_latents = torch.cat(
|
||||
[
|
||||
latents[
|
||||
:,
|
||||
:,
|
||||
region_conditioning.region.top : region_conditioning.region.bottom,
|
||||
region_conditioning.region.left : region_conditioning.region.right,
|
||||
]
|
||||
for region_conditioning in region_conditioning_batch
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(ryand): Do we have to repeat the text_conditioning_data to match the batch size? Or does step()
|
||||
# handle broadcasting properly?
|
||||
|
||||
# TODO(ryand): Resume here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
# Run the denoising step on the region.
|
||||
step_output = self.step(
|
||||
t=batched_t,
|
||||
latents=batch_latents,
|
||||
conditioning_data=region_conditioning.text_conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=total_step_count,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
mask_guidance=None,
|
||||
mask=None,
|
||||
masked_latents=None,
|
||||
control_data=region_conditioning.control_data,
|
||||
)
|
||||
# Run a denoising step on the region.
|
||||
# step_output = self._region_step(
|
||||
# region_conditioning=region_conditioning,
|
||||
# t=batched_t,
|
||||
# latents=latents,
|
||||
# step_index=i,
|
||||
# total_step_count=len(timesteps),
|
||||
# scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
# )
|
||||
|
||||
# Store the results from the region.
|
||||
region = region_conditioning.region
|
||||
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
|
||||
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
|
||||
if pred_orig_sample is not None:
|
||||
# If one region has pred_original_sample, then we can assume that all regions will have it, because
|
||||
# they all use the same scheduler.
|
||||
if merged_pred_original is None:
|
||||
merged_pred_original = torch.zeros_like(latents)
|
||||
merged_pred_original[:, :, region.top : region.bottom, region.left : region.right] += (
|
||||
pred_orig_sample
|
||||
)
|
||||
|
||||
# Normalize the merged results.
|
||||
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
|
||||
predicted_original = None
|
||||
if merged_pred_original is not None:
|
||||
predicted_original = torch.where(
|
||||
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
|
||||
)
|
||||
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=i,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
)
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.inference_mode()
|
||||
def _region_batch_step(
|
||||
self,
|
||||
region_conditioning: MultiDiffusionRegionConditioning,
|
||||
t: torch.Tensor,
|
||||
latents: torch.Tensor,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
):
|
||||
# Crop the inputs to the region.
|
||||
region_latents = latents[
|
||||
:,
|
||||
:,
|
||||
region_conditioning.region.top : region_conditioning.region.bottom,
|
||||
region_conditioning.region.left : region_conditioning.region.right,
|
||||
]
|
||||
|
||||
# Run the denoising step on the region.
|
||||
return self.step(
|
||||
t=t,
|
||||
latents=region_latents,
|
||||
conditioning_data=region_conditioning.text_conditioning_data,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
mask_guidance=None,
|
||||
mask=None,
|
||||
masked_latents=None,
|
||||
control_data=region_conditioning.control_data,
|
||||
)
|
||||
@@ -1,16 +1,10 @@
|
||||
"""Torch Device class provides torch device selection services."""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union
|
||||
from typing import Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from deprecated import deprecated
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
# legacy APIs
|
||||
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
@@ -48,23 +42,9 @@ PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NA
|
||||
class TorchDevice:
|
||||
"""Abstraction layer for torch devices."""
|
||||
|
||||
_model_cache: Optional["ModelCacheBase[AnyModel]"] = None
|
||||
|
||||
@classmethod
|
||||
def set_model_cache(cls, cache: "ModelCacheBase[AnyModel]"):
|
||||
"""Set the current model cache."""
|
||||
cls._model_cache = cache
|
||||
|
||||
@classmethod
|
||||
def choose_torch_device(cls) -> torch.device:
|
||||
"""Return the torch.device to use for accelerated inference."""
|
||||
if cls._model_cache:
|
||||
return cls._model_cache.get_execution_device()
|
||||
else:
|
||||
return cls._choose_device()
|
||||
|
||||
@classmethod
|
||||
def _choose_device(cls) -> torch.device:
|
||||
app_config = get_config()
|
||||
if app_config.device != "auto":
|
||||
device = torch.device(app_config.device)
|
||||
@@ -76,19 +56,11 @@ class TorchDevice:
|
||||
device = CPU_DEVICE
|
||||
return cls.normalize(device)
|
||||
|
||||
@classmethod
|
||||
def execution_devices(cls) -> Set[torch.device]:
|
||||
"""Return a list of torch.devices that can be used for accelerated inference."""
|
||||
app_config = get_config()
|
||||
if app_config.devices is None:
|
||||
return cls._lookup_execution_devices()
|
||||
return {torch.device(x) for x in app_config.devices}
|
||||
|
||||
@classmethod
|
||||
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
"""Return the precision to use for accelerated inference."""
|
||||
device = device or cls.choose_torch_device()
|
||||
config = get_config()
|
||||
device = device or cls._choose_device()
|
||||
if device.type == "cuda" and torch.cuda.is_available():
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||
@@ -136,13 +108,3 @@ class TorchDevice:
|
||||
@classmethod
|
||||
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||
return NAME_TO_PRECISION[precision_name]
|
||||
|
||||
@classmethod
|
||||
def _lookup_execution_devices(cls) -> Set[torch.device]:
|
||||
if torch.cuda.is_available():
|
||||
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||
elif torch.backends.mps.is_available():
|
||||
devices = {torch.device("mps")}
|
||||
else:
|
||||
devices = {torch.device("cpu")}
|
||||
return devices
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
#!/bin/env python
|
||||
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
def get_args() -> Namespace:
|
||||
parser = ArgumentParser(description="Update models database from yaml file")
|
||||
parser.add_argument("--root", type=Path, required=False, default=None)
|
||||
parser.add_argument("--yaml_file", type=Path, required=False, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def populate_config() -> InvokeAIAppConfig:
|
||||
args = get_args()
|
||||
config = get_config()
|
||||
if args.root:
|
||||
config._root = args.root
|
||||
if args.yaml_file:
|
||||
config.legacy_models_yaml_path = args.yaml_file
|
||||
else:
|
||||
config.legacy_models_yaml_path = config.root_path / "configs/models.yaml"
|
||||
return config
|
||||
|
||||
|
||||
def initialize_installer(config: InvokeAIAppConfig) -> ModelInstallService:
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = SqliteDatabase(config.db_path, logger)
|
||||
record_store = ModelRecordServiceSQL(db)
|
||||
queue = DownloadQueueService()
|
||||
queue.start()
|
||||
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=queue)
|
||||
return installer
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config = populate_config()
|
||||
installer = initialize_installer(config)
|
||||
installer._migrate_yaml(rename_yaml=False, overwrite_db=True)
|
||||
print("\n<INSTALLED MODELS>")
|
||||
print("\t".join(["key", "name", "type", "path"]))
|
||||
for model in installer.record_store.all_models():
|
||||
print("\t".join([model.key, model.name, model.type, (config.models_path / model.path).as_posix()]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,14 +14,13 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = mm2_model_manager.install.register_path(embedding_file)
|
||||
with mm2_model_manager.load.ram_cache.reserve_execution_device():
|
||||
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
|
||||
config = mm2_model_manager.store.get_model(key)
|
||||
loaded_model_2 = mm2_model_manager.load.load_model(config)
|
||||
config = mm2_model_manager.store.get_model(key)
|
||||
loaded_model_2 = mm2_model_manager.load.load_model(config)
|
||||
|
||||
assert loaded_model.config.key == loaded_model_2.config.key
|
||||
assert loaded_model.config.key == loaded_model_2.config.key
|
||||
|
||||
@@ -89,10 +89,11 @@ def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
logger=InvokeAILogger.get_logger(),
|
||||
max_cache_size=mm2_app_config.ram,
|
||||
max_vram_cache_size=mm2_app_config.vram,
|
||||
)
|
||||
convert_cache = ModelConvertCache(mm2_app_config.convert_cache_path)
|
||||
return ModelLoadService(
|
||||
|
||||
@@ -8,9 +8,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.backend.model_manager.load import ModelCache
|
||||
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
|
||||
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
|
||||
@@ -22,7 +20,6 @@ device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", t
|
||||
def test_device_choice(device_name):
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
TorchDevice.set_model_cache(None) # disable dynamic selection of GPU device
|
||||
torch_device = TorchDevice.choose_torch_device()
|
||||
assert torch_device == torch.device(device_name)
|
||||
|
||||
@@ -133,32 +130,3 @@ def test_legacy_precision_name():
|
||||
assert "float16" == choose_precision(torch.device("cuda"))
|
||||
assert "float16" == choose_precision(torch.device("mps"))
|
||||
assert "float32" == choose_precision(torch.device("cpu"))
|
||||
|
||||
|
||||
def test_multi_device_support_1():
|
||||
config = get_config()
|
||||
config.devices = ["cuda:0", "cuda:1"]
|
||||
assert TorchDevice.execution_devices() == {torch.device("cuda:0"), torch.device("cuda:1")}
|
||||
|
||||
|
||||
def test_multi_device_support_2():
|
||||
config = get_config()
|
||||
config.devices = None
|
||||
with (
|
||||
patch("torch.cuda.device_count", return_value=3),
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
):
|
||||
assert TorchDevice.execution_devices() == {
|
||||
torch.device("cuda:0"),
|
||||
torch.device("cuda:1"),
|
||||
torch.device("cuda:2"),
|
||||
}
|
||||
|
||||
|
||||
def test_multi_device_support_3():
|
||||
config = get_config()
|
||||
config.devices = ["cuda:0", "cuda:1"]
|
||||
cache = ModelCache()
|
||||
with cache.reserve_execution_device() as gpu:
|
||||
assert gpu in [torch.device(x) for x in config.devices]
|
||||
assert TorchDevice.choose_torch_device() == gpu
|
||||
|
||||
@@ -17,6 +17,7 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.images.images_default import ImageService
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403
|
||||
@@ -48,13 +49,13 @@ def mock_services() -> InvocationServices:
|
||||
model_manager=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
tensors=None, # type: ignore
|
||||
conditioning=None, # type: ignore
|
||||
performance_statistics=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -92,6 +92,7 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
|
||||
assert config.host == "192.168.1.1"
|
||||
assert config.port == 8080
|
||||
assert config.ram == 100
|
||||
assert config.vram == 50
|
||||
assert config.legacy_models_yaml_path == Path("/custom/models.yaml")
|
||||
# This should be stripped out
|
||||
assert not hasattr(config, "esrgan")
|
||||
|
||||
Reference in New Issue
Block a user