Compare commits

..

21 Commits

Author SHA1 Message Date
Lincoln Stein
39881d3d7d fix installer logic for tokenizer_3 and text_encoder_3 2024-06-21 23:34:18 -04:00
Lincoln Stein
28f1d25973 unpin dependencies; fix typo in sd3.py 2024-06-21 15:59:47 -04:00
Lincoln Stein
95377ea159 add non-commercial use message to sd3 starter; rebuild frontend 2024-06-20 21:59:28 -04:00
Lincoln Stein
445561e3a4 add sd3 to starter models 2024-06-20 18:13:46 -04:00
blessedcoolant
66260fd345 fix: Update Clip 3 slot title & lint issues 2024-06-20 08:53:35 +05:30
blessedcoolant
c403efa83f fix: Make TE5 Optional 2024-06-20 08:45:36 +05:30
blessedcoolant
cd99ef2f46 Merge branch 'main' into lstein/feat/sd3-model-loading 2024-06-20 08:43:34 +05:30
Lincoln Stein
9dce4f09ae scale default RAM cache by size of system evirtual memory 2024-06-18 13:49:12 -04:00
blessedcoolant
22b5c036aa Revert "fix: height and weight not working on sd3 node"
This reverts commit be14fd59c9.
2024-06-17 06:41:49 +05:30
blessedcoolant
be14fd59c9 fix: height and weight not working on sd3 node 2024-06-17 06:34:01 +05:30
Lincoln Stein
423057a2e8 add config variable to suppress loading of sd3 text_encoder_3 T5 model 2024-06-16 16:28:39 -04:00
blessedcoolant
f65d50a4dd wip: basic wrapper for generating sd3 images 2024-06-16 04:18:20 +05:30
Lincoln Stein
554809c647 return correct base type for sd3 VAEs 2024-06-15 18:17:03 -04:00
Lincoln Stein
ac0396e6f7 Merge branch 'lstein/feat/sd3-model-loading' of github.com:invoke-ai/InvokeAI into lstein/feat/sd3-model-loading 2024-06-14 16:48:20 -04:00
Lincoln Stein
78f704e7d5 tweak installer to select correct components of HF SD3 diffusers models 2024-06-14 16:46:24 -04:00
blessedcoolant
41236031b2 chore: remove unrequired changes to v1 workflow field types 2024-06-15 00:00:44 +05:30
blessedcoolant
ddbd2ebd9d wip: add Transformer Field to Node UI 2024-06-14 22:25:26 +05:30
blessedcoolant
0c970bc880 wip: add SD3 Model Loader Invocation 2024-06-14 22:21:09 +05:30
blessedcoolant
c79d9b9ecf wip: Add Initial support for select SD3 models in UI 2024-06-14 16:04:16 +05:30
Lincoln Stein
03b9d17d0b draft sd3 loading; probable VRAM leak when using quantized submodels 2024-06-13 00:51:00 -04:00
Lincoln Stein
002f8242a1 add draft SD3 probing; there is an issue with FromOriginalControlNetMixin in backend.util.hotfixes due to new diffusers 2024-06-12 22:44:34 -04:00
62 changed files with 1205 additions and 731 deletions

View File

@@ -1328,7 +1328,7 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
config = InvokeAIAppConfig.get_config()
ram_cache = ModelCache(
max_cache_size=config.ram_cache_size, logger=logger
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
)
convert_cache = ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size

View File

@@ -103,7 +103,6 @@ class CompelInvocation(BaseInvocation):
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(self.prompt)
@@ -118,7 +117,6 @@ class CompelInvocation(BaseInvocation):
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
@@ -205,7 +203,6 @@ class SDXLPromptInvocationBase:
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(prompt)
@@ -316,6 +313,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
)
]
)
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(

View File

@@ -1,5 +1,4 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import copy
import inspect
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@@ -194,8 +193,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
cond_data = context.conditioning.load(cond.conditioning_name)
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
@@ -226,7 +226,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
assert isinstance(resized_mask, torch.Tensor)
return resized_mask
def _concat_regional_text_embeddings(

View File

@@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
MainModel = "MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
SD3MainModel = "SD3MainModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
LoRAModel = "LoRAModelField"
@@ -125,6 +126,7 @@ class FieldDescriptions:
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
@@ -133,6 +135,7 @@ class FieldDescriptions:
main_model = "Main model (UNet, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
sd3_main_model = "SD3 Main Model (Transformer, CLIP1, CLIP2, CLIP3, VAE) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"

View File

@@ -12,14 +12,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField, WithBoard, WithMetadata
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext

View File

@@ -8,13 +8,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
class ModelIdentifierField(BaseModel):
@@ -54,6 +48,11 @@ class UNetField(BaseModel):
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load unet submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
class CLIPField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
@@ -61,6 +60,15 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class SD3CLIPField(BaseModel):
tokenizer_1: ModelIdentifierField = Field(description="Info to load tokenizer 1 submodel")
text_encoder_1: ModelIdentifierField = Field(description="Info to load text_encoder 1 submodel")
tokenizer_2: ModelIdentifierField = Field(description="Info to load tokenizer 2 submodel")
text_encoder_2: ModelIdentifierField = Field(description="Info to load text_encoder 2 submodel")
tokenizer_3: Optional[ModelIdentifierField] = Field(description="Info to load tokenizer 3 submodel")
text_encoder_3: Optional[ModelIdentifierField] = Field(description="Info to load text_encoder 3 submodel")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')

View File

@@ -0,0 +1,200 @@
from contextlib import ExitStack
from typing import Optional, cast
import torch
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
from pydantic import field_validator
from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Input,
invocation,
invocation_output,
)
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.denoise_latents import get_scheduler
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField, SD3CLIPField, TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.model_manager.config import SubModelType
sd3_pipeline: Optional[StableDiffusion3Pipeline] = None
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
@invocation_output("sd3_model_loader_output")
class SD3ModelLoaderOutput(BaseInvocationOutput):
"""Stable Diffuion 3 base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: SD3CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sd3_model_loader", title="SD3 Main Model", tags=["model", "sd3"], category="model", version="1.0.0")
class SD3ModelLoaderInvocation(BaseInvocation):
"""Loads an SD3 base model, outputting its submodels."""
model: ModelIdentifierField = InputField(description=FieldDescriptions.sd3_main_model, ui_type=UIType.SD3MainModel)
def invoke(self, context: InvocationContext) -> SD3ModelLoaderOutput:
model_key = self.model.key
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer_1 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder_1 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer_2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder_2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
try:
tokenizer_3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
text_encoder_3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
except Exception:
tokenizer_3 = None
text_encoder_3 = None
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SD3ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
clip=SD3CLIPField(
tokenizer_1=tokenizer_1,
text_encoder_1=text_encoder_1,
tokenizer_2=tokenizer_2,
text_encoder_2=text_encoder_2,
tokenizer_3=tokenizer_3,
text_encoder_3=text_encoder_3,
),
vae=VAEField(vae=vae),
)
@invocation(
"sd3_image_generator", title="Stable Diffusion 3", tags=["latent", "sd3"], category="latents", version="1.0.0"
)
class StableDiffusion3Invocation(BaseInvocation):
"""Generates an image using Stable Diffusion 3."""
transformer: TransformerField = InputField(
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
ui_order=0,
)
clip: SD3CLIPField = InputField(
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
ui_order=1,
)
noise: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
ui_order=2,
)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler_f",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
positive_prompt: str = InputField(default="", title="Positive Prompt")
negative_prompt: str = InputField(default="", title="Negative Prompt")
steps: int = InputField(default=20, gt=0, description=FieldDescriptions.steps)
guidance_scale: float = InputField(default=7.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
use_clip_3: bool = InputField(default=True, description="Use TE5 Encoder of SD3", title="Use TE5 Encoder")
seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description=FieldDescriptions.seed,
)
width: int = InputField(
default=1024,
multiple_of=LATENT_SCALE_FACTOR,
gt=0,
description=FieldDescriptions.width,
)
height: int = InputField(
default=1024,
multiple_of=LATENT_SCALE_FACTOR,
gt=0,
description=FieldDescriptions.height,
)
@field_validator("seed", mode="before")
def modulo_seed(cls, v: int):
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % (SEED_MAX + 1)
def invoke(self, context: InvocationContext) -> LatentsOutput:
with ExitStack() as stack:
tokenizer_1 = stack.enter_context(context.models.load(self.clip.tokenizer_1))
tokenizer_2 = stack.enter_context(context.models.load(self.clip.tokenizer_2))
text_encoder_1 = stack.enter_context(context.models.load(self.clip.text_encoder_1))
text_encoder_2 = stack.enter_context(context.models.load(self.clip.text_encoder_2))
transformer = stack.enter_context(context.models.load(self.transformer.transformer))
assert isinstance(transformer, SD3Transformer2DModel)
assert isinstance(text_encoder_1, CLIPTextModelWithProjection)
assert isinstance(text_encoder_2, CLIPTextModelWithProjection)
assert isinstance(tokenizer_1, CLIPTokenizer)
assert isinstance(tokenizer_2, CLIPTokenizer)
if self.use_clip_3 and self.clip.tokenizer_3 and self.clip.text_encoder_3:
tokenizer_3 = stack.enter_context(context.models.load(self.clip.tokenizer_3))
text_encoder_3 = stack.enter_context(context.models.load(self.clip.text_encoder_3))
assert isinstance(text_encoder_3, T5EncoderModel)
assert isinstance(tokenizer_3, T5TokenizerFast)
else:
tokenizer_3 = None
text_encoder_3 = None
scheduler = get_scheduler(
context=context,
scheduler_info=self.transformer.scheduler,
scheduler_name=self.scheduler,
seed=self.seed,
)
sd3_pipeline = StableDiffusion3Pipeline(
transformer=transformer,
vae=FakeVae(),
text_encoder=text_encoder_1,
text_encoder_2=text_encoder_2,
text_encoder_3=text_encoder_3,
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
tokenizer_3=tokenizer_3,
scheduler=scheduler,
)
results = sd3_pipeline(
self.positive_prompt,
negative_prompt=self.negative_prompt,
num_inference_steps=self.steps,
guidance_scale=self.guidance_scale,
output_type="latent",
)
latents = cast(torch.Tensor, results.images[0])
latents = latents.unsqueeze(0)
latents_name = context.tensors.save(latents)
return LatentsOutput.build(latents_name, latents=latents, seed=self.seed)

View File

@@ -26,13 +26,14 @@ LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.2"
SYSTEM_RAM_TO_CACHE_SIZE_FACTOR = 0.25 # after 60 GB, default ram cache will scale by this factor
CONFIG_SCHEMA_VERSION = "4.0.1"
def get_default_ram_cache_size() -> float:
@@ -45,7 +46,7 @@ def get_default_ram_cache_size() -> float:
max_ram = psutil.virtual_memory().total / GB
if max_ram >= 60:
return 15.0
return max_ram * SYSTEM_RAM_TO_CACHE_SIZE_FACTOR
if max_ram >= 30:
return 7.5
if max_ram >= 14:
@@ -105,16 +106,14 @@ class InvokeAIAppConfig(BaseSettings):
convert_cache: Maximum size of on-disk converted models cache (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps`
devices: List of execution devices; will override default device selected.
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
clear_queue_on_startup: Empties session queue on startup.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
@@ -180,7 +179,6 @@ class InvokeAIAppConfig(BaseSettings):
# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.")
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
# GENERATION
@@ -190,7 +188,6 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
# NODES
@@ -380,6 +377,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
@@ -427,27 +427,6 @@ def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig
return config
def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.1 config dictionary to a current config object.
A few new multi-GPU options were added in 4.0.2, and this simply
updates the schema label.
Args:
config_dict: A dictionary of settings from a v4.0.1 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for k, _ in config_dict.items():
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
# TO DO: replace this with a formal registration and migration system
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
@@ -479,10 +458,6 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
elif loaded_config_dict["schema_version"] == "4.0.1":
loaded_config_dict = migrate_v4_0_1_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately

View File

@@ -53,11 +53,11 @@ class InvocationServices:
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase",
names: "NameServiceBase",
performance_statistics: "InvocationStatsServiceBase",
urls: "UrlServiceBase",
workflow_records: "WorkflowRecordsStorageBase",
tensors: "ObjectSerializerBase[torch.Tensor]",
@@ -77,11 +77,11 @@ class InvocationServices:
self.model_images = model_images
self.model_manager = model_manager
self.download_queue = download_queue
self.performance_statistics = performance_statistics
self.session_queue = session_queue
self.session_processor = session_processor
self.invocation_cache = invocation_cache
self.names = names
self.performance_statistics = performance_statistics
self.urls = urls
self.workflow_records = workflow_records
self.tensors = tensors

View File

@@ -74,9 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def reset_stats(self, graph_execution_state_id: str):
self._stats.pop(graph_execution_state_id)
self._cache_stats.pop(graph_execution_state_id)
def reset_stats(self):
self._stats = {}
self._cache_stats = {}
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

View File

@@ -284,14 +284,9 @@ class ModelInstallService(ModelInstallServiceBase):
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
self._install_jobs = unfinished_jobs
def _migrate_yaml(self, rename_yaml: Optional[bool] = True, overwrite_db: Optional[bool] = False) -> None:
def _migrate_yaml(self) -> None:
db_models = self.record_store.all_models()
if overwrite_db:
for model in db_models:
self.record_store.del_model(model.key)
db_models = self.record_store.all_models()
legacy_models_yaml_path = (
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
)
@@ -341,8 +336,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
if rename_yaml:
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
# Unset the path - we are done with it either way
self._app_config.legacy_models_yaml_path = None

View File

@@ -33,11 +33,6 @@ class ModelLoadServiceBase(ABC):
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
@property
@abstractmethod
def gpu_count(self) -> int:
"""Return the number of GPUs we are configured to use."""
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None

View File

@@ -46,7 +46,6 @@ class ModelLoadService(ModelLoadServiceBase):
self._registry = registry
def start(self, invoker: Invoker) -> None:
"""Start the service."""
self._invoker = invoker
@property
@@ -54,11 +53,6 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader."""
return self._ram_cache
@property
def gpu_count(self) -> int:
"""Return the number of GPUs available for our uses."""
return len(self._ram_cache.execution_devices)
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod
from typing import Optional, Set
import torch
from typing_extensions import Self
@@ -32,7 +31,7 @@ class ModelManagerServiceBase(ABC):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_devices: Optional[Set[torch.device]] = None,
execution_device: torch.device,
) -> Self:
"""
Construct the model manager service instance.

View File

@@ -1,10 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
@@ -65,6 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: Optional[torch.device] = None,
) -> Self:
"""
Construct the model manager service instance.
@@ -77,7 +82,9 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache(
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(

View File

@@ -1,6 +1,5 @@
import shutil
import tempfile
import threading
import typing
from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypeVar
@@ -10,7 +9,6 @@ import torch
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker
@@ -72,10 +70,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
return self._output_dir / name
def _new_name(self) -> str:
tid = threading.current_thread().ident
# Add tid to the object name because uuid4 not thread-safe on windows
# See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
return f"{self._obj_class_name}_{tid}-{uuid_string()}"
return f"{self._obj_class_name}_{uuid_string()}"
def _tempdir_cleanup(self) -> None:
"""Calls `cleanup` on the temporary directory, if it exists."""

View File

@@ -1,9 +1,8 @@
import traceback
from contextlib import suppress
from queue import Queue
from threading import BoundedSemaphore, Lock, Thread
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional, Set
from typing import Optional
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import (
@@ -27,7 +26,6 @@ from invokeai.app.services.session_queue.session_queue_common import SessionQueu
from invokeai.app.services.shared.graph import NodeInputError
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
from invokeai.backend.util.devices import TorchDevice
from ..invoker import Invoker
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
@@ -59,11 +57,8 @@ class DefaultSessionRunner(SessionRunnerBase):
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
self._on_node_error_callbacks = on_node_error_callbacks or []
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
self._process_lock = Lock()
def start(
self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
) -> None:
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
self._services = services
self._cancel_event = cancel_event
self._profiler = profiler
@@ -81,8 +76,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# Loop over invocations until the session is complete or canceled
while True:
try:
with self._process_lock:
invocation = queue_item.session.next()
invocation = queue_item.session.next()
# Anything other than a `NodeInputError` is handled as a processor error
except NodeInputError as e:
error_type = e.__class__.__name__
@@ -114,7 +108,7 @@ class DefaultSessionRunner(SessionRunnerBase):
self._on_after_run_session(queue_item=queue_item)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
try:
# Any unhandled exception in this scope is an invocation error & will fail the graph
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
@@ -216,7 +210,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item)
@@ -330,7 +324,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker) -> None:
self._invoker: Invoker = invoker
self._active_queue_items: Set[SessionQueueItem] = set()
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
self._resume_event = ThreadEvent()
@@ -356,14 +350,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None
)
self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
TorchDevice.execution_devices()
)
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
# Session processor - singlethreaded
self._thread = Thread(
name="session_processor",
target=self._process,
@@ -376,16 +363,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
)
self._thread.start()
# Session processor workers - multithreaded
self._invoker.services.logger.debug(f"Starting {self._worker_thread_count} session processing threads.")
for _i in range(0, self._worker_thread_count):
worker = Thread(
name="session_worker",
target=self._process_next_session,
daemon=True,
)
worker.start()
def stop(self, *args, **kwargs) -> None:
self._stop_event.set()
@@ -393,7 +370,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event.set()
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
if any(item.queue_id == event[1].queue_id for item in self._active_queue_items):
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
self._cancel_event.set()
self._poll_now()
@@ -401,7 +378,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now()
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
if self._active_queue_items and event[1].status in ["completed", "failed", "canceled"]:
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
@@ -426,7 +403,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus(
is_started=self._resume_event.is_set(),
is_processing=len(self._active_queue_items) > 0,
is_processing=self._queue_item is not None,
)
def _process(
@@ -451,22 +428,30 @@ class DefaultSessionProcessor(SessionProcessorBase):
resume_event.wait()
# Get the next session to process
queue_item = self._invoker.services.session_queue.dequeue()
self._queue_item = self._invoker.services.session_queue.dequeue()
if queue_item is None:
if self._queue_item is None:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
self._session_worker_queue.put(queue_item)
self._invoker.services.logger.debug(f"Scheduling queue item {queue_item.item_id} to run")
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# Run the graph
# self.session_runner.run(queue_item=self._queue_item)
self.session_runner.run(queue_item=self._queue_item)
except Exception:
except Exception as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_non_fatal_processor_error(
queue_item=self._queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
# Wait for next polling interval or event to try again
poll_now_event.wait(self._polling_interval)
continue
@@ -481,25 +466,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
finally:
stop_event.clear()
poll_now_event.clear()
self._queue_item = None
self._thread_semaphore.release()
def _process_next_session(self) -> None:
while True:
self._resume_event.wait()
queue_item = self._session_worker_queue.get()
if queue_item.status == "canceled":
continue
try:
self._active_queue_items.add(queue_item)
# reserve a GPU for this session - may block
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
# Run the session on the reserved GPU
self.session_runner.run(queue_item=queue_item)
except Exception:
continue
finally:
self._active_queue_items.remove(queue_item)
def _on_non_fatal_processor_error(
self,
queue_item: Optional[SessionQueueItem],

View File

@@ -236,9 +236,6 @@ class SessionQueueItemWithoutGraph(BaseModel):
}
)
def __hash__(self) -> int:
return self.item_id
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass

View File

@@ -652,7 +652,7 @@ class Graph(BaseModel):
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list
if get_origin(input_field) is not list:
if get_origin(input_field) != list:
return False
# Validate that all outputs match the input type

View File

@@ -2,7 +2,6 @@ from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from torch import Tensor
@@ -27,13 +26,11 @@ from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
"""
The InvocationContext provides access to various services and data about the current invocation.
@@ -326,6 +323,7 @@ class ConditioningInterface(InvocationContextInterface):
Returns:
The loaded conditioning data.
"""
return self._services.conditioning.load(name)
@@ -559,28 +557,6 @@ class UtilInterface(InvocationContextInterface):
is_canceled=self.is_canceled,
)
def torch_device(self) -> torch.device:
"""
Return a torch device to use in the current invocation.
Returns:
A torch.device not currently in use by the system.
"""
ram_cache: "ModelCacheBase[AnyModel]" = self._services.model_manager.load.ram_cache
return ram_cache.get_execution_device()
def torch_dtype(self, device: Optional[torch.device] = None) -> torch.dtype:
"""
Return a precision type to use with the current invocation and torch device.
Args:
device: Optional device.
Returns:
A torch.dtype suited for the current device.
"""
return TorchDevice.choose_torch_dtype(device)
class InvocationContext:
"""Provides access to various services and data for the current invocation.

View File

@@ -25,7 +25,6 @@ from enum import Enum
from typing import Literal, Optional, Type, TypeAlias, Union
import torch
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
@@ -38,7 +37,7 @@ from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ConfigMixin, ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
class InvalidModelConfigException(Exception):
@@ -53,6 +52,7 @@ class BaseModelType(str, Enum):
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
StableDiffusion3 = "sd-3"
# Kandinsky2_1 = "kandinsky-2.1"
@@ -76,8 +76,11 @@ class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
Transformer = "transformer"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
@@ -178,7 +181,6 @@ class ModelConfigBase(BaseModel):
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
"""Extend the pydantic schema from a json."""
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
@@ -445,7 +447,7 @@ class ModelConfigFactory(object):
model = dest_class.model_validate(model_data)
else:
# mypy doesn't typecheck TypeAdapters well?
model = AnyModelConfigValidator.validate_python(model_data)
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
assert model is not None
if key:
model.key = key

View File

@@ -65,7 +65,8 @@ class LoadedModelWithoutConfig:
def __enter__(self) -> AnyModel:
"""Context entry."""
return self._locker.lock()
self._locker.lock()
return self.model
def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Context exit."""

View File

@@ -84,6 +84,8 @@ class ModelLoader(ModelLoaderBase):
except IndexError:
pass
self._logger.info(f"Loading {config.key}:{submodel_type}")
cache_path: Path = self._convert_cache.cache_path(str(model_path))
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)

View File

@@ -8,10 +8,9 @@ model will be cleared and (re)loaded from disk when next needed.
"""
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, Generator, Generic, Optional, Set, TypeVar
from typing import Dict, Generic, Optional, TypeVar
import torch
@@ -52,13 +51,45 @@ class CacheRecord(Generic[T]):
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Read-only copy of the model *without weights* residing in the "meta device"
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
key: str
size: int
model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
is_quantized: bool = False
loaded: bool = False
_locks: int = 0
def lock(self) -> None:
"""Lock this record."""
self._locks += 1
def unlock(self) -> None:
"""Unlock this record."""
self._locks -= 1
assert self._locks >= 0
@property
def locked(self) -> bool:
"""Return true if record is locked."""
return self._locks > 0
@dataclass
@@ -85,27 +116,14 @@ class ModelCacheBase(ABC, Generic[T]):
@property
@abstractmethod
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
pass
@contextmanager
@property
@abstractmethod
def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]:
"""Reserve an execution device (GPU) under the current thread id."""
pass
@abstractmethod
def get_execution_device(self) -> torch.device:
"""
Return an execution device that has been reserved for current thread.
Note that reservations are done using the current thread's TID.
It might be better to do this using the session ID, but that involves
too many detailed changes to model manager calls.
May generate a ValueError if no GPU has been reserved.
"""
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@property
@@ -114,6 +132,16 @@ class ModelCacheBase(ABC, Generic[T]):
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@abstractmethod
def offload_unlocked_models(self, size_required: int) -> None:
"""Offload from VRAM any models not actively in use."""
pass
@abstractmethod
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device."""
pass
@property
@abstractmethod
def stats(self) -> Optional[CacheStats]:
@@ -175,11 +203,6 @@ class ModelCacheBase(ABC, Generic[T]):
"""Return true if the model identified by key and submodel_type is in the cache."""
pass
@abstractmethod
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
"""Move a copy of the model into the indicated device and return it."""
pass
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""

View File

@@ -18,19 +18,17 @@ context. Use like this:
"""
import copy
import gc
import sys
import threading
from contextlib import contextmanager, suppress
import math
import time
from contextlib import suppress
from logging import Logger
from threading import BoundedSemaphore
from typing import Dict, Generator, List, Optional, Set
from typing import Dict, List, Optional
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -41,7 +39,9 @@ from .model_locker import ModelLocker
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
@@ -57,8 +57,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
lazy_offloading: bool = True,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
@@ -66,19 +68,22 @@ class ModelCache(ModelCacheBase[AnyModel]):
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
self._storage_device: torch.device = storage_device
self._ram_lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
@@ -86,87 +91,25 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
# device to thread id
self._device_lock = threading.Lock()
self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()}
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
self.logger.info(
f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}"
)
@property
def logger(self) -> Logger:
"""Return the logger used by the cache."""
return self._logger
@property
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
return self._lazy_offloading
@property
def storage_device(self) -> torch.device:
"""Return the storage device (e.g. "CPU" for RAM)."""
return self._storage_device
@property
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
devices = self._execution_devices.keys()
return set(devices)
def get_execution_device(self) -> torch.device:
"""
Return an execution device that has been reserved for current thread.
Note that reservations are done using the current thread's TID.
It would be better to do this using the session ID, but that involves
too many detailed changes to model manager calls.
May generate a ValueError if no GPU has been reserved.
"""
current_thread = threading.current_thread().ident
assert current_thread is not None
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
if not assigned:
raise ValueError(f"No GPU has been reserved for the use of thread {current_thread}")
return assigned[0]
@contextmanager
def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]:
"""Reserve an execution device (e.g. GPU) for exclusive use by a generation thread.
Note that the reservation is done using the current thread's TID.
It would be better to do this using the session ID, but that involves
too many detailed changes to model manager calls.
"""
device = None
with self._device_lock:
current_thread = threading.current_thread().ident
assert current_thread is not None
# look for a device that has already been assigned to this thread
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
if assigned:
device = assigned[0]
# no device already assigned. Get one.
if device is None:
self._free_execution_device.acquire(timeout=timeout)
with self._device_lock:
free_device = [x for x, tid in self._execution_devices.items() if tid == 0]
self._execution_devices[free_device[0]] = current_thread
device = free_device[0]
# we are outside the lock region now
self.logger.info(f"{current_thread} Reserved torch device {device}")
# Tell TorchDevice to use this object to get the torch device.
TorchDevice.set_model_cache(self)
try:
yield device
finally:
with self._device_lock:
self.logger.info(f"{current_thread} Released torch device {device}")
self._execution_devices[device] = 0
self._free_execution_device.release()
torch.cuda.empty_cache()
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
return self._execution_device
@property
def max_cache_size(self) -> float:
@@ -211,16 +154,26 @@ class ModelCache(ModelCacheBase[AnyModel]):
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
with self._ram_lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
size = calc_model_size_by_data(model)
self.make_room(size)
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
size = calc_model_size_by_data(model)
self.make_room(size)
cache_record = CacheRecord(key=key, model=model, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
is_quantized = hasattr(model, "is_quantized") and model.is_quantized
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not is_quantized else None
cache_record = CacheRecord(
key=key,
model=model,
device=self._execution_device
if is_quantized
else self._storage_device, # quantized models are loaded directly into CUDA
is_quantized=is_quantized,
state_dict=state_dict,
size=size,
)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
def get(
self,
@@ -238,37 +191,36 @@ class ModelCache(ModelCacheBase[AnyModel]):
This may raise an IndexError if the model is not in the cache.
"""
with self._ram_lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats:
self.stats.hits += 1
else:
if self.stats:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
# more stats
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
self.stats.hits += 1
else:
if self.stats:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return ModelLocker(
cache=self,
cache_entry=cache_entry,
cache_entry = self._cached_models[key]
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return ModelLocker(
cache=self,
cache_entry=cache_entry,
)
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
@@ -280,34 +232,142 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
return model_key
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
"""Move a copy of the model into the indicated device and return it.
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
# Special handling of the stable-diffusion-3:text_encoder_3
# submodel, when the user has loaded a quantized model.
# The only way to remove the quantized version of this model from VRAM is to
# delete it completely - it can't be moved from device to device
# This also contains a workaround for quantized models that
# persist indefinitely in VRAM
if cache_entry.is_quantized:
self._empty_quantized_state_dict(cache_entry.model)
cache_entry.model = None
self._delete_cache_entry(cache_entry)
vram_in_use = torch.cuda.memory_allocated() + size_required
continue
if not cache_entry.loaded:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
gc.collect()
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
:param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError
"""
with self._ram_lock:
self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
source_device = cache_entry.device
# Some models don't have a state dictionary, in which case the
# stored model will still reside in CPU
if hasattr(cache_entry.model, "to"):
model_in_gpu = copy.deepcopy(cache_entry.model)
assert hasattr(model_in_gpu, "to")
model_in_gpu.to(target_device)
return model_in_gpu
else:
return cache_entry.model # what happens in CPU stays in CPU
# Note: We compare device types so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
# Some models don't have a `to` method, in which case they run in RAM/CPU.
if not hasattr(cache_entry.model, "to"):
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self.storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device, non_blocking=True)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.debug(
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % (self.cache_size() / GIG)
in_ram_models = len(self._cached_models)
self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}")
in_ram_models = 0
in_vram_models = 0
locked_in_vram_models = 0
for cache_record in self._cached_models.values():
if hasattr(cache_record.model, "device"):
if cache_record.model.device == self.storage_device:
in_ram_models += 1
else:
in_vram_models += 1
if cache_record.locked:
locked_in_vram_models += 1
self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
@@ -330,14 +390,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
)
refs = sys.getrefcount(cache_entry.model)
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if refs <= (3 if "onnx" in model_key else 2):
if not cache_entry.locked:
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
@@ -364,26 +422,27 @@ class ModelCache(ModelCacheBase[AnyModel]):
if self.stats:
self.stats.cleared = models_cleared
gc.collect()
TorchDevice.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
if target_device.type != "cuda":
return
vram_device = ( # mem_get_info() needs an indexed device
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
)
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
if needed_size > free_mem:
raise torch.cuda.OutOfMemoryError
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
try:
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]
except ValueError:
pass
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]
del cache_entry
gc.collect()
TorchDevice.empty_cache()
@staticmethod
def _device_name(device: torch.device) -> str:
return f"{device.type}:{device.index}"
def _empty_quantized_state_dict(self, model: AnyModel) -> None:
"""Set all keys of a model's state dict to None.
This is a partial workaround for a poorly-understood bug in
transformers' support for quantized T5EncoderModels (text_encoder_3
of SD3). This allows most of the model to be unloaded from VRAM, but
still leaks 8K of VRAM each time the model is unloaded. Using the quantized
version of stable-diffusion-3-medium is NOT recommended.
"""
assert isinstance(model, torch.nn.Module)
sd = model.state_dict()
for k in sd.keys():
sd[k] = None

View File

@@ -10,8 +10,6 @@ from invokeai.backend.model_manager import AnyModel
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
@@ -31,29 +29,33 @@ class ModelLocker(ModelLockerBase):
"""Return the model without moving it around."""
return self._cache_entry.model
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return self._cache_entry.state_dict
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
self._cache_entry.lock()
try:
device = self._cache.get_execution_device()
model_on_device = self._cache.model_to_device(self._cache_entry, device)
self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}")
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
self._cache_entry.unlock()
raise
except Exception:
self._cache_entry.unlock()
raise
return model_on_device
return self.model
# It is no longer necessary to move the model out of VRAM
# because it will be removed when it goes out of scope
# in the caller's context
def unlock(self) -> None:
"""Call upon exit from context."""
self._cache.print_cuda_stats()
# This is no longer in use in MGPU.
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return None
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(0)
self._cache.print_cuda_stats()

View File

@@ -36,9 +36,11 @@ VARIANT_TO_IN_CHANNEL_MAP = {
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""
# note - will be removed for load_single_file()
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusion3: "SD3",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
@@ -65,7 +67,10 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
result = load_class.from_pretrained(
model_path,
torch_dtype=self._torch_dtype,
)
else:
raise e

View File

@@ -100,6 +100,7 @@ class ModelProbe(object):
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"StableDiffusion3Pipeline": ModelType.Main,
"AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.VAE,
"ControlNetModel": ModelType.ControlNet,
@@ -298,10 +299,13 @@ class ModelProbe(object):
return possible_conf.absolute()
if model_type is ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
config_file = f"stable-diffusion/{config_file}"
if base_type is BaseModelType.StableDiffusion3:
config_file = "stable-diffusion/v3-inference.yaml"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
config_file = f"stable-diffusion/{config_file}"
elif model_type is ModelType.ControlNet:
config_file = (
"controlnet/cldm_v15.yaml"
@@ -374,7 +378,7 @@ def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[Con
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
return MainModelDefaultSettings(width=512, height=512)
elif model_base is BaseModelType.StableDiffusionXL:
elif model_base in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusion3]:
return MainModelDefaultSettings(width=1024, height=1024)
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
return None
@@ -398,7 +402,10 @@ class CheckpointProbeBase(ProbeBase):
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
key = "model.diffusion_model.input_blocks.0.0.weight"
if key not in state_dict:
return ModelVariantType.Normal
in_channels = state_dict[key].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
@@ -425,6 +432,9 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
key_name = "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight"
if key_name in state_dict:
return BaseModelType.StableDiffusion3
else:
raise InvalidModelConfigException("Cannot determine base type")
@@ -596,6 +606,10 @@ class FolderProbeBase(ProbeBase):
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(self.model_path / "model_index.json", "r") as file:
index_conf = json.load(file)
if index_conf.get("_class_name") == "StableDiffusion3Pipeline":
return BaseModelType.StableDiffusion3
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
@@ -644,6 +658,8 @@ class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._config_looks_like_sd3():
return BaseModelType.StableDiffusion3
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
@@ -663,6 +679,15 @@ class VaeFolderProbe(FolderProbeBase):
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _config_looks_like_sd3(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
return config.get("scaling_factor", 0) == 1.5305 and config.get("sample_size") in [512, 1024]
def _guess_name(self) -> str:
name = self.model_path.name
if name == "vae":

View File

@@ -122,6 +122,13 @@ STARTER_MODELS: list[StarterModel] = [
type=ModelType.Main,
dependencies=[sdxl_fp16_vae_fix],
),
StarterModel(
name="Stable Diffusion 3",
base=BaseModelType.StableDiffusion3,
source="stabilityai/stable-diffusion-3-medium-diffusers",
description="The OG Stable Diffusion 3 base model **NOT FOR COMMERCIAL USE**.",
type=ModelType.Main,
),
# endregion
# region VAE
sdxl_fp16_vae_fix,

View File

@@ -35,6 +35,18 @@ def filter_files(
The file list can be obtained from the `files` field of HuggingFaceMetadata,
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
"""
# BRITTLENESS WARNING!!
# The following pattern is designed to match model files that are components of diffusers submodels,
# but not to match other random stuff found in huggingface repos.
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of.
DIFFUSERS_COMPONENT_PATTERN = (
r"model(-fp16)?(-\d+-of-\d+)?(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
)
variant = variant or ModelRepoVariant.Default
paths: List[Path] = []
root = files[0].parts[0]
@@ -45,31 +57,26 @@ def filter_files(
# Start by filtering on model file extensions, discarding images, docs, etc
for file in files:
if file.name.endswith((".json", ".txt")):
paths.append(file)
elif file.name.endswith(
if file.name.endswith(
(
".json",
".txt",
"learned_embeds.bin",
"ip_adapter.bin",
"lora_weights.safetensors",
"weights.pb",
"onnx_data",
"spiece.model",
)
):
paths.append(file)
# BRITTLENESS WARNING!!
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of.
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
elif re.search(DIFFUSERS_COMPONENT_PATTERN, file.name):
paths.append(file)
# limit search to subfolder if requested
if subfolder:
subfolder = root / subfolder
paths = [x for x in paths if x.parent == Path(subfolder)]
# _filter_by_variant uniquifies the paths and returns a set
return sorted(_filter_by_variant(paths, variant))
@@ -97,9 +104,22 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
if variant == ModelRepoVariant.Flax:
result.add(path)
elif path.suffix in [".json", ".txt"]:
elif path.suffix in [".json", ".txt", ".model"]:
result.add(path)
# handle shard patterns
elif re.match(r"model\.fp16-\d+-of-\d+\.safetensors", path.name):
if variant is ModelRepoVariant.FP16:
result.add(path)
else:
continue
elif re.match(r"model-\d+-of-\d+\.safetensors", path.name):
if variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]:
result.add(path)
else:
continue
elif variant in [
ModelRepoVariant.FP16,
ModelRepoVariant.FP32,
@@ -123,6 +143,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
score += 1
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
candidate_variant_label, *_ = str(candidate_variant_label).split("-") # handle shard pattern
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
@@ -139,6 +160,8 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
else:
continue
print(subfolder_weights)
for candidate_list in subfolder_weights.values():
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:

View File

@@ -4,7 +4,6 @@
from __future__ import annotations
import pickle
import threading
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
@@ -35,8 +34,6 @@ with LoRAHelper.apply_lora_unet(unet, loras):
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
_thread_lock = threading.Lock()
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
@@ -109,7 +106,7 @@ class ModelPatcher:
"""
original_weights = {}
try:
with torch.no_grad(), cls._thread_lock:
with torch.no_grad():
for lora, lora_weight in loras:
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
@@ -132,7 +129,9 @@ class ModelPatcher:
dtype = module.weight.dtype
if module_key not in original_weights:
if model_state_dict is None: # no CPU copy of the state dict was provided
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
original_weights[module_key] = model_state_dict[module_key + ".weight"]
else:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0

View File

@@ -32,11 +32,8 @@ class SDXLConditioningInfo(BasicConditioningInfo):
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
assert self.pooled_embeds.device == device
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
result = super().to(device=device, dtype=dtype)
assert self.embeds.device == device
return result
return super().to(device=device, dtype=dtype)
@dataclass

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import math
import threading
from typing import Any, Callable, Optional, Union
import torch
@@ -294,31 +293,24 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
added_cond_kwargs = None
try:
if conditioning_data.is_sdxl():
# tid = threading.current_thread().ident
# print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.uncond_text.pooled_embeds,
conditioning_data.cond_text.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.uncond_text.add_time_ids,
conditioning_data.cond_text.add_time_ids,
],
dim=0,
),
}
except Exception as e:
tid = threading.current_thread().ident
print(f"DEBUG: {tid} {str(e)}")
raise e
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.uncond_text.pooled_embeds,
conditioning_data.cond_text.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.uncond_text.add_time_ids,
conditioning_data.cond_text.add_time_ids,
],
dim=0,
),
}
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings

View File

@@ -7,6 +7,7 @@ from diffusers import (
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
FlowMatchEulerDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
@@ -29,6 +30,7 @@ SCHEDULER_MAP = {
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
"euler_a": (EulerAncestralDiscreteScheduler, {}),
"euler_f": (FlowMatchEulerDiscreteScheduler, {}),
"kdpm_2": (KDPM2DiscreteScheduler, {}),
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),

View File

@@ -1,16 +1,10 @@
"""Torch Device class provides torch device selection services."""
from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union
from typing import Dict, Literal, Optional, Union
import torch
from deprecated import deprecated
from invokeai.app.services.config.config_default import get_config
if TYPE_CHECKING:
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
# legacy APIs
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
CPU_DEVICE = torch.device("cpu")
@@ -48,23 +42,9 @@ PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NA
class TorchDevice:
"""Abstraction layer for torch devices."""
_model_cache: Optional["ModelCacheBase[AnyModel]"] = None
@classmethod
def set_model_cache(cls, cache: "ModelCacheBase[AnyModel]"):
"""Set the current model cache."""
cls._model_cache = cache
@classmethod
def choose_torch_device(cls) -> torch.device:
"""Return the torch.device to use for accelerated inference."""
if cls._model_cache:
return cls._model_cache.get_execution_device()
else:
return cls._choose_device()
@classmethod
def _choose_device(cls) -> torch.device:
app_config = get_config()
if app_config.device != "auto":
device = torch.device(app_config.device)
@@ -76,19 +56,11 @@ class TorchDevice:
device = CPU_DEVICE
return cls.normalize(device)
@classmethod
def execution_devices(cls) -> Set[torch.device]:
"""Return a list of torch.devices that can be used for accelerated inference."""
app_config = get_config()
if app_config.devices is None:
return cls._lookup_execution_devices()
return {torch.device(x) for x in app_config.devices}
@classmethod
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
"""Return the precision to use for accelerated inference."""
device = device or cls.choose_torch_device()
config = get_config()
device = device or cls._choose_device()
if device.type == "cuda" and torch.cuda.is_available():
device_name = torch.cuda.get_device_name(device)
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
@@ -136,13 +108,3 @@ class TorchDevice:
@classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]
@classmethod
def _lookup_execution_devices(cls) -> Set[torch.device]:
if torch.cuda.is_available():
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
elif torch.backends.mps.is_available():
devices = {torch.device("mps")}
else:
devices = {torch.device("cpu")}
return devices

View File

@@ -3,7 +3,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import diffusers
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlNetMixin
# The following import is
# generating import errors with diffusers 028.2
# tried diffusers.loaders.controlnet import FromOriginalControlNetMixin, but this
# fails as well
# from diffusers.loaders import FromOriginalControlNetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from diffusers.models.embeddings import (
@@ -32,7 +37,7 @@ from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger(__name__)
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
class ControlNetModel(ModelMixin, ConfigMixin):
"""
A ControlNet model.

View File

@@ -11,6 +11,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
any: 'base',
'sd-1': 'green',
'sd-2': 'teal',
'sd-3': 'purple',
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
};

View File

@@ -10,6 +10,7 @@ import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
{ value: 'sd-3', label: MODEL_TYPE_MAP['sd-3'] },
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
];

View File

@@ -28,6 +28,8 @@ import {
isModelIdentifierFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isSD3MainModelFieldInputInstance,
isSD3MainModelFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
@@ -53,6 +55,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
@@ -133,6 +136,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -0,0 +1,55 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSD3Models } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
const SD3MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSD3Models();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(SD3MainModelFieldInputComponent);

View File

@@ -631,6 +631,7 @@ export const schema = {
'euler',
'euler_k',
'euler_a',
'euler_f',
'kdpm_2',
'kdpm_2_a',
'dpmpp_2s',
@@ -694,6 +695,7 @@ export const schema = {
'euler',
'euler_k',
'euler_a',
'euler_f',
'kdpm_2',
'kdpm_2_a',
'dpmpp_2s',
@@ -839,7 +841,7 @@ export const schema = {
},
BaseModelType: {
description: 'Base model type.',
enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
enum: ['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner'],
title: 'BaseModelType',
type: 'string',
},
@@ -855,8 +857,11 @@ export const schema = {
'unet',
'text_encoder',
'text_encoder_2',
'text_encoder_3',
'tokenizer',
'tokenizer_2',
'tokenizer_3',
'transformer',
'vae',
'vae_decoder',
'vae_encoder',

View File

@@ -47,6 +47,7 @@ export const zSchedulerField = z.enum([
'heun_k',
'lms_k',
'euler_a',
'euler_f',
'kdpm_2_a',
'lcm',
'tcd',
@@ -55,7 +56,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner']);
const zModelType = z.enum([
'main',
'vae',
@@ -71,8 +72,11 @@ const zSubModelType = z.enum([
'unet',
'text_encoder',
'text_encoder_2',
'text_encoder_3',
'tokenizer',
'tokenizer_2',
'tokenizer_3',
'transformer',
'vae',
'vae_decoder',
'vae_encoder',

View File

@@ -32,11 +32,14 @@ export const MODEL_TYPES = [
'LoRAModelField',
'MainModelField',
'SDXLMainModelField',
'SD3MainModelField',
'SDXLRefinerModelField',
'VaeModelField',
'UNetField',
'TransformerField',
'VAEField',
'CLIPField',
'SD3CLIPField',
'T2IAdapterModelField',
];
@@ -47,6 +50,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
BoardField: 'purple.500',
BooleanField: 'green.500',
CLIPField: 'green.500',
SD3CLIPField: 'green.500',
ColorField: 'pink.300',
ConditioningField: 'cyan.500',
ControlField: 'teal.500',
@@ -62,10 +66,12 @@ export const FIELD_COLORS: { [key: string]: string } = {
MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500',
SD3MainModelField: 'teal.500',
StringField: 'yellow.500',
T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500',
UNetField: 'red.500',
TransformerField: 'red.500',
VAEField: 'blue.500',
VAEModelField: 'teal.500',
};

View File

@@ -119,6 +119,10 @@ const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSD3MainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SD3MainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
originalType: zStatelessFieldType.optional(),
@@ -155,6 +159,7 @@ const zStatefulFieldType = z.union([
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
zSD3MainModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
@@ -466,6 +471,28 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SD3MainModelField
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSD3MainModelFieldType,
originalType: zFieldType.optional(),
default: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSD3MainModelFieldType,
});
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
zSD3MainModelFieldInputInstance.safeParse(val).success;
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
zSD3MainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region VAEModelField
export const zVAEModelFieldValue = zModelIdentifierField.optional();
@@ -662,6 +689,7 @@ export const zStatefulFieldValue = z.union([
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zSDXLRefinerModelFieldValue,
zSD3MainModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
zControlNetModelFieldValue,
@@ -689,6 +717,7 @@ const zStatefulFieldInputInstance = z.union([
zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
zControlNetModelFieldInputInstance,
@@ -717,6 +746,7 @@ const zStatefulFieldInputTemplate = z.union([
zMainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
zSD3MainModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
zLoRAModelFieldInputTemplate,
zControlNetModelFieldInputTemplate,
@@ -746,6 +776,7 @@ const zStatefulFieldOutputTemplate = z.union([
zMainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,
zSD3MainModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,
zLoRAModelFieldOutputTemplate,
zControlNetModelFieldOutputTemplate,

View File

@@ -44,7 +44,7 @@ export const zSchedulerField = z.enum([
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner']);
const zModelName = z.string().min(3);
export const zModelIdentifier = z.object({
model_name: zModelName,

View File

@@ -217,6 +217,20 @@ const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
});
// #endregion
// #region SDXLMainModelField
const zSD3MainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SD3MainModelField'),
});
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
type: zSD3MainModelFieldType,
value: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
type: zSD3MainModelFieldType,
});
// #endregion
// #region VAEModelField
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
@@ -339,6 +353,7 @@ const zStatefulFieldType = z.union([
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
zSD3MainModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
@@ -378,6 +393,7 @@ const zStatefulFieldInputInstance = z.union([
zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
zControlNetModelFieldInputInstance,
@@ -402,6 +418,7 @@ const zStatefulFieldOutputInstance = z.union([
zMainModelFieldOutputInstance,
zSDXLMainModelFieldOutputInstance,
zSDXLRefinerModelFieldOutputInstance,
zSD3MainModelFieldOutputInstance,
zVAEModelFieldOutputInstance,
zLoRAModelFieldOutputInstance,
zControlNetModelFieldOutputInstance,

View File

@@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
MainModelField: undefined,
SchedulerField: 'euler',
SDXLMainModelField: undefined,
SD3MainModelField: undefined,
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,

View File

@@ -15,6 +15,7 @@ import type {
MainModelFieldInputTemplate,
ModelIdentifierFieldInputTemplate,
SchedulerFieldInputTemplate,
SD3MainModelFieldInputTemplate,
SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate,
StatefulFieldType,
@@ -193,6 +194,20 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefiner
return template;
};
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: SD3MainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -375,6 +390,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate,

View File

@@ -30,6 +30,7 @@ const MODEL_FIELD_TYPES = [
'MainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'SD3MainModelField',
'VAEModelField',
'LoRAModelField',
'ControlNetModelField',

View File

@@ -39,7 +39,7 @@ const ParamClipSkip = () => {
return CLIP_SKIP_MAP[model.base].markers;
}, [model]);
if (model?.base === 'sdxl') {
if (model?.base === 'sdxl' || model?.base === 'sd-3') {
return null;
}

View File

@@ -7,6 +7,7 @@ export const MODEL_TYPE_MAP = {
any: 'Any',
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
'sd-3': 'Stable Diffusion 3.x',
sdxl: 'Stable Diffusion XL',
'sdxl-refiner': 'Stable Diffusion XL Refiner',
};
@@ -18,6 +19,7 @@ export const MODEL_TYPE_SHORT_MAP = {
any: 'Any',
'sd-1': 'SD1.X',
'sd-2': 'SD2.X',
'sd-3': 'SD3.X',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
};
@@ -38,6 +40,11 @@ export const CLIP_SKIP_MAP = {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
// TODO: Update this when we have more details on how CLIP SKIP works with SD3
'sd-3': {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
sdxl: {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
@@ -73,6 +80,7 @@ export const SCHEDULER_OPTIONS: ComboboxOption[] = [
{ value: 'heun_k', label: 'Heun Karras' },
{ value: 'lms_k', label: 'LMS Karras' },
{ value: 'euler_a', label: 'Euler Ancestral' },
{ value: 'euler_f', label: 'Euler Flow Match' },
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
{ value: 'lcm', label: 'LCM' },
{ value: 'tcd', label: 'TCD' },

View File

@@ -10,6 +10,7 @@ import {
isNonRefinerMainModelConfig,
isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig,
isSD3MainModelModelConfig,
isSDXLMainModelModelConfig,
isT2IAdapterModelConfig,
isTIModelConfig,
@@ -35,6 +36,7 @@ export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);

File diff suppressed because one or more lines are too long

View File

@@ -109,7 +109,11 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
};
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2' || config.base === 'sd-3');
};
export const isSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'sd-3';
};
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {

View File

@@ -39,6 +39,7 @@ from invokeai.app.invocations.model import (
ModelIdentifierField,
ModelLoaderOutput,
SDXLLoRALoaderOutput,
TransformerField,
UNetField,
UNetOutput,
VAEField,
@@ -117,6 +118,7 @@ __all__ = [
# invokeai.app.invocations.model
"ModelIdentifierField",
"UNetField",
"TransformerField",
"CLIPField",
"VAEField",
"UNetOutput",

View File

@@ -33,30 +33,32 @@ classifiers = [
]
dependencies = [
# Core generation dependencies, pinned for reproducible builds.
"accelerate==0.30.1",
"accelerate",
"bitsandbytes",
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.2",
"controlnet-aux==0.0.7",
"diffusers[torch]==0.27.2",
"diffusers[torch]",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
"numpy", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
"onnx==1.15.0",
"onnxruntime==1.16.3",
"opencv-python==4.9.0.80",
"pytorch-lightning==2.1.3",
"pytorch-lightning",
"safetensors==0.4.3",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch==2.2.2",
"torchmetrics==0.11.4",
"torch",
"torchmetrics",
"torchsde==0.2.6",
"torchvision==0.17.2",
"transformers==4.41.1",
"torchvision",
"transformers",
"sentencepiece==0.1.99",
# Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.11.0",
"fastapi==0.111.0",
"huggingface-hub==0.23.1",
"huggingface-hub",
"pydantic-settings==2.2.1",
"pydantic==2.7.2",
"python-socketio==5.11.1",
@@ -73,7 +75,7 @@ dependencies = [
"easing-functions",
"einops",
"facexlib",
"matplotlib", # needed for plotting of Penner easing functions
"matplotlib", # needed for plotting of Penner easing functions
"npyscreen",
"omegaconf",
"picklescan",

View File

@@ -1,54 +0,0 @@
#!/bin/env python
from argparse import ArgumentParser, Namespace
from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig, get_config
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
def get_args() -> Namespace:
parser = ArgumentParser(description="Update models database from yaml file")
parser.add_argument("--root", type=Path, required=False, default=None)
parser.add_argument("--yaml_file", type=Path, required=False, default=None)
return parser.parse_args()
def populate_config() -> InvokeAIAppConfig:
args = get_args()
config = get_config()
if args.root:
config._root = args.root
if args.yaml_file:
config.legacy_models_yaml_path = args.yaml_file
else:
config.legacy_models_yaml_path = config.root_path / "configs/models.yaml"
return config
def initialize_installer(config: InvokeAIAppConfig) -> ModelInstallService:
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db)
queue = DownloadQueueService()
queue.start()
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=queue)
return installer
def main() -> None:
config = populate_config()
installer = initialize_installer(config)
installer._migrate_yaml(rename_yaml=False, overwrite_db=True)
print("\n<INSTALLED MODELS>")
print("\t".join(["key", "name", "type", "path"]))
for model in installer.record_store.all_models():
print("\t".join([model.key, model.name, model.type, (config.models_path / model.path).as_posix()]))
if __name__ == "__main__":
main()

View File

@@ -14,14 +14,13 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = mm2_model_manager.install.register_path(embedding_file)
with mm2_model_manager.load.ram_cache.reserve_execution_device():
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
assert loaded_model is not None
assert loaded_model.config.key == key
with loaded_model as model:
assert isinstance(model, TextualInversionModelRaw)
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
assert loaded_model is not None
assert loaded_model.config.key == key
with loaded_model as model:
assert isinstance(model, TextualInversionModelRaw)
config = mm2_model_manager.store.get_model(key)
loaded_model_2 = mm2_model_manager.load.load_model(config)
config = mm2_model_manager.store.get_model(key)
loaded_model_2 = mm2_model_manager.load.load_model(config)
assert loaded_model.config.key == loaded_model_2.config.key
assert loaded_model.config.key == loaded_model_2.config.key

View File

@@ -89,10 +89,11 @@ def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
ram_cache = ModelCache(
logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram,
max_vram_cache_size=mm2_app_config.vram,
)
convert_cache = ModelConvertCache(mm2_app_config.convert_cache_path)
return ModelLoadService(

View File

@@ -8,9 +8,7 @@ import pytest
import torch
from invokeai.app.services.config import get_config
from invokeai.backend.model_manager.load import ModelCache
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
@@ -22,7 +20,6 @@ device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", t
def test_device_choice(device_name):
config = get_config()
config.device = device_name
TorchDevice.set_model_cache(None) # disable dynamic selection of GPU device
torch_device = TorchDevice.choose_torch_device()
assert torch_device == torch.device(device_name)
@@ -133,32 +130,3 @@ def test_legacy_precision_name():
assert "float16" == choose_precision(torch.device("cuda"))
assert "float16" == choose_precision(torch.device("mps"))
assert "float32" == choose_precision(torch.device("cpu"))
def test_multi_device_support_1():
config = get_config()
config.devices = ["cuda:0", "cuda:1"]
assert TorchDevice.execution_devices() == {torch.device("cuda:0"), torch.device("cuda:1")}
def test_multi_device_support_2():
config = get_config()
config.devices = None
with (
patch("torch.cuda.device_count", return_value=3),
patch("torch.cuda.is_available", return_value=True),
):
assert TorchDevice.execution_devices() == {
torch.device("cuda:0"),
torch.device("cuda:1"),
torch.device("cuda:2"),
}
def test_multi_device_support_3():
config = get_config()
config.devices = ["cuda:0", "cuda:1"]
cache = ModelCache()
with cache.reserve_execution_device() as gpu:
assert gpu in [torch.device(x) for x in config.devices]
assert TorchDevice.choose_torch_device() == gpu

View File

@@ -17,6 +17,7 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403
@@ -48,13 +49,13 @@ def mock_services() -> InvocationServices:
model_manager=None, # type: ignore
download_queue=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
tensors=None, # type: ignore
conditioning=None, # type: ignore
performance_statistics=None, # type: ignore
)

View File

@@ -92,6 +92,7 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
assert config.host == "192.168.1.1"
assert config.port == 8080
assert config.ram == 100
assert config.vram == 50
assert config.legacy_models_yaml_path == Path("/custom/models.yaml")
# This should be stripped out
assert not hasattr(config, "esrgan")