mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-19 18:58:10 -05:00
Compare commits
72 Commits
v4.2.6
...
ryan/promp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c09b345ec | ||
|
|
b7a1086325 | ||
|
|
3062fe2752 | ||
|
|
76a65d30cd | ||
|
|
d0435575c1 | ||
|
|
f9c61f1b6c | ||
|
|
a8cc5caf96 | ||
|
|
930ff559e4 | ||
|
|
473f4cc1c3 | ||
|
|
78d2b1b650 | ||
|
|
39e10d894c | ||
|
|
e16faa6370 | ||
|
|
83a86abce2 | ||
|
|
0c56d4a581 | ||
|
|
97a7f51721 | ||
|
|
710dc6b487 | ||
|
|
2ef3b49a79 | ||
|
|
3f79467f7b | ||
|
|
2c2ec8f0bc | ||
|
|
79e35bd0d3 | ||
|
|
137202b77c | ||
|
|
03e22c257b | ||
|
|
ae6d4fbc78 | ||
|
|
cd1bc1595a | ||
|
|
0583101c1c | ||
|
|
f866b49255 | ||
|
|
b7c6c63005 | ||
|
|
95e9f5323b | ||
|
|
6b0ca88177 | ||
|
|
7ad32dcad2 | ||
|
|
81991e072b | ||
|
|
cec345cb5c | ||
|
|
608cbe3f5c | ||
|
|
7905a46ca4 | ||
|
|
38343917f8 | ||
|
|
9f088d1bf5 | ||
|
|
fd8d1c12d4 | ||
|
|
d623bd429b | ||
|
|
499e4d4fde | ||
|
|
e961dd1dec | ||
|
|
7e00526999 | ||
|
|
3a9dda9177 | ||
|
|
bd8ae5d896 | ||
|
|
87e96e1be2 | ||
|
|
0bc60378d3 | ||
|
|
9cc852cf7f | ||
|
|
0428ce73a9 | ||
|
|
d0d2955992 | ||
|
|
d868d5d584 | ||
|
|
ab775726b7 | ||
|
|
650902dc29 | ||
|
|
7b5d4935b4 | ||
|
|
ecbff2aa44 | ||
|
|
0ce6ec634d | ||
|
|
d09999736c | ||
|
|
a405f14ea2 | ||
|
|
9d3739244f | ||
|
|
534528b85a | ||
|
|
114320ee69 | ||
|
|
6161aa73af | ||
|
|
1ab20f43c8 | ||
|
|
9328c17ded | ||
|
|
c1c8e55e8e | ||
|
|
504a42fe61 | ||
|
|
29c8ddfb88 | ||
|
|
95079dc7d4 | ||
|
|
2a1514272f | ||
|
|
59ce9cf41c | ||
|
|
e6abea7bc5 | ||
|
|
c335f92345 | ||
|
|
c1afe35704 | ||
|
|
e4813f800a |
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
@@ -39,6 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
@@ -53,6 +55,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -314,9 +321,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
unet: UNet2DConditionModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
cfg_scale: float | list[float],
|
||||
steps: int,
|
||||
cfg_rescale_multiplier: float,
|
||||
@@ -330,10 +338,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
uncond_list = [uncond_list]
|
||||
|
||||
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
cond_list, context, unet.device, unet.dtype
|
||||
cond_list, context, device, dtype
|
||||
)
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
uncond_list, context, unet.device, unet.dtype
|
||||
uncond_list, context, device, dtype
|
||||
)
|
||||
|
||||
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
@@ -341,14 +349,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
masks=cond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
dtype=dtype,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if isinstance(cfg_scale, list):
|
||||
@@ -707,9 +715,108 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return seed, noise, latents
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
if os.environ.get("USE_MODULAR_DENOISE", False):
|
||||
return self._new_invoke(context)
|
||||
else:
|
||||
return self._old_invoke(context)
|
||||
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=device, dtype=dtype)
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
# TODO: old backend, remove
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
scheduler,
|
||||
seed=seed,
|
||||
device=device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
denoise_ctx = DenoiseContext(
|
||||
inputs=DenoiseInputs(
|
||||
orig_latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
attention_processor_cls=CustomAttnProcessor2_0,
|
||||
),
|
||||
unet=None,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
### preview
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
ext_manager.add_extension(PreviewExt(step_callback))
|
||||
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
unet_info.model_on_device() as (model_state_dict, unet),
|
||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||
# ext: controlnet
|
||||
ext_manager.patch_extensions(unet),
|
||||
# ext: freeu, seamless, ip adapter, lora
|
||||
ext_manager.patch_unet(model_state_dict, unet),
|
||||
):
|
||||
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||
denoise_ctx.unet = unet
|
||||
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.detach().to("cpu")
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
@@ -788,7 +895,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
device=unet.device,
|
||||
dtype=unet.dtype,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
|
||||
@@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@@ -134,6 +135,7 @@ class FieldDescriptions:
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
spandrel_image_to_image_model = "Image-to-Image model"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
|
||||
144
invokeai/app/invocations/spandrel_image_to_image.py
Normal file
144
invokeai/app/invocations/spandrel_image_to_image.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||
|
||||
|
||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
|
||||
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
|
||||
|
||||
image: ImageField = InputField(description="The input image")
|
||||
image_to_image_model: ModelIdentifierField = InputField(
|
||||
title="Image-to-Image Model",
|
||||
description=FieldDescriptions.spandrel_image_to_image_model,
|
||||
ui_type=UIType.SpandrelImageToImageModel,
|
||||
)
|
||||
tile_size: int = InputField(
|
||||
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
||||
)
|
||||
|
||||
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
|
||||
return Tile(
|
||||
coords=TBLR(
|
||||
top=tile.coords.top * scale,
|
||||
bottom=tile.coords.bottom * scale,
|
||||
left=tile.coords.left * scale,
|
||||
right=tile.coords.right * scale,
|
||||
),
|
||||
overlap=TBLR(
|
||||
top=tile.overlap.top * scale,
|
||||
bottom=tile.overlap.bottom * scale,
|
||||
left=tile.overlap.left * scale,
|
||||
right=tile.overlap.right * scale,
|
||||
),
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
# revisit this.
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Compute the image tiles.
|
||||
if self.tile_size > 0:
|
||||
min_overlap = 20
|
||||
tiles = calc_tiles_min_overlap(
|
||||
image_height=image.height,
|
||||
image_width=image.width,
|
||||
tile_height=self.tile_size,
|
||||
tile_width=self.tile_size,
|
||||
min_overlap=min_overlap,
|
||||
)
|
||||
else:
|
||||
# No tiling. Generate a single tile that covers the entire image.
|
||||
min_overlap = 0
|
||||
tiles = [
|
||||
Tile(
|
||||
coords=TBLR(top=0, bottom=image.height, left=0, right=image.width),
|
||||
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||
)
|
||||
]
|
||||
|
||||
# Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate
|
||||
# over tiles left-to-right, top-to-bottom.
|
||||
tiles = sorted(tiles, key=lambda x: x.coords.left)
|
||||
tiles = sorted(tiles, key=lambda x: x.coords.top)
|
||||
|
||||
# Prepare input image for inference.
|
||||
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
# Run the model on each tile.
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# Scale the tiles for re-assembling the final image.
|
||||
scale = spandrel_model.scale
|
||||
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
|
||||
|
||||
# Prepare the output tensor.
|
||||
_, channels, height, width = image_tensor.shape
|
||||
output_tensor = torch.zeros(
|
||||
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
|
||||
)
|
||||
|
||||
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
|
||||
# Exit early if the invocation has been canceled.
|
||||
if context.util.is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# Extract the current tile from the input tensor.
|
||||
input_tile = image_tensor[
|
||||
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
|
||||
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
# Run the model on the tile.
|
||||
output_tile = spandrel_model.run(input_tile)
|
||||
|
||||
# Convert the output tile into the output tensor's format.
|
||||
# (N, C, H, W) -> (C, H, W)
|
||||
output_tile = output_tile.squeeze(0)
|
||||
# (C, H, W) -> (H, W, C)
|
||||
output_tile = output_tile.permute(1, 2, 0)
|
||||
output_tile = output_tile.clamp(0, 1)
|
||||
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
|
||||
|
||||
# Merge the output tile into the output tensor.
|
||||
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
|
||||
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
|
||||
# it seems unnecessary, but we may find a need in the future.
|
||||
top_overlap = scaled_tile.overlap.top // 2
|
||||
left_overlap = scaled_tile.overlap.left // 2
|
||||
output_tensor[
|
||||
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
|
||||
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
|
||||
:,
|
||||
] = output_tile[top_overlap:, left_overlap:, :]
|
||||
|
||||
# Convert the output tensor to a PIL image.
|
||||
np_image = output_tensor.detach().numpy().astype(np.uint8)
|
||||
pil_image = Image.fromarray(np_image)
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -175,6 +175,10 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
# Calculate the tile locations to cover the latent-space image.
|
||||
# TODO(ryand): In the future, we may want to revisit the tile overlap strategy. Things to consider:
|
||||
# - How much overlap 'context' to provide for each denoising step.
|
||||
# - How much overlap to use during merging/blending.
|
||||
# - Should we 'jitter' the tile locations in each step so that the seams are in different places?
|
||||
tiles = calc_tiles_min_overlap(
|
||||
image_height=latent_height,
|
||||
image_width=latent_width,
|
||||
@@ -218,7 +222,8 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
device=unet.device,
|
||||
dtype=unet.dtype,
|
||||
latent_height=latent_tile_height,
|
||||
latent_width=latent_tile_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
|
||||
@@ -124,16 +124,14 @@ class IPAdapter(RawModel):
|
||||
self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
def to(
|
||||
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
|
||||
):
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
|
||||
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||
self._image_proj_model.to(device=self.device, dtype=self.dtype)
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
# HACK(ryand): Fix this issue with circular imports.
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
@@ -57,14 +56,9 @@ class LoRALayerBase:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
@@ -106,19 +100,14 @@ class LoRALayer(LoRALayerBase):
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
@@ -167,23 +156,18 @@ class LoHALayer(LoRALayerBase):
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
@@ -264,12 +248,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
@@ -277,19 +256,19 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
@@ -319,15 +298,10 @@ class FullLayer(LoRALayerBase):
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
@@ -359,16 +333,11 @@ class IA3Layer(LoRALayerBase):
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
):
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
@@ -390,15 +359,10 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
@@ -521,7 +485,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@@ -67,6 +67,7 @@ class ModelType(str, Enum):
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
SpandrelImageToImage = "spandrel_image_to_image"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
@@ -371,6 +372,17 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
"""Model config for Spandrel Image to Image models."""
|
||||
|
||||
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
@@ -407,6 +419,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
|
||||
@@ -167,7 +167,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
size = calc_model_size_by_data(self.logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
running_on_cpu = self.execution_device == torch.device("cpu")
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
@@ -289,11 +290,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(
|
||||
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
|
||||
)
|
||||
new_dict[k] = v.to(target_device, copy=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
||||
cache_entry.model.to(target_device)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint
|
||||
)
|
||||
class SpandrelImageToImageModelLoader(ModelLoader):
|
||||
"""Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor)."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("Unexpected submodel requested for Spandrel model.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
model = SpandrelImageToImageModel.load_from_file(model_path)
|
||||
|
||||
torch_dtype = self._torch_dtype
|
||||
if not model.supports_dtype(torch_dtype):
|
||||
self._logger.warning(
|
||||
f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} "
|
||||
"model. Falling back to 'float32'."
|
||||
)
|
||||
torch_dtype = torch.float32
|
||||
model.to(dtype=torch_dtype)
|
||||
|
||||
return model
|
||||
@@ -15,6 +15,7 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
||||
@@ -33,7 +34,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
elif isinstance(model, CLIPTokenizer):
|
||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
||||
return 0
|
||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
|
||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
|
||||
return model.calc_size()
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import spandrel
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
@@ -25,6 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
CkptType = Dict[str | int, Any]
|
||||
@@ -220,24 +222,46 @@ class ModelProbe(object):
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
elif key.startswith(("lora_te_", "lora_unet_")):
|
||||
return ModelType.LoRA
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
|
||||
return ModelType.LoRA
|
||||
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
||||
elif key.startswith(("controlnet", "control_model", "input_blocks")):
|
||||
return ModelType.ControlNet
|
||||
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
|
||||
elif key.startswith(("image_proj.", "ip_adapter.")):
|
||||
return ModelType.IPAdapter
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
else:
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
# Check if the model can be loaded as a SpandrelImageToImageModel.
|
||||
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
|
||||
try:
|
||||
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
|
||||
# explored to avoid this:
|
||||
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
|
||||
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
|
||||
# supported on meta tensors.
|
||||
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
|
||||
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
|
||||
# maintain it, and the risk of false positive detections is higher.
|
||||
SpandrelImageToImageModel.load_from_file(model_path)
|
||||
return ModelType.SpandrelImageToImage
|
||||
except spandrel.UnsupportedModelError:
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
if "No such file or directory" in str(e):
|
||||
# This error is expected if the model_path does not exist (which is the case in some unit tests).
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
|
||||
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@@ -569,6 +593,11 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
@@ -776,6 +805,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
@@ -805,6 +839,7 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
@@ -814,5 +849,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
|
||||
@@ -340,6 +340,13 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="tile-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-tile-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
# endregion
|
||||
# region T2I Adapter
|
||||
StarterModel(
|
||||
@@ -399,6 +406,43 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
type=ModelType.T2IAdapter,
|
||||
),
|
||||
# endregion
|
||||
# region SpandrelImageToImage
|
||||
StarterModel(
|
||||
name="RealESRGAN_x4plus_anime_6B",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
description="A Real-ESRGAN 4x upscaling model (optimized for anime images).",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
StarterModel(
|
||||
name="RealESRGAN_x4plus",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
description="A Real-ESRGAN 4x upscaling model (general-purpose).",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
StarterModel(
|
||||
name="ESRGAN_SRx4_DF2KOST_official",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
description="The official ESRGAN 4x upscaling model.",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
StarterModel(
|
||||
name="RealESRGAN_x2plus",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
description="A Real-ESRGAN 2x upscaling model (general-purpose).",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
StarterModel(
|
||||
name="SwinIR - realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN-with-dict-keys-params-and-params_ema.pth",
|
||||
description="A SwinIR 4x upscaling model.",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
# endregion
|
||||
]
|
||||
|
||||
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -32,8 +32,27 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
"""
|
||||
|
||||
|
||||
# TODO: rename smth like ModelPatcher and add TI method?
|
||||
class ModelPatcher:
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
|
||||
"""A context manager that patches `unet` with the provided attention processor.
|
||||
|
||||
Args:
|
||||
unet (UNet2DConditionModel): The UNet model to patch.
|
||||
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
|
||||
"""
|
||||
unet_orig_processors = unet.attn_processors
|
||||
|
||||
# create separate instance for each attention, to be able modify each attention separately
|
||||
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
|
||||
try:
|
||||
unet.set_attn_processor(unet_new_processors)
|
||||
yield None
|
||||
|
||||
finally:
|
||||
unet.set_attn_processor(unet_orig_processors)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
@@ -139,15 +158,12 @@ class ModelPatcher:
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(
|
||||
device=TorchDevice.CPU_DEVICE,
|
||||
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
|
||||
)
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
@@ -156,7 +172,7 @@ class ModelPatcher:
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
module.weight += layer_weight.to(dtype=dtype)
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
||||
@@ -164,9 +180,7 @@ class ModelPatcher:
|
||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||
with torch.no_grad():
|
||||
for module_key, weight in original_weights.items():
|
||||
model.get_submodule(module_key).weight.copy_(
|
||||
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
||||
)
|
||||
model.get_submodule(module_key).weight.copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
|
||||
@@ -190,12 +190,7 @@ class IAIOnnxRuntimeModel(RawModel):
|
||||
return self.session.run(None, inputs)
|
||||
|
||||
# compatability with RawModel ABC
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
pass
|
||||
|
||||
# compatability with diffusers load code
|
||||
|
||||
@@ -1,15 +1,3 @@
|
||||
"""Base class for 'Raw' models.
|
||||
|
||||
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||
and is used for type checking of calls to the model patcher. Its main purpose
|
||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||
from lora.py.
|
||||
|
||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
@@ -17,13 +5,18 @@ import torch
|
||||
|
||||
|
||||
class RawModel(ABC):
|
||||
"""Abstract base class for 'Raw' model wrappers."""
|
||||
"""Base class for 'Raw' models.
|
||||
|
||||
The RawModel class is the base class of LoRAModelRaw, TextualInversionModelRaw, etc.
|
||||
and is used for type checking of calls to the model patcher. Its main purpose
|
||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||
from lora.py.
|
||||
|
||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
pass
|
||||
|
||||
139
invokeai/backend/spandrel_image_to_image_model.py
Normal file
139
invokeai/backend/spandrel_image_to_image_model.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from spandrel import ImageModelDescriptor, ModelLoader
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class SpandrelImageToImageModel(RawModel):
|
||||
"""A wrapper for a Spandrel Image-to-Image model.
|
||||
|
||||
The main reason for having a wrapper class is to integrate with the type handling of RawModel.
|
||||
"""
|
||||
|
||||
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
|
||||
self._spandrel_model = spandrel_model
|
||||
|
||||
@staticmethod
|
||||
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
|
||||
"""Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run().
|
||||
|
||||
Args:
|
||||
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||
"""
|
||||
image_np = np.array(image)
|
||||
# (H, W, C) -> (C, H, W)
|
||||
image_np = np.transpose(image_np, (2, 0, 1))
|
||||
image_np = image_np / 255
|
||||
image_tensor = torch.from_numpy(image_np).float()
|
||||
# (C, H, W) -> (N, C, H, W)
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
return image_tensor
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
|
||||
"""Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||
|
||||
Returns:
|
||||
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
||||
"""
|
||||
# (N, C, H, W) -> (C, H, W)
|
||||
tensor = tensor.squeeze(0)
|
||||
# (C, H, W) -> (H, W, C)
|
||||
tensor = tensor.permute(1, 2, 0)
|
||||
tensor = tensor.clamp(0, 1)
|
||||
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
|
||||
image = Image.fromarray(tensor)
|
||||
return image
|
||||
|
||||
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the image-to-image model.
|
||||
|
||||
Args:
|
||||
image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||
"""
|
||||
return self._spandrel_model(image_tensor)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str | Path):
|
||||
model = ModelLoader().load_from_file(file_path)
|
||||
if not isinstance(model, ImageModelDescriptor):
|
||||
raise ValueError(
|
||||
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
||||
"('ImageModelDescriptor')."
|
||||
)
|
||||
|
||||
return cls(spandrel_model=model)
|
||||
|
||||
@classmethod
|
||||
def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
||||
model = ModelLoader().load_from_state_dict(state_dict)
|
||||
if not isinstance(model, ImageModelDescriptor):
|
||||
raise ValueError(
|
||||
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
||||
"('ImageModelDescriptor')."
|
||||
)
|
||||
|
||||
return cls(spandrel_model=model)
|
||||
|
||||
def supports_dtype(self, dtype: torch.dtype) -> bool:
|
||||
"""Check if the model supports the given dtype."""
|
||||
if dtype == torch.float16:
|
||||
return self._spandrel_model.supports_half
|
||||
elif dtype == torch.bfloat16:
|
||||
return self._spandrel_model.supports_bfloat16
|
||||
elif dtype == torch.float32:
|
||||
# All models support float32.
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unexpected dtype '{dtype}'.")
|
||||
|
||||
def get_model_type_name(self) -> str:
|
||||
"""The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining
|
||||
consistent over time.
|
||||
"""
|
||||
return str(type(self._spandrel_model.model))
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
"""Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported.
|
||||
Note: The non_blocking parameter is currently ignored."""
|
||||
# TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will have to access the
|
||||
# model directly if we want to apply this optimization.
|
||||
self._spandrel_model.to(device=device, dtype=dtype)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
"""The device of the underlying model."""
|
||||
return self._spandrel_model.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""The dtype of the underlying model."""
|
||||
return self._spandrel_model.dtype
|
||||
|
||||
@property
|
||||
def scale(self) -> int:
|
||||
"""The scale of the model (e.g. 1x, 2x, 4x, etc.)."""
|
||||
return self._spandrel_model.scale
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Get size of the model in memory in bytes."""
|
||||
# HACK(ryand): Fix this issue with circular imports.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._spandrel_model.model)
|
||||
131
invokeai/backend/stable_diffusion/denoise_context.py
Normal file
131
invokeai/backend/stable_diffusion/denoise_context.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNetKwargs:
|
||||
sample: torch.Tensor
|
||||
timestep: Union[torch.Tensor, float, int]
|
||||
encoder_hidden_states: torch.Tensor
|
||||
|
||||
class_labels: Optional[torch.Tensor] = None
|
||||
timestep_cond: Optional[torch.Tensor] = None
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None
|
||||
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None
|
||||
# return_dict: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DenoiseInputs:
|
||||
"""Initial variables passed to denoise. Supposed to be unchanged."""
|
||||
|
||||
# The latent-space image to denoise.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
# - If we are inpainting, this is the initial latent image before noise has been added.
|
||||
# - If we are generating a new image, this should be initialized to zeros.
|
||||
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
|
||||
orig_latents: torch.Tensor
|
||||
|
||||
# kwargs forwarded to the scheduler.step() method.
|
||||
scheduler_step_kwargs: dict[str, Any]
|
||||
|
||||
# Text conditionging data.
|
||||
conditioning_data: TextConditioningData
|
||||
|
||||
# Noise used for two purposes:
|
||||
# 1. Used by the scheduler to noise the initial `latents` before denoising.
|
||||
# 2. Used to noise the `masked_latents` when inpainting.
|
||||
# `noise` should be None if the `latents` tensor has already been noised.
|
||||
# Shape: [1 or batch, channels, latent_height, latent_width]
|
||||
noise: Optional[torch.Tensor]
|
||||
|
||||
# The seed used to generate the noise for the denoising process.
|
||||
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
|
||||
# same noise used earlier in the pipeline. This should really be handled in a clearer way.
|
||||
seed: int
|
||||
|
||||
# The timestep schedule for the denoising process.
|
||||
timesteps: torch.Tensor
|
||||
|
||||
# The first timestep in the schedule. This is used to determine the initial noise level, so
|
||||
# should be populated if you want noise applied *even* if timesteps is empty.
|
||||
init_timestep: torch.Tensor
|
||||
|
||||
# Class of attention processor that is used.
|
||||
attention_processor_cls: Type[Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DenoiseContext:
|
||||
"""Context with all variables in denoise"""
|
||||
|
||||
# Initial variables passed to denoise. Supposed to be unchanged.
|
||||
inputs: DenoiseInputs
|
||||
|
||||
# Scheduler which used to apply noise predictions.
|
||||
scheduler: SchedulerMixin
|
||||
|
||||
# UNet model.
|
||||
unet: Optional[UNet2DConditionModel] = None
|
||||
|
||||
# Current state of latent-space image in denoising process.
|
||||
# None until `pre_denoise_loop` callback.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
latents: Optional[torch.Tensor] = None
|
||||
|
||||
# Current denoising step index.
|
||||
# None until `pre_step` callback.
|
||||
step_index: Optional[int] = None
|
||||
|
||||
# Current denoising step timestep.
|
||||
# None until `pre_step` callback.
|
||||
timestep: Optional[torch.Tensor] = None
|
||||
|
||||
# Arguments which will be passed to UNet model.
|
||||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||
unet_kwargs: Optional[UNetKwargs] = None
|
||||
|
||||
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
||||
# Supposed to be used only in `post_step` callback, otherwise can be None.
|
||||
step_output: Optional[SchedulerOutput] = None
|
||||
|
||||
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
||||
# Available in events inside step(between `pre_step` and `post_stop`).
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
latent_model_input: Optional[torch.Tensor] = None
|
||||
|
||||
# [TMP] Defines on which conditionings current unet call will be runned.
|
||||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||
conditioning_mode: Optional[ConditioningMode] = None
|
||||
|
||||
# [TMP] Noise predictions from negative conditioning.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
negative_noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# [TMP] Noise predictions from positive conditioning.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
positive_noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# Combined noise prediction from passed conditionings.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# Dictionary for extensions to pass extra info about denoise process to other extensions.
|
||||
extra: dict = field(default_factory=dict)
|
||||
@@ -23,21 +23,12 @@ from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
step: int
|
||||
order: int
|
||||
total_steps: int
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
mask: torch.Tensor
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -95,6 +102,12 @@ class TextConditioningRegions:
|
||||
assert self.masks.shape[1] == len(self.ranges)
|
||||
|
||||
|
||||
class ConditioningMode(Enum):
|
||||
Both = "both"
|
||||
Negative = "negative"
|
||||
Positive = "positive"
|
||||
|
||||
|
||||
class TextConditioningData:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -103,7 +116,7 @@ class TextConditioningData:
|
||||
uncond_regions: Optional[TextConditioningRegions],
|
||||
cond_regions: Optional[TextConditioningRegions],
|
||||
guidance_scale: Union[float, List[float]],
|
||||
guidance_rescale_multiplier: float = 0,
|
||||
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
|
||||
):
|
||||
self.uncond_text = uncond_text
|
||||
self.cond_text = cond_text
|
||||
@@ -114,6 +127,7 @@ class TextConditioningData:
|
||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
self.guidance_scale = guidance_scale
|
||||
# TODO: old backend, remove
|
||||
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||
@@ -121,3 +135,114 @@ class TextConditioningData:
|
||||
def is_sdxl(self):
|
||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
|
||||
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
|
||||
"""Fills unet arguments with data from provided conditionings.
|
||||
|
||||
Args:
|
||||
unet_kwargs (UNetKwargs): Object which stores UNet model arguments.
|
||||
conditioning_mode (ConditioningMode): Describes which conditionings should be used.
|
||||
"""
|
||||
_, _, h, w = unet_kwargs.sample.shape
|
||||
device = unet_kwargs.sample.device
|
||||
dtype = unet_kwargs.sample.dtype
|
||||
|
||||
# TODO: combine regions with conditionings
|
||||
if conditioning_mode == ConditioningMode.Both:
|
||||
conditionings = [self.uncond_text, self.cond_text]
|
||||
c_regions = [self.uncond_regions, self.cond_regions]
|
||||
elif conditioning_mode == ConditioningMode.Positive:
|
||||
conditionings = [self.cond_text]
|
||||
c_regions = [self.cond_regions]
|
||||
elif conditioning_mode == ConditioningMode.Negative:
|
||||
conditionings = [self.uncond_text]
|
||||
c_regions = [self.uncond_regions]
|
||||
else:
|
||||
raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}")
|
||||
|
||||
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
[c.embeds for c in conditionings]
|
||||
)
|
||||
|
||||
unet_kwargs.encoder_hidden_states = encoder_hidden_states
|
||||
unet_kwargs.encoder_attention_mask = encoder_attention_mask
|
||||
|
||||
if self.is_sdxl():
|
||||
added_cond_kwargs = dict( # noqa: C408
|
||||
text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
|
||||
time_ids=torch.cat([c.add_time_ids for c in conditionings]),
|
||||
)
|
||||
|
||||
unet_kwargs.added_cond_kwargs = added_cond_kwargs
|
||||
|
||||
if any(r is not None for r in c_regions):
|
||||
tmp_regions = []
|
||||
for c, r in zip(conditionings, c_regions, strict=True):
|
||||
if r is None:
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=dtype),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
)
|
||||
tmp_regions.append(r)
|
||||
|
||||
if unet_kwargs.cross_attention_kwargs is None:
|
||||
unet_kwargs.cross_attention_kwargs = {}
|
||||
|
||||
unet_kwargs.cross_attention_kwargs.update(
|
||||
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor:
|
||||
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
|
||||
|
||||
@classmethod
|
||||
def _pad_conditioning(
|
||||
cls,
|
||||
cond: torch.Tensor,
|
||||
target_len: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): Conditioning tensor which to pads by zeros.
|
||||
target_len (int): To which length(tokens count) pad tensor.
|
||||
"""
|
||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||
|
||||
if cond.shape[1] < target_len:
|
||||
conditioning_attention_mask = cls._pad_zeros(
|
||||
conditioning_attention_mask,
|
||||
pad_shape=(cond.shape[0], target_len - cond.shape[1]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
cond = cls._pad_zeros(
|
||||
cond,
|
||||
pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return cond, conditioning_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _concat_conditionings_for_batch(
|
||||
cls,
|
||||
conditionings: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Concatenate provided conditioning tensors to one batched tensor.
|
||||
If tensors have different sizes then pad them by zeros and creates
|
||||
encoder_attention_mask to exclude padding from attention.
|
||||
|
||||
Args:
|
||||
conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate.
|
||||
"""
|
||||
encoder_attention_mask = None
|
||||
max_len = max([c.shape[1] for c in conditionings])
|
||||
if any(c.shape[1] != max_len for c in conditionings):
|
||||
encoder_attention_masks = [None] * len(conditionings)
|
||||
for i in range(len(conditionings)):
|
||||
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len)
|
||||
encoder_attention_mask = torch.cat(encoder_attention_masks)
|
||||
|
||||
return torch.cat(conditionings), encoder_attention_mask
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
|
||||
|
||||
class RegionalPromptData:
|
||||
|
||||
140
invokeai/backend/stable_diffusion/diffusion_backend.py
Normal file
140
invokeai/backend/stable_diffusion/diffusion_backend.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
|
||||
|
||||
class StableDiffusionBackend:
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
):
|
||||
self.unet = unet
|
||||
self.scheduler = scheduler
|
||||
config = get_config()
|
||||
self._sequential_guidance = config.sequential_guidance
|
||||
|
||||
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
if ctx.inputs.init_timestep.shape[0] == 0:
|
||||
return ctx.inputs.orig_latents
|
||||
|
||||
ctx.latents = ctx.inputs.orig_latents.clone()
|
||||
|
||||
if ctx.inputs.noise is not None:
|
||||
batch_size = ctx.latents.shape[0]
|
||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
ctx.latents = ctx.scheduler.add_noise(
|
||||
ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size)
|
||||
)
|
||||
|
||||
# if no work to do, return latents
|
||||
if ctx.inputs.timesteps.shape[0] == 0:
|
||||
return ctx.latents
|
||||
|
||||
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
|
||||
# ext: preview[pre_denoise_loop, priority=low]
|
||||
ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx)
|
||||
|
||||
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
|
||||
# ext: inpaint (apply mask to latents on non-inpaint models)
|
||||
ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx)
|
||||
|
||||
# ext: tiles? [override: step]
|
||||
ctx.step_output = self.step(ctx, ext_manager)
|
||||
|
||||
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
|
||||
# ext: preview[post_step, priority=low]
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_STEP, ctx)
|
||||
|
||||
ctx.latents = ctx.step_output.prev_sample
|
||||
|
||||
# ext: inpaint[post_denoise_loop] (restore unmasked part)
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_DENOISE_LOOP, ctx)
|
||||
return ctx.latents
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput:
|
||||
ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep)
|
||||
|
||||
# TODO: conditionings as list(conditioning_data.to_unet_kwargs - ready)
|
||||
# Note: The current handling of conditioning doesn't feel very future-proof.
|
||||
# This might change in the future as new requirements come up, but for now,
|
||||
# this is the rough plan.
|
||||
if self._sequential_guidance:
|
||||
ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Negative)
|
||||
ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Positive)
|
||||
else:
|
||||
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
|
||||
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||
|
||||
# ext: override apply_cfg
|
||||
ctx.noise_pred = self.apply_cfg(ctx)
|
||||
|
||||
# ext: cfg_rescale [modify_noise_prediction]
|
||||
# TODO: rename
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
||||
|
||||
# clean up locals
|
||||
ctx.latent_model_input = None
|
||||
ctx.negative_noise_pred = None
|
||||
ctx.positive_noise_pred = None
|
||||
ctx.noise_pred = None
|
||||
|
||||
return step_output
|
||||
|
||||
@staticmethod
|
||||
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
|
||||
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = guidance_scale[ctx.step_index]
|
||||
|
||||
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||
|
||||
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
|
||||
sample = ctx.latent_model_input
|
||||
if conditioning_mode == ConditioningMode.Both:
|
||||
sample = torch.cat([sample] * 2)
|
||||
|
||||
ctx.unet_kwargs = UNetKwargs(
|
||||
sample=sample,
|
||||
timestep=ctx.timestep,
|
||||
encoder_hidden_states=None, # set later by conditoning
|
||||
cross_attention_kwargs=dict( # noqa: C408
|
||||
percent_through=ctx.step_index / len(ctx.inputs.timesteps),
|
||||
),
|
||||
)
|
||||
|
||||
ctx.conditioning_mode = conditioning_mode
|
||||
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
||||
|
||||
# ext: controlnet/ip/t2i [pre_unet]
|
||||
ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx)
|
||||
|
||||
# ext: inpaint [pre_unet, priority=low]
|
||||
# or
|
||||
# ext: inpaint [override: unet_forward]
|
||||
noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
|
||||
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx)
|
||||
|
||||
# clean up locals
|
||||
ctx.unet_kwargs = None
|
||||
ctx.conditioning_mode = None
|
||||
|
||||
return noise_pred
|
||||
|
||||
def _unet_forward(self, **kwargs) -> torch.Tensor:
|
||||
return self.unet(**kwargs).sample
|
||||
12
invokeai/backend/stable_diffusion/extension_callback_type.py
Normal file
12
invokeai/backend/stable_diffusion/extension_callback_type.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ExtensionCallbackType(Enum):
|
||||
SETUP = "setup"
|
||||
PRE_DENOISE_LOOP = "pre_denoise_loop"
|
||||
POST_DENOISE_LOOP = "post_denoise_loop"
|
||||
PRE_STEP = "pre_step"
|
||||
POST_STEP = "post_step"
|
||||
PRE_UNET = "pre_unet"
|
||||
POST_UNET = "post_unet"
|
||||
POST_APPLY_CFG = "post_apply_cfg"
|
||||
60
invokeai/backend/stable_diffusion/extensions/base.py
Normal file
60
invokeai/backend/stable_diffusion/extensions/base.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallbackMetadata:
|
||||
callback_type: ExtensionCallbackType
|
||||
order: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallbackFunctionWithMetadata:
|
||||
metadata: CallbackMetadata
|
||||
function: Callable[[DenoiseContext], None]
|
||||
|
||||
|
||||
def callback(callback_type: ExtensionCallbackType, order: int = 0):
|
||||
def _decorator(function):
|
||||
function._ext_metadata = CallbackMetadata(
|
||||
callback_type=callback_type,
|
||||
order=order,
|
||||
)
|
||||
return function
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
class ExtensionBase:
|
||||
def __init__(self):
|
||||
self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
||||
|
||||
# Register all of the callback methods for this instance.
|
||||
for func_name in dir(self):
|
||||
func = getattr(self, func_name)
|
||||
metadata = getattr(func, "_ext_metadata", None)
|
||||
if metadata is not None and isinstance(metadata, CallbackMetadata):
|
||||
if metadata.callback_type not in self._callbacks:
|
||||
self._callbacks[metadata.callback_type] = []
|
||||
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
|
||||
|
||||
def get_callbacks(self):
|
||||
return self._callbacks
|
||||
|
||||
@contextmanager
|
||||
def patch_extension(self, context: DenoiseContext):
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
yield None
|
||||
63
invokeai/backend/stable_diffusion/extensions/preview.py
Normal file
63
invokeai/backend/stable_diffusion/extensions/preview.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
|
||||
|
||||
# TODO: change event to accept image instead of latents
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
step: int
|
||||
order: int
|
||||
total_steps: int
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class PreviewExt(ExtensionBase):
|
||||
def __init__(self, callback: Callable[[PipelineIntermediateState], None]):
|
||||
super().__init__()
|
||||
self.callback = callback
|
||||
|
||||
# do last so that all other changes shown
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||
def initial_preview(self, ctx: DenoiseContext):
|
||||
self.callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=ctx.scheduler.order,
|
||||
total_steps=len(ctx.inputs.timesteps),
|
||||
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
|
||||
latents=ctx.latents,
|
||||
)
|
||||
)
|
||||
|
||||
# do last so that all other changes shown
|
||||
@callback(ExtensionCallbackType.POST_STEP, order=1000)
|
||||
def step_preview(self, ctx: DenoiseContext):
|
||||
if hasattr(ctx.step_output, "denoised"):
|
||||
predicted_original = ctx.step_output.denoised
|
||||
elif hasattr(ctx.step_output, "pred_original_sample"):
|
||||
predicted_original = ctx.step_output.pred_original_sample
|
||||
else:
|
||||
predicted_original = ctx.step_output.prev_sample
|
||||
|
||||
self.callback(
|
||||
PipelineIntermediateState(
|
||||
step=ctx.step_index,
|
||||
order=ctx.scheduler.order,
|
||||
total_steps=len(ctx.inputs.timesteps),
|
||||
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
|
||||
latents=ctx.step_output.prev_sample,
|
||||
predicted_original=predicted_original, # TODO: is there any reason for additional field?
|
||||
)
|
||||
)
|
||||
71
invokeai/backend/stable_diffusion/extensions_manager.py
Normal file
71
invokeai/backend/stable_diffusion/extensions_manager.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase
|
||||
|
||||
|
||||
class ExtensionsManager:
|
||||
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
|
||||
self._is_canceled = is_canceled
|
||||
|
||||
# A list of extensions in the order that they were added to the ExtensionsManager.
|
||||
self._extensions: List[ExtensionBase] = []
|
||||
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
||||
|
||||
def add_extension(self, extension: ExtensionBase):
|
||||
self._extensions.append(extension)
|
||||
self._regenerate_ordered_callbacks()
|
||||
|
||||
def _regenerate_ordered_callbacks(self):
|
||||
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
|
||||
self._ordered_callbacks = {}
|
||||
|
||||
# Fill the ordered callbacks dictionary.
|
||||
for extension in self._extensions:
|
||||
for callback_type, callbacks in extension.get_callbacks().items():
|
||||
if callback_type not in self._ordered_callbacks:
|
||||
self._ordered_callbacks[callback_type] = []
|
||||
self._ordered_callbacks[callback_type].extend(callbacks)
|
||||
|
||||
# Sort each callback list.
|
||||
for callback_type, callbacks in self._ordered_callbacks.items():
|
||||
# Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were
|
||||
# added will be preserved.
|
||||
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
|
||||
|
||||
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
callbacks = self._ordered_callbacks.get(callback_type, [])
|
||||
for cb in callbacks:
|
||||
cb.function(ctx)
|
||||
|
||||
@contextmanager
|
||||
def patch_extensions(self, context: DenoiseContext):
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
exit_stack.enter_context(ext.patch_extension(context))
|
||||
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# TODO: create logic in PR with extension which uses it
|
||||
yield None
|
||||
@@ -61,6 +61,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
# full noise. Investigate the history of why this got commented out.
|
||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||
assert isinstance(latents, torch.Tensor) # For static type checking.
|
||||
|
||||
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
|
||||
# cropping into regions.
|
||||
@@ -122,19 +123,42 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
control_data=region_conditioning.control_data,
|
||||
)
|
||||
|
||||
# Store the results from the region.
|
||||
# If two tiles overlap by more than the target overlap amount, crop the left and top edges of the
|
||||
# affected tiles to achieve the target overlap.
|
||||
# Build a region_weight matrix that applies gradient blending to the edges of the region.
|
||||
region = region_conditioning.region
|
||||
top_adjustment = max(0, region.overlap.top - target_overlap)
|
||||
left_adjustment = max(0, region.overlap.left - target_overlap)
|
||||
region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom)
|
||||
region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right)
|
||||
merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[
|
||||
:, :, top_adjustment:, left_adjustment:
|
||||
]
|
||||
# For now, we treat every region as having the same weight.
|
||||
merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0
|
||||
_, _, region_height, region_width = step_output.prev_sample.shape
|
||||
region_weight = torch.ones(
|
||||
(1, 1, region_height, region_width),
|
||||
dtype=latents.dtype,
|
||||
device=latents.device,
|
||||
)
|
||||
if region.overlap.left > 0:
|
||||
left_grad = torch.linspace(
|
||||
0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype
|
||||
).view((1, 1, 1, -1))
|
||||
region_weight[:, :, :, : region.overlap.left] *= left_grad
|
||||
if region.overlap.top > 0:
|
||||
top_grad = torch.linspace(
|
||||
0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype
|
||||
).view((1, 1, -1, 1))
|
||||
region_weight[:, :, : region.overlap.top, :] *= top_grad
|
||||
if region.overlap.right > 0:
|
||||
right_grad = torch.linspace(
|
||||
1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype
|
||||
).view((1, 1, 1, -1))
|
||||
region_weight[:, :, :, -region.overlap.right :] *= right_grad
|
||||
if region.overlap.bottom > 0:
|
||||
bottom_grad = torch.linspace(
|
||||
1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype
|
||||
).view((1, 1, -1, 1))
|
||||
region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad
|
||||
|
||||
# Update the merged results with the region results.
|
||||
merged_latents[
|
||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||
] += step_output.prev_sample * region_weight
|
||||
merged_latents_weights[
|
||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||
] += region_weight
|
||||
|
||||
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
|
||||
if pred_orig_sample is not None:
|
||||
@@ -142,9 +166,9 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
# they all use the same scheduler.
|
||||
if merged_pred_original is None:
|
||||
merged_pred_original = torch.zeros_like(latents)
|
||||
merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[
|
||||
:, :, top_adjustment:, left_adjustment:
|
||||
]
|
||||
merged_pred_original[
|
||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||
] += pred_orig_sample
|
||||
|
||||
# Normalize the merged results.
|
||||
latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)
|
||||
|
||||
@@ -65,17 +65,12 @@ class TextualInversionModelRaw(RawModel):
|
||||
|
||||
return result
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
for emb in [self.embedding, self.embedding_2]:
|
||||
if emb is not None:
|
||||
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
emb.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Get the size of this model in bytes."""
|
||||
|
||||
@@ -112,15 +112,3 @@ class TorchDevice:
|
||||
@classmethod
|
||||
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||
return NAME_TO_PRECISION[precision_name]
|
||||
|
||||
@staticmethod
|
||||
def get_non_blocking(to_device: torch.device) -> bool:
|
||||
"""Return the non_blocking flag to be used when moving a tensor to a given device.
|
||||
MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS.
|
||||
When moving _from_ MPS, we can use non-blocking operations.
|
||||
|
||||
See:
|
||||
- https://github.com/pytorch/pytorch/issues/107455
|
||||
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28
|
||||
"""
|
||||
return False if to_device.type == "mps" else True
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
useLoRAModels,
|
||||
useMainModels,
|
||||
useRefinerModels,
|
||||
useSpandrelImageToImageModels,
|
||||
useT2IAdapterModels,
|
||||
useVAEModels,
|
||||
} from 'services/api/hooks/modelsByType';
|
||||
@@ -71,6 +72,13 @@ const ModelList = () => {
|
||||
[vaeModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [spandrelImageToImageModels, { isLoading: isLoadingSpandrelImageToImageModels }] =
|
||||
useSpandrelImageToImageModels();
|
||||
const filteredSpandrelImageToImageModels = useMemo(
|
||||
() => modelsFilter(spandrelImageToImageModels, searchTerm, filteredModelType),
|
||||
[spandrelImageToImageModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const totalFilteredModels = useMemo(() => {
|
||||
return (
|
||||
filteredMainModels.length +
|
||||
@@ -80,7 +88,8 @@ const ModelList = () => {
|
||||
filteredControlNetModels.length +
|
||||
filteredT2IAdapterModels.length +
|
||||
filteredIPAdapterModels.length +
|
||||
filteredVAEModels.length
|
||||
filteredVAEModels.length +
|
||||
filteredSpandrelImageToImageModels.length
|
||||
);
|
||||
}, [
|
||||
filteredControlNetModels.length,
|
||||
@@ -91,6 +100,7 @@ const ModelList = () => {
|
||||
filteredRefinerModels.length,
|
||||
filteredT2IAdapterModels.length,
|
||||
filteredVAEModels.length,
|
||||
filteredSpandrelImageToImageModels.length,
|
||||
]);
|
||||
|
||||
return (
|
||||
@@ -143,6 +153,17 @@ const ModelList = () => {
|
||||
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
||||
)}
|
||||
{/* Spandrel Image to Image List */}
|
||||
{isLoadingSpandrelImageToImageModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading Image-to-Image Models..." />
|
||||
)}
|
||||
{!isLoadingSpandrelImageToImageModels && filteredSpandrelImageToImageModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Image-to-Image"
|
||||
modelList={filteredSpandrelImageToImageModels}
|
||||
key="spandrel-image-to-image"
|
||||
/>
|
||||
)}
|
||||
{totalFilteredModels === 0 && (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Text>{t('modelManager.noMatchingModels')}</Text>
|
||||
|
||||
@@ -21,6 +21,7 @@ export const ModelTypeFilter = () => {
|
||||
t2i_adapter: t('common.t2iAdapter'),
|
||||
ip_adapter: t('common.ipAdapter'),
|
||||
clip_vision: 'Clip Vision',
|
||||
spandrel_image_to_image: 'Image-to-Image',
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
|
||||
@@ -32,6 +32,8 @@ import {
|
||||
isSDXLMainModelFieldInputTemplate,
|
||||
isSDXLRefinerModelFieldInputInstance,
|
||||
isSDXLRefinerModelFieldInputTemplate,
|
||||
isSpandrelImageToImageModelFieldInputInstance,
|
||||
isSpandrelImageToImageModelFieldInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
@@ -54,6 +56,7 @@ import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||
@@ -125,6 +128,20 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (
|
||||
isSpandrelImageToImageModelFieldInputInstance(fieldInstance) &&
|
||||
isSpandrelImageToImageModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<SpandrelImageToImageModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
||||
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldSpandrelImageToImageModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
SpandrelImageToImageModelFieldInputInstance,
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const SpandrelImageToImageModelFieldInputComponent = (
|
||||
props: FieldComponentProps<SpandrelImageToImageModelFieldInputInstance, SpandrelImageToImageModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: SpandrelImageToImageModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldSpandrelImageToImageModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl className="nowheel nodrag" isInvalid={!value}>
|
||||
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SpandrelImageToImageModelFieldInputComponent);
|
||||
@@ -19,6 +19,7 @@ import type {
|
||||
ModelIdentifierFieldValue,
|
||||
SchedulerFieldValue,
|
||||
SDXLRefinerModelFieldValue,
|
||||
SpandrelImageToImageModelFieldValue,
|
||||
StatefulFieldValue,
|
||||
StringFieldValue,
|
||||
T2IAdapterModelFieldValue,
|
||||
@@ -39,6 +40,7 @@ import {
|
||||
zModelIdentifierFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zStatefulFieldValue,
|
||||
zStringFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
@@ -333,6 +335,12 @@ export const nodesSlice = createSlice({
|
||||
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zT2IAdapterModelFieldValue);
|
||||
},
|
||||
fieldSpandrelImageToImageModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<SpandrelImageToImageModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
|
||||
},
|
||||
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
|
||||
fieldValueReducer(state, action, zEnumFieldValue);
|
||||
},
|
||||
@@ -384,6 +392,7 @@ export const {
|
||||
fieldImageValueChanged,
|
||||
fieldIPAdapterModelValueChanged,
|
||||
fieldT2IAdapterModelValueChanged,
|
||||
fieldSpandrelImageToImageModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldModelIdentifierValueChanged,
|
||||
|
||||
@@ -66,6 +66,7 @@ const zModelType = z.enum([
|
||||
'embedding',
|
||||
'onnx',
|
||||
'clip_vision',
|
||||
'spandrel_image_to_image',
|
||||
]);
|
||||
const zSubModelType = z.enum([
|
||||
'unet',
|
||||
|
||||
@@ -38,6 +38,7 @@ export const MODEL_TYPES = [
|
||||
'VAEField',
|
||||
'CLIPField',
|
||||
'T2IAdapterModelField',
|
||||
'SpandrelImageToImageModelField',
|
||||
];
|
||||
|
||||
/**
|
||||
@@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
MainModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
SDXLRefinerModelField: 'teal.500',
|
||||
SpandrelImageToImageModelField: 'teal.500',
|
||||
StringField: 'yellow.500',
|
||||
T2IAdapterField: 'teal.500',
|
||||
T2IAdapterModelField: 'teal.500',
|
||||
|
||||
@@ -139,6 +139,10 @@ const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T2IAdapterModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SpandrelImageToImageModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -160,6 +164,7 @@ const zStatefulFieldType = z.union([
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
zSpandrelImageToImageModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
@@ -581,6 +586,33 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda
|
||||
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SpandrelModelToModelField
|
||||
|
||||
export const zSpandrelImageToImageModelFieldValue = zModelIdentifierField.optional();
|
||||
const zSpandrelImageToImageModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSpandrelImageToImageModelFieldValue,
|
||||
});
|
||||
const zSpandrelImageToImageModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSpandrelImageToImageModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSpandrelImageToImageModelFieldValue,
|
||||
});
|
||||
const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSpandrelImageToImageModelFieldType,
|
||||
});
|
||||
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
|
||||
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
|
||||
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
|
||||
export const isSpandrelImageToImageModelFieldInputInstance = (
|
||||
val: unknown
|
||||
): val is SpandrelImageToImageModelFieldInputInstance =>
|
||||
zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSpandrelImageToImageModelFieldInputTemplate = (
|
||||
val: unknown
|
||||
): val is SpandrelImageToImageModelFieldInputTemplate =>
|
||||
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
@@ -667,6 +699,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zControlNetModelFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
]);
|
||||
@@ -694,6 +727,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zControlNetModelFieldInputInstance,
|
||||
zIPAdapterModelFieldInputInstance,
|
||||
zT2IAdapterModelFieldInputInstance,
|
||||
zSpandrelImageToImageModelFieldInputInstance,
|
||||
zColorFieldInputInstance,
|
||||
zSchedulerFieldInputInstance,
|
||||
]);
|
||||
@@ -722,6 +756,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zControlNetModelFieldInputTemplate,
|
||||
zIPAdapterModelFieldInputTemplate,
|
||||
zT2IAdapterModelFieldInputTemplate,
|
||||
zSpandrelImageToImageModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
@@ -751,6 +786,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zControlNetModelFieldOutputTemplate,
|
||||
zIPAdapterModelFieldOutputTemplate,
|
||||
zT2IAdapterModelFieldOutputTemplate,
|
||||
zSpandrelImageToImageModelFieldOutputTemplate,
|
||||
zColorFieldOutputTemplate,
|
||||
zSchedulerFieldOutputTemplate,
|
||||
]);
|
||||
|
||||
@@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
SDXLRefinerModelField: undefined,
|
||||
StringField: '',
|
||||
T2IAdapterModelField: undefined,
|
||||
SpandrelImageToImageModelField: undefined,
|
||||
VAEModelField: undefined,
|
||||
ControlNetModelField: undefined,
|
||||
};
|
||||
|
||||
@@ -17,6 +17,7 @@ import type {
|
||||
SchedulerFieldInputTemplate,
|
||||
SDXLMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
StatefulFieldType,
|
||||
StatelessFieldInputTemplate,
|
||||
StringFieldInputTemplate,
|
||||
@@ -263,6 +264,17 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<T2IAdapt
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSpandrelImageToImageModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
SpandrelImageToImageModelFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, fieldType }) => {
|
||||
const template: SpandrelImageToImageModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -377,6 +389,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
|
||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||
} as const;
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ const MODEL_FIELD_TYPES = [
|
||||
'ControlNetModelField',
|
||||
'IPAdapterModelField',
|
||||
'T2IAdapterModelField',
|
||||
'SpandrelImageToImageModelField',
|
||||
];
|
||||
|
||||
/**
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
isNonSDXLMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSDXLMainModelModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isTIModelConfig,
|
||||
isVAEModelConfig,
|
||||
@@ -39,6 +40,7 @@ export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||
export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToImageModelConfig);
|
||||
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -51,6 +51,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
||||
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||
type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
@@ -62,6 +63,7 @@ export type AnyModelConfig =
|
||||
| ControlNetModelConfig
|
||||
| IPAdapterModelConfig
|
||||
| T2IAdapterModelConfig
|
||||
| SpandrelImageToImageModelConfig
|
||||
| TextualInversionModelConfig
|
||||
| MainModelConfig
|
||||
| CLIPVisionDiffusersConfig;
|
||||
@@ -86,6 +88,12 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
|
||||
return config.type === 't2i_adapter';
|
||||
};
|
||||
|
||||
export const isSpandrelImageToImageModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is SpandrelImageToImageModelConfig => {
|
||||
return config.type === 'spandrel_image_to_image';
|
||||
};
|
||||
|
||||
export const isControlAdapterModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => {
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.2.6"
|
||||
__version__ = "4.2.6post1"
|
||||
|
||||
45
prompt_augmentation/augment_prompts_gpt2.py
Normal file
45
prompt_augmentation/augment_prompts_gpt2.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, GPT2TokenizerFast
|
||||
|
||||
|
||||
def generate_augmented_prompt(
|
||||
model: GPT2LMHeadModel, tokenizer: GPT2TokenizerFast, initial_prompt, device: torch.device
|
||||
):
|
||||
instruction_prompt = f'I have a short image caption. I want you to expand it to some more detailed captions. Use two sentences for each new caption. Try to use creative language and unique or distinct words that relate to the subject of interest. Generate 3 new captions, each one interpreting the original caption in a unique way i.e. each caption should clearly describe a different image. The captions do not need to be realistic, they are allowed to describe weird or unusual scenes. The short caption is: "{initial_prompt}"'
|
||||
input_ids = tokenizer.encode(instruction_prompt, return_tensors="pt")
|
||||
assert isinstance(input_ids, torch.Tensor)
|
||||
input_ids = input_ids.to(device=device)
|
||||
|
||||
gen_tokens = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
# max_length=100,
|
||||
max_new_tokens=200,
|
||||
)
|
||||
gen_text = tokenizer.batch_decode(gen_tokens)[0]
|
||||
|
||||
print(f"Original Prompt: '{initial_prompt}'")
|
||||
print(f"Generated text: '{gen_text}'")
|
||||
print("\n\n")
|
||||
|
||||
|
||||
def main():
|
||||
device = torch.device("cuda")
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16)
|
||||
model = model.to(device=device)
|
||||
assert isinstance(model, GPT2LMHeadModel)
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
assert isinstance(tokenizer, GPT2TokenizerFast)
|
||||
|
||||
test_captions = [
|
||||
"a cat on a table",
|
||||
"medieval armor",
|
||||
]
|
||||
|
||||
for test_caption in test_captions:
|
||||
generate_augmented_prompt(model, tokenizer, test_caption, device=device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
168
prompt_augmentation/augment_prompts_phi3.py
Normal file
168
prompt_augmentation/augment_prompts_phi3.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def run_model(model, tokenizer, instruction, device="cuda"):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": instruction,
|
||||
}
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
||||
inputs = inputs.to(device)
|
||||
|
||||
outputs = model.generate(
|
||||
inputs,
|
||||
max_new_tokens=200,
|
||||
temperature=0.9,
|
||||
do_sample=True,
|
||||
)
|
||||
text = tokenizer.batch_decode(outputs)[0]
|
||||
assert isinstance(text, str)
|
||||
|
||||
output = text.split("<|assistant|>")[-1].strip()
|
||||
output = output.split("<|end|>")[0].strip()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def generate_style_prompt(model, tokenizer, initial_prompt, detailed_prompt):
|
||||
instruction = f"""Your task is to generate a list of words describing the style of an image. You will be given two captions for the same image, and must imagine the likely style of the image and generate a list of words that describe the style.
|
||||
|
||||
Here are some examples of style words for various styles:
|
||||
Example 1 style words: photography, high resolution, masterpiece, film grain, 8k, dynamic lighting
|
||||
Example 2 style words: 2D animation, cartoon, digital art, vibrant colors
|
||||
Example 3 style words: 3D animation, realistic, detailed, render, soft shadows, subsurface scattering
|
||||
Example 4 style words: painting, brush strokes, impressionism, oil painting
|
||||
Example 5 style words: drawing, sketch, pencil, charcoal, shading
|
||||
|
||||
Here are the captions describing the image:
|
||||
Simple caption: "{initial_prompt}"
|
||||
Detailed caption: "{detailed_prompt}"
|
||||
|
||||
Now imagine a possible style for the image and generate a list of five words separated by commas that describe the style.
|
||||
Style words:"""
|
||||
|
||||
output = run_model(model, tokenizer, instruction)
|
||||
return output
|
||||
|
||||
|
||||
def generate_augmented_prompt(model, tokenizer, initial_prompt):
|
||||
instruction = f"""Your task is to translate a short image caption and a style caption to a more detailed caption for the same image. The detailed caption should adhere to the following:
|
||||
- be 1 sentence long
|
||||
- use descriptive language that relates to the subject of interest
|
||||
- it may add new details, but shouldn't change the subject of the original caption
|
||||
|
||||
Here are some examples:
|
||||
Original caption: "A cat on a table"
|
||||
Detailed caption: "A fluffy cat with a curious expression, sitting on a wooden table next to a vase of flowers."
|
||||
|
||||
Original caption: "medieval armor"
|
||||
Detailed caption: "The gleaming suit of medieval armor stands proudly in the museum, its intricate engravings telling tales of long-forgotten battles and chivalry."
|
||||
|
||||
Original caption: "A panda bear as a mad scientist"
|
||||
Detailed caption: "Clad in a tiny lab coat and goggles, the panda bear feverishly mixes colorful potions, embodying the eccentricity of a mad scientist in its whimsical laboratory."
|
||||
|
||||
Here is the prompt to translate:
|
||||
Original caption: "{initial_prompt}"
|
||||
Detailed caption:"""
|
||||
output_prompt = run_model(model, tokenizer, instruction)
|
||||
return output_prompt
|
||||
|
||||
|
||||
STYLES = [
|
||||
"photography, RAW, high resolution, masterpiece, film grain, 8k, dynamic lighting",
|
||||
"2D animation, cartoon, digital art, vibrant colors",
|
||||
"3D animation, animated, cartoon, render, soft shadows, subsurface scattering",
|
||||
"painting, brush strokes, impressionism, oil painting",
|
||||
"sci-fi, futuristic, neon, cyberpunk, dystopian, gritty",
|
||||
# "Pop Art": "vibrant, bold, commercial, graphic, iconic",
|
||||
# "Impressionism": "soft, brushstroke, atmospheric, pastel, fleeting",
|
||||
# "Minimalism": "clean, simple, monochrome, sparse, understated",
|
||||
# "Surrealism": "dreamlike, bizarre, illogical, subconscious, imaginative",
|
||||
# "Cyberpunk": "neon, dystopian, futuristic, urban, gritty",
|
||||
# "Gothic Art": "dark, medieval, ornate, macabre, dramatic",
|
||||
# "Art Nouveau": "organic, flowing, decorative, floral, curvilinear",
|
||||
# "Baroque": "ornate, dramatic, dynamic, extravagant, rich",
|
||||
# "Renaissance Art": "classical, realistic, balanced, harmonious, humanistic",
|
||||
# "Street Art": "graffiti, urban, bold, rebellious, contemporary",
|
||||
# "Impressionist Photography": "blurry, atmospheric, evocative, impressionistic, dreamy",
|
||||
# "Retro Futurism": "nostalgic, sci-fi, retro, utopian, colorful",
|
||||
# "Anime Style": "exaggerated, cel-shaded, expressive, stylized, colorful",
|
||||
# "Manga Style": "graphic, dynamic, expressive, inked, stylized",
|
||||
# "Watercolor Painting": "translucent, flowing, wash, gradient, ethereal",
|
||||
# "Pixel Art": "retro, low-res, grid-based, nostalgic, 8-bit",
|
||||
# "Fantasy Art": "mythical, epic, magical, imaginative, detailed",
|
||||
# "Comic Book Style": "bold, inked, dynamic, panel, narrative",
|
||||
# "Photorealism": "meticulous, detailed, lifelike, precise, high-definition",
|
||||
# "Vintage Poster": "retro, nostalgic, graphic, bold, stylized",
|
||||
# "Fantasy Illustration": "enchanting, whimsical, colorful, imaginative, ethereal",
|
||||
# "Anime Realism": "hybrid, stylized, polished, detailed, emotive",
|
||||
# "Hyperrealism": "ultra-detailed, uncanny, lifelike, precise, meticulous",
|
||||
# "Sci-Fi Illustration": "futuristic, imaginative, detailed, cosmic, speculative",
|
||||
# "Anime Chibi": "cute, exaggerated, simplified, colorful, whimsical",
|
||||
# "Vintage Photography": "sepia, nostalgic, timeless, film, classic",
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
device = torch.device("cuda")
|
||||
model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.float16)
|
||||
model = model.to(device=device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
test_prompts = [
|
||||
# Simple
|
||||
# "medieval armor",
|
||||
# "make a calendar",
|
||||
"a hairless mouse with human ears",
|
||||
# "A panda bear as a mad scientist",
|
||||
"Male portrait photo",
|
||||
"Apocalyptic scenes of a meteor storm over a volcano.",
|
||||
# "cinematic still of a stainless steel robot swimming in a pool",
|
||||
# "Fantasy castle on a hilltop, sunset",
|
||||
# "A bee devouring the world",
|
||||
# "transparent ghost of a soldier in a cemetry",
|
||||
# "Space dog",
|
||||
# "A mermaid playing chess with a dolphin",
|
||||
# "F18 hornet",
|
||||
# "Dogs playing poker",
|
||||
# "a strong and muscular warrior with a bow",
|
||||
# "a masterpiece painting of a crocodile wearing a hoodie sitting on the roof of a car",
|
||||
# "cozy Danish interior design with wooden floor modern realistic archviz scandinavian",
|
||||
# "A curious cat exploring a haunted mansion",
|
||||
# "toilet design toilet in style of dodge charger toilet, black, photo",
|
||||
# "swirling water tornados epic fantasy",
|
||||
# "Painting of melted gemstones metallic sculpture with electrifying god rays brane bejeweled style",
|
||||
# "olympic swimming pool",
|
||||
# # Detailed subject, simple style
|
||||
# "a man looking like a lobster, with lobster hands, smiling, looking straight into the camera with arms wide spread",
|
||||
# "A mythical monstrous black furry nocturnal dog with bear claws, green glistening scaled wings and glowing crimson eyes. Several heavy chains hang from its body and ankles.",
|
||||
# "Photograph of a red apple on a wooden table while in the background a window is observed through which a flash of light enters"
|
||||
# # Simple subject, detailed style
|
||||
# "photo of a bicycle, detailed, 8k uhd, dslr, high quality, film grain, Fujifilm XT3",
|
||||
# # Detailed subject and style
|
||||
# "RAW photo, aristocratic russian noblewoman, dressed in medieval dress, model face, come hither gesture, medieval mansion, medieval nobility, slim body, high detailed skin, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3",
|
||||
# "A close-up photograph of a fat orange cat with lasagna in its mouth. Shot on Leica M6.",
|
||||
# "iOS app icon, cute white and yellow tiger head, happy, colorful, minimal, detailed, fantasy, app icon, 8k, white background",
|
||||
# "Cinematic, off-center, two-shot, 35mm film still of a 30-year-old french man, curly brown hair and a stained beige polo sweater, reading a book to his adorable 5-year-old daughter, wearing fuzzy pink pajamas, sitting in a cozy corner nook, sunny natural lighting, sun shining through the glass of the window, warm morning glow, sharp focus, heavenly illumination, unconditional love",
|
||||
# "futuristic robot but wearing medieval lord clothes and in a medieval castle, extremely detailed digital art, ambient lightning, interior, castle, medieval, painting, digital painting, trending on devianart, photorealistic, sunrays",
|
||||
# "very dark focused flash photo, amazing quality, masterpiece, best quality, hyper detailed, ultra detailed, UHD, perfect anatomy, portrait, dof, hyper-realism, majestic, awesome, inspiring,Capture the thrilling showdown between the ancient mummy and the colossal sand boss in an epic battle amidst swirling dust and desert sands. Embrace the action and chaos as these formidable forces clash in the heart of the dunes. cinematic composition, soft shadows, national geographic style",
|
||||
# "8k breathtaking view of a Boat on the waterside, Swiss alp, mountain lake, milky haze, morning lake mist, masterpiece, award-winning, professional, highly detailed in yvonne coomber style, undefined, snow, mist, in the distance, glowing eyes",
|
||||
]
|
||||
|
||||
# torch.random.manual_seed(1234)
|
||||
for initial_prompt in test_prompts:
|
||||
# Randomly select a style.
|
||||
style = STYLES[0]
|
||||
print("----------------------")
|
||||
detailed_prompt = generate_augmented_prompt(model, tokenizer, initial_prompt)
|
||||
# style_prompt = generate_style_prompt(model, tokenizer, initial_prompt, detailed_prompt)
|
||||
print(f"Original Prompt: '{initial_prompt}'\n\n")
|
||||
print(f"Detailed Prompt: '{detailed_prompt}'\n\n")
|
||||
# print(f"Style Prompt: '{style_prompt}'\n\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -46,6 +46,7 @@ dependencies = [
|
||||
"opencv-python==4.9.0.80",
|
||||
"pytorch-lightning==2.1.3",
|
||||
"safetensors==0.4.3",
|
||||
"spandrel==0.3.4",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch==2.2.2",
|
||||
"torchmetrics==0.11.4",
|
||||
@@ -54,7 +55,7 @@ dependencies = [
|
||||
"transformers==4.41.1",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events==0.11.0",
|
||||
"fastapi-events==0.11.1",
|
||||
"fastapi==0.111.0",
|
||||
"huggingface-hub==0.23.1",
|
||||
"pydantic-settings==2.2.1",
|
||||
|
||||
46
tests/backend/stable_diffusion/extensions/test_base.py
Normal file
46
tests/backend/stable_diffusion/extensions/test_base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from unittest import mock
|
||||
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
|
||||
class MockExtension(ExtensionBase):
|
||||
"""A mock ExtensionBase subclass for testing purposes."""
|
||||
|
||||
def __init__(self, x: int):
|
||||
super().__init__()
|
||||
self._x = x
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def set_step_index(self, ctx: DenoiseContext):
|
||||
ctx.step_index = self._x
|
||||
|
||||
|
||||
def test_extension_base_callback_registration():
|
||||
"""Test that a callback can be successfully registered with an extension."""
|
||||
val = 5
|
||||
mock_extension = MockExtension(val)
|
||||
|
||||
mock_ctx = mock.MagicMock()
|
||||
|
||||
callbacks = mock_extension.get_callbacks()
|
||||
pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, [])
|
||||
assert len(pre_denoise_loop_cbs) == 1
|
||||
|
||||
# Call the mock callback.
|
||||
pre_denoise_loop_cbs[0].function(mock_ctx)
|
||||
|
||||
# Confirm that the callback ran.
|
||||
assert mock_ctx.step_index == val
|
||||
|
||||
|
||||
def test_extension_base_empty_callback_type():
|
||||
"""Test that an empty list is returned when no callbacks are registered for a given callback type."""
|
||||
mock_extension = MockExtension(5)
|
||||
|
||||
# There should be no callbacks registered for POST_DENOISE_LOOP.
|
||||
callbacks = mock_extension.get_callbacks()
|
||||
|
||||
post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, [])
|
||||
assert len(post_denoise_loop_cbs) == 0
|
||||
112
tests/backend/stable_diffusion/test_extension_manager.py
Normal file
112
tests/backend/stable_diffusion/test_extension_manager.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
|
||||
|
||||
class MockExtension(ExtensionBase):
|
||||
"""A mock ExtensionBase subclass for testing purposes."""
|
||||
|
||||
def __init__(self, x: int):
|
||||
super().__init__()
|
||||
self._x = x
|
||||
|
||||
# Note that order is not specified. It should default to 0.
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def set_step_index(self, ctx: DenoiseContext):
|
||||
ctx.step_index = self._x
|
||||
|
||||
|
||||
class MockExtensionLate(ExtensionBase):
|
||||
"""A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback."""
|
||||
|
||||
def __init__(self, x: int):
|
||||
super().__init__()
|
||||
self._x = x
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||
def set_step_index(self, ctx: DenoiseContext):
|
||||
ctx.step_index = self._x
|
||||
|
||||
|
||||
def test_extension_manager_run_callback():
|
||||
"""Test that run_callback runs all callbacks for the given callback type."""
|
||||
|
||||
em = ExtensionsManager()
|
||||
mock_extension_1 = MockExtension(1)
|
||||
em.add_extension(mock_extension_1)
|
||||
|
||||
mock_ctx = mock.MagicMock()
|
||||
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||
|
||||
assert mock_ctx.step_index == 1
|
||||
|
||||
|
||||
def test_extension_manager_run_callback_no_callbacks():
|
||||
"""Test that run_callback does not raise an error when there are no callbacks for the given callback type."""
|
||||
em = ExtensionsManager()
|
||||
mock_ctx = mock.MagicMock()
|
||||
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["extension_1", "extension_2"],
|
||||
# Regardless of initialization order, we expect MockExtensionLate to run last.
|
||||
[(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))],
|
||||
)
|
||||
def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase):
|
||||
"""Test that run_callback runs callbacks in the correct order."""
|
||||
em = ExtensionsManager()
|
||||
em.add_extension(extension_1)
|
||||
em.add_extension(extension_2)
|
||||
|
||||
mock_ctx = mock.MagicMock()
|
||||
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||
|
||||
assert mock_ctx.step_index == 2
|
||||
|
||||
|
||||
class MockExtensionStableSort(ExtensionBase):
|
||||
"""A mock extension with three PRE_DENOISE_LOOP callbacks, each with a different order value."""
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=-1000)
|
||||
def early(self, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def middle(self, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||
def late(self, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
|
||||
def test_extension_manager_stable_sort():
|
||||
"""Test that when two callbacks have the same 'order' value, they are sorted based on the order they were added to
|
||||
the ExtensionsManager."""
|
||||
|
||||
em = ExtensionsManager()
|
||||
|
||||
mock_extension_1 = MockExtensionStableSort()
|
||||
mock_extension_2 = MockExtensionStableSort()
|
||||
|
||||
em.add_extension(mock_extension_1)
|
||||
em.add_extension(mock_extension_2)
|
||||
|
||||
expected_order = [
|
||||
mock_extension_1.early,
|
||||
mock_extension_2.early,
|
||||
mock_extension_1.middle,
|
||||
mock_extension_2.middle,
|
||||
mock_extension_1.late,
|
||||
mock_extension_2.late,
|
||||
]
|
||||
|
||||
# It's not ideal that we are accessing a private attribute here, but this was the most direct way to assert the
|
||||
# desired behaviour.
|
||||
assert [cb.function for cb in em._ordered_callbacks[ExtensionCallbackType.PRE_DENOISE_LOOP]] == expected_order
|
||||
Reference in New Issue
Block a user