From 5adce0266ba434e1a4d7da53d9b7371387141fdc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 7 Oct 2025 14:47:53 +1100 Subject: [PATCH] refactor(mm): remove legacy probe, new configs dir structure, update imports --- invokeai/app/api/routers/model_manager.py | 6 +- invokeai/app/invocations/cogview4_denoise.py | 2 +- .../app/invocations/cogview4_model_loader.py | 3 +- .../app/invocations/create_gradient_mask.py | 11 +- invokeai/app/invocations/denoise_latents.py | 2 +- invokeai/app/invocations/flux_ip_adapter.py | 4 +- invokeai/app/invocations/flux_model_loader.py | 4 +- invokeai/app/invocations/flux_redux.py | 4 +- invokeai/app/invocations/flux_text_encoder.py | 2 +- invokeai/app/invocations/flux_vae_encode.py | 2 +- invokeai/app/invocations/image_to_latents.py | 2 +- invokeai/app/invocations/ip_adapter.py | 4 +- invokeai/app/invocations/model.py | 4 +- invokeai/app/invocations/sd3_denoise.py | 2 +- invokeai/app/services/events/events_base.py | 4 +- invokeai/app/services/events/events_common.py | 4 +- .../model_install/model_install_common.py | 2 +- .../model_install/model_install_default.py | 9 +- .../services/model_load/model_load_base.py | 2 +- .../services/model_load/model_load_default.py | 2 +- .../app/services/model_manager/__init__.py | 2 - .../model_records/model_records_base.py | 10 +- .../model_records/model_records_sql.py | 27 +- .../model_relationships_default.py | 2 +- .../app/services/shared/invocation_context.py | 6 +- .../migrations/migration_22.py | 2 +- invokeai/app/util/custom_openapi.py | 2 +- .../backend/flux/flux_state_dict_utils.py | 7 +- invokeai/backend/model_manager/__init__.py | 45 - invokeai/backend/model_manager/config.py | 2584 ----------------- .../backend/model_manager/configs/base.py | 2 +- .../model_manager/configs/controlnet.py | 39 +- .../backend/model_manager/configs/factory.py | 48 +- .../backend/model_manager/configs/lora.py | 2 +- .../backend/model_manager/configs/main.py | 4 +- .../model_manager/configs/t2i_adapter.py | 2 +- .../backend/model_manager/legacy_probe.py | 1034 ------- .../backend/model_manager/load/load_base.py | 2 +- .../model_manager/load/load_default.py | 5 +- .../load/model_loader_registry.py | 6 +- .../load/model_loaders/clip_vision.py | 6 +- .../load/model_loaders/cogview4.py | 7 +- .../load/model_loaders/controlnet.py | 6 +- .../model_manager/load/model_loaders/flux.py | 19 +- .../load/model_loaders/generic_diffusers.py | 11 +- .../load/model_loaders/ip_adapter.py | 2 +- .../load/model_loaders/llava_onevision.py | 4 +- .../model_manager/load/model_loaders/lora.py | 2 +- .../model_manager/load/model_loaders/onnx.py | 2 +- .../load/model_loaders/sig_lip.py | 4 +- .../model_loaders/spandrel_image_to_image.py | 4 +- .../load/model_loaders/stable_diffusion.py | 7 +- .../load/model_loaders/textual_inversion.py | 2 +- .../model_manager/load/model_loaders/vae.py | 3 +- .../util/lora_metadata_extractor.py | 3 +- invokeai/backend/util/test_utils.py | 3 +- .../components/ParamDenoisingStrength.tsx | 7 +- .../controlLayers/store/validators.ts | 2 +- .../subpanels/ModelPanel/ModelView.tsx | 11 +- .../nodes/util/graph/generation/Graph.test.ts | 2 +- .../graph/generation/addControlAdapters.ts | 6 +- .../util/graph/generation/buildFLUXGraph.ts | 2 +- .../nodes/util/graph/graphBuilderUtils.ts | 2 +- .../frontend/web/src/services/api/schema.ts | 1768 +++++------ .../frontend/web/src/services/api/types.ts | 7 +- scripts/classify-model.py | 2 +- tests/test_model_probe.py | 2 +- 67 files changed, 892 insertions(+), 4910 deletions(-) delete mode 100644 invokeai/backend/model_manager/config.py delete mode 100644 invokeai/backend/model_manager/legacy_probe.py diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 84db65252e..7add1d09cf 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -28,9 +28,8 @@ from invokeai.app.services.model_records import ( UnknownModelException, ) from invokeai.app.util.suppress_output import SuppressOutput -from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType -from invokeai.backend.model_manager.config import ( - AnyModelConfig, +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.main import ( Main_Checkpoint_SD1_Config, Main_Checkpoint_SD2_Config, Main_Checkpoint_SDXL_Config, @@ -47,6 +46,7 @@ from invokeai.backend.model_manager.starter_models import ( StarterModelBundle, StarterModelWithoutDependencies, ) +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"]) diff --git a/invokeai/app/invocations/cogview4_denoise.py b/invokeai/app/invocations/cogview4_denoise.py index c0b962ba31..070d8a3478 100644 --- a/invokeai/app/invocations/cogview4_denoise.py +++ b/invokeai/app/invocations/cogview4_denoise.py @@ -22,7 +22,7 @@ from invokeai.app.invocations.model import TransformerField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional -from invokeai.backend.model_manager.config import BaseModelType +from invokeai.backend.model_manager.taxonomy import BaseModelType from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import CogView4ConditioningInfo diff --git a/invokeai/app/invocations/cogview4_model_loader.py b/invokeai/app/invocations/cogview4_model_loader.py index 9db4f3c053..fbafcd345f 100644 --- a/invokeai/app/invocations/cogview4_model_loader.py +++ b/invokeai/app/invocations/cogview4_model_loader.py @@ -13,8 +13,7 @@ from invokeai.app.invocations.model import ( VAEField, ) from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_manager.config import SubModelType -from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType @invocation_output("cogview4_model_loader_output") diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py index f6e046d096..8a7e7c5231 100644 --- a/invokeai/app/invocations/create_gradient_mask.py +++ b/invokeai/app/invocations/create_gradient_mask.py @@ -20,9 +20,7 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation from invokeai.app.invocations.model import UNetField, VAEField from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_manager import LoadedModel -from invokeai.backend.model_manager.config import Main_Config_Base -from invokeai.backend.model_manager.taxonomy import ModelVariantType +from invokeai.backend.model_manager.taxonomy import FluxVariantType, ModelType, ModelVariantType from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor @@ -182,10 +180,11 @@ class CreateGradientMaskInvocation(BaseInvocation): if self.unet is not None and self.vae is not None and self.image is not None: # all three fields must be present at the same time main_model_config = context.models.get_config(self.unet.unet.key) - assert isinstance(main_model_config, Main_Config_Base) - if main_model_config.variant is ModelVariantType.Inpaint: + assert main_model_config.type is ModelType.Main + variant = getattr(main_model_config, "variant", None) + if variant is ModelVariantType.Inpaint or variant is FluxVariantType.DevFill: mask = dilated_mask_tensor - vae_info: LoadedModel = context.models.load(self.vae.vae) + vae_info = context.models.load(self.vae.vae) image = context.images.get_pil(self.image.image_name) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 37b385914c..bb114263e2 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -39,7 +39,7 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelVariantType from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.patches.layer_patcher import LayerPatcher diff --git a/invokeai/app/invocations/flux_ip_adapter.py b/invokeai/app/invocations/flux_ip_adapter.py index c564023a3a..4a1997c512 100644 --- a/invokeai/app/invocations/flux_ip_adapter.py +++ b/invokeai/app/invocations/flux_ip_adapter.py @@ -16,9 +16,7 @@ from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_manager.config import ( - IPAdapter_Checkpoint_FLUX_Config, -) +from invokeai.backend.model_manager.configs.ip_adapter import IPAdapter_Checkpoint_FLUX_Config from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType diff --git a/invokeai/app/invocations/flux_model_loader.py b/invokeai/app/invocations/flux_model_loader.py index 2803db48e0..eaac82bafc 100644 --- a/invokeai/app/invocations/flux_model_loader.py +++ b/invokeai/app/invocations/flux_model_loader.py @@ -14,9 +14,7 @@ from invokeai.app.util.t5_model_identifier import ( preprocess_t5_tokenizer_model_identifier, ) from invokeai.backend.flux.util import get_flux_max_seq_length -from invokeai.backend.model_manager.config import ( - Checkpoint_Config_Base, -) +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType diff --git a/invokeai/app/invocations/flux_redux.py b/invokeai/app/invocations/flux_redux.py index 3e34497b10..403d78b078 100644 --- a/invokeai/app/invocations/flux_redux.py +++ b/invokeai/app/invocations/flux_redux.py @@ -24,9 +24,9 @@ from invokeai.app.invocations.primitives import ImageField from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel -from invokeai.backend.model_manager import BaseModelType, ModelType -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.starter_models import siglip +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 77b6187840..c395a0bf22 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -17,7 +17,7 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.modules.conditioner import HFEncoder -from invokeai.backend.model_manager import ModelFormat +from invokeai.backend.model_manager.taxonomy import ModelFormat from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw diff --git a/invokeai/app/invocations/flux_vae_encode.py b/invokeai/app/invocations/flux_vae_encode.py index 2932517edc..4ec0365c2c 100644 --- a/invokeai/app/invocations/flux_vae_encode.py +++ b/invokeai/app/invocations/flux_vae_encode.py @@ -12,7 +12,7 @@ from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.modules.autoencoder import AutoEncoder -from invokeai.backend.model_manager import LoadedModel +from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py index 552f5edb1b..fde70a34fd 100644 --- a/invokeai/app/invocations/image_to_latents.py +++ b/invokeai/app/invocations/image_to_latents.py @@ -23,7 +23,7 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_manager import LoadedModel +from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 7c3234bdc7..2b2931e78f 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -11,8 +11,8 @@ from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_manager.config import ( - AnyModelConfig, +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.ip_adapter import ( IPAdapter_Checkpoint_Config_Base, IPAdapter_InvokeAI_Config_Base, ) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 327de6ac70..753ae77c55 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -12,9 +12,7 @@ from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType diff --git a/invokeai/app/invocations/sd3_denoise.py b/invokeai/app/invocations/sd3_denoise.py index f43f26ae0e..b9d69369b7 100644 --- a/invokeai/app/invocations/sd3_denoise.py +++ b/invokeai/app/invocations/sd3_denoise.py @@ -23,7 +23,7 @@ from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional -from invokeai.backend.model_manager import BaseModelType +from invokeai.backend.model_manager.taxonomy import BaseModelType from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index fc0f0bb2c6..c70ef3fa16 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -44,8 +44,8 @@ if TYPE_CHECKING: SessionQueueItem, SessionQueueStatus, ) - from invokeai.backend.model_manager import SubModelType - from invokeai.backend.model_manager.config import AnyModelConfig + from invokeai.backend.model_manager.configs.factory import AnyModelConfig + from invokeai.backend.model_manager.taxonomy import SubModelType class EventServiceBase: diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 8fbb08015a..2f99529398 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -16,8 +16,8 @@ from invokeai.app.services.session_queue.session_queue_common import ( ) from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_manager import SubModelType -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.taxonomy import SubModelType if TYPE_CHECKING: from invokeai.app.services.download.download_base import DownloadJob diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py index fea75d7375..67832466f3 100644 --- a/invokeai/app/services/model_install/model_install_common.py +++ b/invokeai/app/services/model_install/model_install_common.py @@ -10,7 +10,7 @@ from typing_extensions import Annotated from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob from invokeai.app.services.model_records import ModelRecordChanges -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 10a954a563..53bb5cc12d 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -35,10 +35,9 @@ from invokeai.app.services.model_install.model_install_common import ( ) from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records.model_records_base import ModelRecordChanges -from invokeai.backend.model_manager.config import ( +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base +from invokeai.backend.model_manager.configs.factory import ( AnyModelConfig, - Checkpoint_Config_Base, - InvalidModelConfigException, ModelConfigFactory, ) from invokeai.backend.model_manager.metadata import ( @@ -532,7 +531,7 @@ class ModelInstallService(ModelInstallServiceBase): x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts ): install_job.set_error( - InvalidModelConfigException( + ValueError( f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." ) ) @@ -602,7 +601,7 @@ class ModelInstallService(ModelInstallServiceBase): return ModelConfigFactory.from_model_on_disk( mod=model_path, - overrides=deepcopy(fields), + override_fields=deepcopy(fields), hash_algo=hash_algo, ) diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 8aae80e29d..87a405b4ea 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Callable, Optional -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index ad4ad97a02..2e2d2ae219 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -11,7 +11,7 @@ from torch import load as torch_load from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load import ( LoadedModel, LoadedModelWithoutConfig, diff --git a/invokeai/app/services/model_manager/__init__.py b/invokeai/app/services/model_manager/__init__.py index aad67ff352..e703d4f1ff 100644 --- a/invokeai/app/services/model_manager/__init__.py +++ b/invokeai/app/services/model_manager/__init__.py @@ -1,12 +1,10 @@ """Initialization file for model manager service.""" from invokeai.app.services.model_manager.model_manager_default import ModelManagerService, ModelManagerServiceBase -from invokeai.backend.model_manager import AnyModelConfig from invokeai.backend.model_manager.load import LoadedModel __all__ = [ "ModelManagerServiceBase", "ModelManagerService", - "AnyModelConfig", "LoadedModel", ] diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 48f5317536..2d34832dbe 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -12,12 +12,10 @@ from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.util.model_exclude_null import BaseModelExcludeNull -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ControlAdapterDefaultSettings, - LoraModelDefaultSettings, - MainModelDefaultSettings, -) +from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.lora import LoraModelDefaultSettings +from invokeai.backend.model_manager.configs.main import MainModelDefaultSettings from invokeai.backend.model_manager.taxonomy import ( BaseModelType, ClipVariantType, diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 7fad1761cc..6d9a33ba4a 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -58,10 +58,7 @@ from invokeai.app.services.model_records.model_records_base import ( ) from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ModelConfigFactory, -) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType @@ -157,7 +154,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): record_as_dict[field_name] = getattr(changes, field_name) # 3. create a new model config from the updated dict - record = ModelConfigFactory.make_config(record_as_dict) + record = ModelConfigFactory.from_dict(record_as_dict) # If we get this far, the updated model config is valid, so we can save it to the database. json_serialized = record.model_dump_json() @@ -187,7 +184,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.transaction() as cursor: cursor.execute( """--sql - SELECT config, strftime('%s',updated_at) FROM models + SELECT config FROM models WHERE id=?; """, (key,), @@ -195,14 +192,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): rows = cursor.fetchone() if not rows: raise UnknownModelException("model not found") - model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) + model = ModelConfigFactory.from_dict(json.loads(rows[0])) return model def get_model_by_hash(self, hash: str) -> AnyModelConfig: with self._db.transaction() as cursor: cursor.execute( """--sql - SELECT config, strftime('%s',updated_at) FROM models + SELECT config FROM models WHERE hash=?; """, (hash,), @@ -210,7 +207,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): rows = cursor.fetchone() if not rows: raise UnknownModelException("model not found") - model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) + model = ModelConfigFactory.from_dict(json.loads(rows[0])) return model def exists(self, key: str) -> bool: @@ -278,7 +275,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): cursor.execute( f"""--sql - SELECT config, strftime('%s',updated_at) + SELECT config FROM models {where} ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason; @@ -291,7 +288,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): results: list[AnyModelConfig] = [] for row in result: try: - model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1]) + model_config = ModelConfigFactory.from_dict(json.loads(row[0])) except pydantic.ValidationError as e: # We catch this error so that the app can still run if there are invalid model configs in the database. # One reason that an invalid model config might be in the database is if someone had to rollback from a @@ -315,12 +312,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.transaction() as cursor: cursor.execute( """--sql - SELECT config, strftime('%s',updated_at) FROM models + SELECT config FROM models WHERE path=?; """, (str(path),), ) - results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()] + results = [ModelConfigFactory.from_dict(json.loads(x[0])) for x in cursor.fetchall()] return results def search_by_hash(self, hash: str) -> List[AnyModelConfig]: @@ -328,12 +325,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.transaction() as cursor: cursor.execute( """--sql - SELECT config, strftime('%s',updated_at) FROM models + SELECT config FROM models WHERE hash=?; """, (hash,), ) - results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()] + results = [ModelConfigFactory.from_dict(json.loads(x[0])) for x in cursor.fetchall()] return results def list_models( diff --git a/invokeai/app/services/model_relationships/model_relationships_default.py b/invokeai/app/services/model_relationships/model_relationships_default.py index 67fa6c0069..e4da482ff2 100644 --- a/invokeai/app/services/model_relationships/model_relationships_default.py +++ b/invokeai/app/services/model_relationships/model_relationships_default.py @@ -1,6 +1,6 @@ from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig class ModelRelationshipsService(ModelRelationshipsServiceABC): diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 16aacbb985..97291230e0 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -19,10 +19,8 @@ from invokeai.app.services.model_records.model_records_base import UnknownModelE from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.util.step_callback import diffusion_step_callback -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - Config_Base, -) +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py index 08b0e76068..1d9c81529e 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py @@ -8,7 +8,7 @@ from pydantic import ValidationError from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration -from invokeai.backend.model_manager.config import AnyModelConfigValidator +from invokeai.backend.model_manager.configs.factory import AnyModelConfigValidator class NormalizeResult(NamedTuple): diff --git a/invokeai/app/util/custom_openapi.py b/invokeai/app/util/custom_openapi.py index 2e07622530..d400e0ff11 100644 --- a/invokeai/app/util/custom_openapi.py +++ b/invokeai/app/util/custom_openapi.py @@ -12,7 +12,7 @@ from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFie from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.session_processor.session_processor_common import ProgressImage -from invokeai.backend.model_manager.config import AnyModelConfigValidator +from invokeai.backend.model_manager.configs.factory import AnyModelConfigValidator from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger() diff --git a/invokeai/backend/flux/flux_state_dict_utils.py b/invokeai/backend/flux/flux_state_dict_utils.py index 8ffab54c68..c306c88f96 100644 --- a/invokeai/backend/flux/flux_state_dict_utils.py +++ b/invokeai/backend/flux/flux_state_dict_utils.py @@ -1,10 +1,7 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from invokeai.backend.model_manager.legacy_probe import CkptType +from typing import Any -def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None: +def get_flux_in_channels_from_state_dict(state_dict: dict[str | int, Any]) -> int | None: """Gets the in channels from the state dict.""" # "Standard" FLUX models use "img_in.weight", but some community fine tunes use diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index a167687d2e..e69de29bb2 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -1,45 +0,0 @@ -"""Re-export frequently-used symbols from the Model Manager backend.""" - -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - InvalidModelConfigException, - Config_Base, - ModelConfigFactory, -) -from invokeai.backend.model_manager.legacy_probe import ModelProbe -from invokeai.backend.model_manager.load import LoadedModel -from invokeai.backend.model_manager.search import ModelSearch -from invokeai.backend.model_manager.taxonomy import ( - AnyModel, - AnyVariant, - BaseModelType, - ClipVariantType, - ModelFormat, - ModelRepoVariant, - ModelSourceType, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SubModelType, -) - -__all__ = [ - "AnyModelConfig", - "InvalidModelConfigException", - "LoadedModel", - "ModelConfigFactory", - "ModelProbe", - "ModelSearch", - "Config_Base", - "AnyModel", - "AnyVariant", - "BaseModelType", - "ClipVariantType", - "ModelFormat", - "ModelRepoVariant", - "ModelSourceType", - "ModelType", - "ModelVariantType", - "SchedulerPredictionType", - "SubModelType", -] diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py deleted file mode 100644 index 188ac9ad11..0000000000 --- a/invokeai/backend/model_manager/config.py +++ /dev/null @@ -1,2584 +0,0 @@ -# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team -""" -Configuration definitions for image generation models. - -Typical usage: - - from invokeai.backend.model_manager import ModelConfigFactory - raw = dict(path='models/sd-1/main/foo.ckpt', - name='foo', - base='sd-1', - type='main', - config='configs/stable-diffusion/v1-inference.yaml', - variant='normal', - format='checkpoint' - ) - config = ModelConfigFactory.make_config(raw) - print(config.name) - -Validation errors will raise an InvalidModelConfigException error. - -""" - -import json -import logging -import re -import time -from abc import ABC -from enum import Enum -from functools import cache -from inspect import isabstract -from pathlib import Path -from typing import ( - ClassVar, - Literal, - Optional, - Self, - Type, - Union, -) - -import torch -from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError -from pydantic_core import CoreSchema, PydanticUndefined, SchemaValidator -from typing_extensions import Annotated, Any, Dict - -from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.misc import uuid_string -from invokeai.backend.flux.controlnet.state_dict_utils import ( - is_state_dict_instantx_controlnet, - is_state_dict_xlabs_controlnet, -) -from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter -from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux -from invokeai.backend.model_hash.hash_validator import validate_hash -from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS -from invokeai.backend.model_manager.model_on_disk import ModelOnDisk -from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora -from invokeai.backend.model_manager.taxonomy import ( - AnyVariant, - BaseModelType, - ClipVariantType, - FluxLoRAFormat, - FluxVariantType, - ModelFormat, - ModelRepoVariant, - ModelSourceType, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SubModelType, - variant_type_adapter, -) -from invokeai.backend.model_manager.util.model_util import lora_token_vector_length -from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control -from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor -from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel -from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES - -logger = logging.getLogger(__name__) -app_config = get_config() - - -class InvalidModelConfigException(Exception): - """Exception for when config parser doesn't recognize this combination of model type and format.""" - - pass - - -class NotAMatch(Exception): - """Exception for when a model does not match a config class. - - Args: - config_class: The config class that was being tested. - reason: The reason why the model did not match. - """ - - def __init__( - self, - config_class: type, - reason: str, - ): - super().__init__(f"{config_class.__name__}: {reason}") - - -DEFAULTS_PRECISION = Literal["fp16", "fp32"] - - -class FieldValidator: - """Utility class for validating individual fields of a Pydantic model without instantiating the whole model. - - See: https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144 - """ - - @staticmethod - def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema: - """Find the Pydantic core schema for a specific field in a model.""" - schema: CoreSchema = model.__pydantic_core_schema__.copy() - # we shallow copied, be careful not to mutate the original schema! - - assert schema["type"] in ["definitions", "model"] - - # find the field schema - field_schema = schema["schema"] # type: ignore - while "fields" not in field_schema: - field_schema = field_schema["schema"] # type: ignore - - field_schema = field_schema["fields"][field_name]["schema"] # type: ignore - - # if the original schema is a definition schema, replace the model schema with the field schema - if schema["type"] == "definitions": - schema["schema"] = field_schema - return schema - else: - return field_schema - - @cache - @staticmethod - def get_validator(model: type[BaseModel], field_name: str) -> SchemaValidator: - """Get a SchemaValidator for a specific field in a model.""" - return SchemaValidator(FieldValidator.find_field_schema(model, field_name)) - - @staticmethod - def validate_field(model: type[BaseModel], field_name: str, value: Any) -> Any: - """Validate a value for a specific field in a model.""" - return FieldValidator.get_validator(model, field_name).validate_python(value) - - -def has_any_keys(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool: - """Returns true if the state dict has any of the specified keys.""" - _keys = {keys} if isinstance(keys, str) else keys - return any(key in state_dict for key in _keys) - - -def has_any_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool: - """Returns true if the state dict has any keys starting with any of the specified prefixes.""" - _prefixes = {prefixes} if isinstance(prefixes, str) else prefixes - return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str)) - - -def has_any_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool: - """Returns true if the state dict has any keys ending with any of the specified suffixes.""" - _suffixes = {suffixes} if isinstance(suffixes, str) else suffixes - return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str)) - - -def common_config_paths(path: Path) -> set[Path]: - """Returns common config file paths for models stored in directories.""" - return {path / "config.json", path / "model_index.json"} - - -# These utility functions are tightly coupled to the config classes below in order to make the process of raising -# NotAMatch exceptions as easy and consistent as possible. - - -def _get_config_or_raise( - config_class: type, - config_path: Path | set[Path], -) -> dict[str, Any]: - """Load the config file at the given path, or raise NotAMatch if it cannot be loaded.""" - paths_to_check = config_path if isinstance(config_path, set) else {config_path} - - problems: dict[Path, str] = {} - - for p in paths_to_check: - if not p.exists(): - problems[p] = "file does not exist" - continue - - try: - with open(p, "r") as file: - config = json.load(file) - - return config - except Exception as e: - problems[p] = str(e) - continue - - raise NotAMatch(config_class, f"unable to load config file(s): {problems}") - - -def _get_class_name_from_config( - config_class: type, - config_path: Path | set[Path], -) -> str: - """Load the config file and return the class name. - - Raises: - NotAMatch if the config file is missing or does not contain a valid class name. - """ - - config = _get_config_or_raise(config_class, config_path) - - try: - if "_class_name" in config: - config_class_name = config["_class_name"] - elif "architectures" in config: - config_class_name = config["architectures"][0] - else: - raise ValueError("missing _class_name or architectures field") - except Exception as e: - raise NotAMatch(config_class, f"unable to determine class name from config file: {config_path}") from e - - if not isinstance(config_class_name, str): - raise NotAMatch(config_class, f"_class_name or architectures field is not a string: {config_class_name}") - - return config_class_name - - -def _validate_class_name(config_class: type[BaseModel], config_path: Path | set[Path], expected: set[str]) -> None: - """Check if the class name in the config file matches the expected class names. - - Args: - config_class: The config class that is being tested. - config_path: The path to the config file. - expected: The expected class names.""" - - class_name = _get_class_name_from_config(config_class, config_path) - if class_name not in expected: - raise NotAMatch(config_class, f"invalid class name from config: {class_name}") - - -def _validate_override_fields( - config_class: type[BaseModel], - override_fields: dict[str, Any], -) -> None: - """Check if the provided override fields are valid for the config class. - - Args: - config_class: The config class that is being tested. - override_fields: The override fields provided by the user. - - Raises: - NotAMatch if any override field is invalid for the config. - """ - for field_name, override_value in override_fields.items(): - if field_name not in config_class.model_fields: - raise NotAMatch(config_class, f"unknown override field: {field_name}") - try: - FieldValidator.validate_field(config_class, field_name, override_value) - except ValidationError as e: - raise NotAMatch(config_class, f"invalid override for field '{field_name}': {e}") from e - - -def _validate_is_file( - config_class: type, - mod: ModelOnDisk, -) -> None: - """Raise NotAMatch if the model path is not a file.""" - if not mod.path.is_file(): - raise NotAMatch(config_class, "model path is not a file") - - -def _validate_is_dir( - config_class: type, - mod: ModelOnDisk, -) -> None: - """Raise NotAMatch if the model path is not a directory.""" - if not mod.path.is_dir(): - raise NotAMatch(config_class, "model path is not a directory") - - -class SubmodelDefinition(BaseModel): - path_or_prefix: str - model_type: ModelType - variant: AnyVariant | None = None - - model_config = ConfigDict(protected_namespaces=()) - - -class MainModelDefaultSettings(BaseModel): - vae: str | None = Field(default=None, description="Default VAE for this model (model key)") - vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model") - scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model") - steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model") - cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model") - cfg_rescale_multiplier: float | None = Field( - default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model" - ) - width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model") - height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model") - guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model") - - model_config = ConfigDict(extra="forbid") - - -class LoraModelDefaultSettings(BaseModel): - weight: float | None = Field(default=None, ge=-1, le=2, description="Default weight for this model") - model_config = ConfigDict(extra="forbid") - - -class ControlAdapterDefaultSettings(BaseModel): - # This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer. - preprocessor: str | None - model_config = ConfigDict(extra="forbid") - - -class LegacyProbeMixin: - """Mixin for classes using the legacy probe for model classification.""" - - pass - - -class Config_Base(ABC, BaseModel): - """ - Abstract base class for model configurations. A model config describes a specific combination of model base, type and - format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format - would have base=sd-1, type=main, format=checkpoint. - - To create a new config type, inherit from this class and implement its interface: - - Define method 'from_model_on_disk' that returns an instance of the class or raises NotAMatch. This method will be - called during model installation to determine the correct config class for a model. - - Define fields 'type', 'base' and 'format' as pydantic fields. These should be Literals with a single value. A - default must be provided for each of these fields. - - If multiple combinations of base, type and format need to be supported, create a separate subclass for each. - - See MinimalConfigExample in test_model_probe.py for an example implementation. - """ - - key: str = Field( - description="A unique key for this model.", - default_factory=uuid_string, - ) - hash: str = Field( - description="The hash of the model file(s).", - ) - path: str = Field( - description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory.", - ) - file_size: int = Field( - description="The size of the model in bytes.", - ) - name: str = Field( - description="Name of the model.", - ) - description: str | None = Field( - description="Model description", - default=None, - ) - source: str = Field( - description="The original source of the model (path, URL or repo_id).", - ) - source_type: ModelSourceType = Field( - description="The type of source", - ) - source_api_response: str | None = Field( - description="The original API response from the source, as stringified JSON.", - default=None, - ) - cover_image: str | None = Field( - description="Url for image to preview model", - default=None, - ) - submodels: dict[SubModelType, SubmodelDefinition] | None = Field( - description="Loadable submodels in this model", - default=None, - ) - usage_info: str | None = Field( - default=None, - description="Usage information for this model", - ) - - CONFIG_CLASSES: ClassVar[set[Type["AnyModelConfig"]]] = set() - - model_config = ConfigDict( - validate_assignment=True, - json_schema_serialization_defaults_required=True, - json_schema_mode_override="serialization", - ) - - @classmethod - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - # Register non-abstract subclasses so we can iterate over them later during model probing. - if not isabstract(cls): - cls.CONFIG_CLASSES.add(cls) - - @classmethod - def __pydantic_init_subclass__(cls, **kwargs): - # Ensure that subclasses define 'base', 'type' and 'format' fields and provide defaults for them. Each subclass - # is expected to represent a single combination of base, type and format. - for name in ("type", "base", "format"): - assert name in cls.model_fields, f"{cls.__name__} must define a '{name}' field" - assert cls.model_fields[name].default is not PydanticUndefined, ( - f"{cls.__name__} must define a default for the '{name}' field" - ) - - @classmethod - def get_tag(cls) -> Tag: - """Constructs a pydantic discriminated union tag for this model config class. When a config is deserialized, - pydantic uses the tag to determine which subclass to instantiate. - - The tag is a dot-separated string of the type, format, base and variant (if applicable). - """ - tag_strings: list[str] = [] - for name in ("type", "format", "base", "variant"): - if field := cls.model_fields.get(name): - if field.default is not PydanticUndefined: - # We expect each of these fields has an Enum for its default; we want the value of the enum. - tag_strings.append(field.default.value) - return Tag(".".join(tag_strings)) - - @staticmethod - def get_model_discriminator_value(v: Any) -> str: - """ - Computes the discriminator value for a model config. - https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator - """ - if isinstance(v, Config_Base): - # We have an instance of a ModelConfigBase subclass - use its tag directly. - return v.get_tag().tag - if isinstance(v, dict): - # We have a dict - compute the tag from its fields. - tag_strings: list[str] = [] - if type_ := v.get("type"): - if isinstance(type_, Enum): - type_ = type_.value - tag_strings.append(type_) - - if format_ := v.get("format"): - if isinstance(format_, Enum): - format_ = format_.value - tag_strings.append(format_) - - if base_ := v.get("base"): - if isinstance(base_, Enum): - base_ = base_.value - tag_strings.append(base_) - - # Special case: CLIP Embed models also need the variant to distinguish them. - if ( - type_ == ModelType.CLIPEmbed.value - and format_ == ModelFormat.Diffusers.value - and base_ == BaseModelType.Any.value - ): - if variant_value := v.get("variant"): - if isinstance(variant_value, Enum): - variant_value = variant_value.value - tag_strings.append(variant_value) - else: - raise ValueError("CLIP Embed model config dict must include a 'variant' field") - - return ".".join(tag_strings) - else: - raise TypeError("Model config discriminator value must be computed from a dict or ModelConfigBase instance") - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - """Given the model on disk and any overrides, return an instance of this config class. - - Implementations should raise NotAMatch if the model does not match this config class.""" - raise NotImplementedError(f"from_model_on_disk not implemented for {cls.__name__}") - - -class Unknown_Config(Config_Base): - """Model config for unknown models, used as a fallback when we cannot identify a model.""" - - base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown) - type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown) - format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - raise NotAMatch(cls, "unknown model config cannot match any model") - - -class Checkpoint_Config_Base(ABC, BaseModel): - """Base class for checkpoint-style models.""" - - config_path: str | None = Field( - description="Path to the config for this model, if any.", - default=None, - ) - converted_at: float | None = Field( - description="When this model was last converted to diffusers", - default_factory=time.time, - ) - - -class Diffusers_Config_Base(ABC, BaseModel): - """Base class for diffusers-style models.""" - - format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - repo_variant: Optional[ModelRepoVariant] = Field(ModelRepoVariant.Default) - - @classmethod - def _get_repo_variant_or_raise(cls, mod: ModelOnDisk) -> ModelRepoVariant: - # get all files ending in .bin or .safetensors - weight_files = list(mod.path.glob("**/*.safetensors")) - weight_files.extend(list(mod.path.glob("**/*.bin"))) - for x in weight_files: - if ".fp16" in x.suffixes: - return ModelRepoVariant.FP16 - if "openvino_model" in x.name: - return ModelRepoVariant.OpenVINO - if "flax_model" in x.name: - return ModelRepoVariant.Flax - if x.suffix == ".onnx": - return ModelRepoVariant.ONNX - return ModelRepoVariant.Default - - -class T5Encoder_T5Encoder_Config(Config_Base): - base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) - type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) - format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - _validate_class_name( - cls, - common_config_paths(mod.path), - { - "T5EncoderModel", - }, - ) - - cls._validate_has_unquantized_config_file(mod) - - return cls(**fields) - - @classmethod - def _validate_has_unquantized_config_file(cls, mod: ModelOnDisk) -> None: - has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists() - - if not has_unquantized_config: - raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json") - - -class T5Encoder_BnBLLMint8_Config(Config_Base): - base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) - type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) - format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - _validate_class_name( - cls, - common_config_paths(mod.path), - { - "T5EncoderModel", - }, - ) - - cls._validate_filename_looks_like_bnb_quantized(mod) - - cls._validate_model_looks_like_bnb_quantized(mod) - - return cls(**fields) - - @classmethod - def _validate_filename_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: - filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix()) - if not filename_looks_like_bnb: - raise NotAMatch(cls, "filename does not look like bnb quantized llm_int8") - - @classmethod - def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: - has_scb_key_suffix = has_any_keys_ending_with(mod.load_state_dict(), "SCB") - if not has_scb_key_suffix: - raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8") - - -class LoRA_Config_Base(ABC, BaseModel): - """Base class for LoRA models.""" - - type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA) - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - default_settings: Optional[LoraModelDefaultSettings] = Field( - description="Default settings for this model", default=None - ) - - -def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None: - # TODO(psyche): Moving this import to the function to avoid circular imports. Refactor later. - from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict - - state_dict = mod.load_state_dict(mod.path) - value = flux_format_from_state_dict(state_dict, mod.metadata()) - return value - - -class LoRA_OMI_Config_Base(LoRA_Config_Base): - format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_looks_like_omi_lora(mod) - - cls._validate_base(mod) - - return cls(**fields) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _validate_looks_like_omi_lora(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model metadata does not look like an OMI LoRA.""" - flux_format = _get_flux_lora_format(mod) - if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: - raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA") - - metadata = mod.metadata() - - metadata_looks_like_omi_lora = ( - bool(metadata.get("modelspec.sai_model_spec")) - and metadata.get("ot_branch") == "omi_format" - and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora" - ) - - if not metadata_looks_like_omi_lora: - raise NotAMatch(cls, "metadata does not look like OMI LoRA") - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL]: - metadata = mod.metadata() - architecture = metadata["modelspec.architecture"] - - if architecture == stable_diffusion_xl_1_lora: - return BaseModelType.StableDiffusionXL - elif architecture == flux_dev_1_lora: - return BaseModelType.Flux - else: - raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}") - - -class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, Config_Base): - base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - - -class LoRA_LyCORIS_Config_Base(LoRA_Config_Base): - """Model config for LoRA/Lycoris models.""" - - type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA) - format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_looks_like_lora(mod) - - cls._validate_base(mod) - - return cls(**fields) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None: - # First rule out ControlLoRA and Diffusers LoRA - flux_format = _get_flux_lora_format(mod) - if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: - raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA") - - # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. - # Some main models have these keys, likely due to the creator merging in a LoRA. - has_key_with_lora_prefix = has_any_keys_starting_with( - mod.load_state_dict(), - { - "lora_te_", - "lora_unet_", - "lora_te1_", - "lora_te2_", - "lora_transformer_", - }, - ) - - has_key_with_lora_suffix = has_any_keys_ending_with( - mod.load_state_dict(), - { - "to_k_lora.up.weight", - "to_q_lora.down.weight", - "lora_A.weight", - "lora_B.weight", - }, - ) - - if not has_key_with_lora_prefix and not has_key_with_lora_suffix: - raise NotAMatch(cls, "model does not match LyCORIS LoRA heuristics") - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: - if _get_flux_lora_format(mod): - return BaseModelType.Flux - - state_dict = mod.load_state_dict() - # If we've gotten here, we assume that the model is a Stable Diffusion model - token_vector_length = lora_token_vector_length(state_dict) - if token_vector_length == 768: - return BaseModelType.StableDiffusion1 - elif token_vector_length == 1024: - return BaseModelType.StableDiffusion2 - elif token_vector_length == 1280: - return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 - elif token_vector_length == 2048: - return BaseModelType.StableDiffusionXL - else: - raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}") - - -class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) - - -class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) - - -class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base): - base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - - -class ControlAdapter_Config_Base(ABC, BaseModel): - default_settings: ControlAdapterDefaultSettings | None = Field(None) - - -class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapter_Config_Base, Config_Base): - """Model config for Control LoRA models.""" - - base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa) - format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS) - - trigger_phrases: set[str] | None = Field(None) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_looks_like_control_lora(mod) - - return cls(**fields) - - @classmethod - def _validate_looks_like_control_lora(cls, mod: ModelOnDisk) -> None: - state_dict = mod.load_state_dict() - - if not is_state_dict_likely_flux_control(state_dict): - raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA") - - -class LoRA_Diffusers_Config_Base(LoRA_Config_Base): - """Model config for LoRA/Diffusers models.""" - - # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates - # the weights format. FLUX Diffusers LoRAs are single files. - - format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_base(mod) - - return cls(**fields) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: - if _get_flux_lora_format(mod): - return BaseModelType.Flux - - # If we've gotten here, we assume that the LoRA is a Stable Diffusion LoRA - path_to_weight_file = cls._get_weight_file_or_raise(mod) - state_dict = mod.load_state_dict(path_to_weight_file) - token_vector_length = lora_token_vector_length(state_dict) - - match token_vector_length: - case 768: - return BaseModelType.StableDiffusion1 - case 1024: - return BaseModelType.StableDiffusion2 - case 1280: - return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 - case 2048: - return BaseModelType.StableDiffusionXL - case _: - raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}") - - @classmethod - def _get_weight_file_or_raise(cls, mod: ModelOnDisk) -> Path: - suffixes = ["bin", "safetensors"] - weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] - for wf in weight_files: - if wf.exists(): - return wf - raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors") - - -class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) - - -class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) - - -class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - - -class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base): - """Model config for standalone VAE models.""" - - type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) - format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - - REGEX_TO_BASE: ClassVar[dict[str, BaseModelType]] = { - r"xl": BaseModelType.StableDiffusionXL, - r"sd2": BaseModelType.StableDiffusion2, - r"vae": BaseModelType.StableDiffusion1, - r"FLUX.1-schnell_ae": BaseModelType.Flux, - } - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_looks_like_vae(mod) - - cls._validate_base(mod) - - return cls(**fields) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None: - if not has_any_keys_starting_with( - mod.load_state_dict(), - { - "encoder.conv_in", - "decoder.conv_in", - }, - ): - raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics") - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: - # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name - for regexp, base in cls.REGEX_TO_BASE.items(): - if re.search(regexp, mod.path.name, re.IGNORECASE): - return base - - raise NotAMatch(cls, "cannot determine base type") - - -class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) - - -class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) - - -class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - - -class VAE_Diffusers_Config_Base(Diffusers_Config_Base): - """Model config for standalone VAE models (diffusers version).""" - - type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) - format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - _validate_class_name( - cls, - common_config_paths(mod.path), - { - "AutoencoderKL", - "AutoencoderTiny", - }, - ) - - cls._validate_base(mod) - - return cls(**fields) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool: - # Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. - return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] - - @classmethod - def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool: - # Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down - # by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best - # we can do is guess based on name. - return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE)) - - @classmethod - def _guess_name(cls, mod: ModelOnDisk) -> str: - name = mod.path.name - if name == "vae": - name = mod.path.parent.name - return name - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: - config = _get_config_or_raise(cls, common_config_paths(mod.path)) - if cls._config_looks_like_sdxl(config): - return BaseModelType.StableDiffusionXL - elif cls._name_looks_like_sdxl(mod): - return BaseModelType.StableDiffusionXL - else: - # TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO. - return BaseModelType.StableDiffusion1 - - -class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) - - -class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base, ControlAdapter_Config_Base): - """Model config for ControlNet models (diffusers version).""" - - type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) - format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - _validate_class_name( - cls, - common_config_paths(mod.path), - { - "ControlNetModel", - "FluxControlNetModel", - }, - ) - - cls._validate_base(mod) - - return cls(**fields) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: - config = _get_config_or_raise(cls, common_config_paths(mod.path)) - - if config.get("_class_name") == "FluxControlNetModel": - return BaseModelType.Flux - - dimension = config.get("cross_attention_dim") - - match dimension: - case 768: - return BaseModelType.StableDiffusion1 - case 1024: - # No obvious way to distinguish between sd2-base and sd2-768, but we don't really differentiate them - # anyway. - return BaseModelType.StableDiffusion2 - case 2048: - return BaseModelType.StableDiffusionXL - case _: - raise NotAMatch(cls, f"unrecognized cross_attention_dim {dimension}") - - -class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) - - -class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) - - -class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - - -class ControlNet_Checkpoint_Config_Base(Checkpoint_Config_Base, ControlAdapter_Config_Base): - """Model config for ControlNet models (diffusers version).""" - - type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) - format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_looks_like_controlnet(mod) - - cls._validate_base(mod) - - return cls(**fields) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None: - if has_any_keys_starting_with( - mod.load_state_dict(), - { - "controlnet", - "control_model", - "input_blocks", - # XLabs FLUX ControlNet models have keys starting with "controlnet_blocks." - # For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors - # TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with - # "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so - # delicate. - "controlnet_blocks", - }, - ): - raise NotAMatch(cls, "state dict does not look like a ControlNet checkpoint") - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: - state_dict = mod.load_state_dict() - - if is_state_dict_xlabs_controlnet(state_dict) or is_state_dict_instantx_controlnet(state_dict): - # TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing - # get_format()? - return BaseModelType.Flux - - for key in ( - "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "controlnet_mid_block.bias", - "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", - ): - if key not in state_dict: - continue - width = state_dict[key].shape[-1] - match width: - case 768: - return BaseModelType.StableDiffusion1 - case 1024: - return BaseModelType.StableDiffusion2 - case 2048: - return BaseModelType.StableDiffusionXL - case 1280: - return BaseModelType.StableDiffusionXL - case _: - pass - - raise NotAMatch(cls, "unable to determine base type from state dict") - - -class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) - - -class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) - - -class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - - -class TI_Config_Base(ABC, BaseModel): - type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk, path: Path | None = None) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod, path) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool: - try: - p = path or mod.path - - if not p.exists(): - return False - - if p.is_dir(): - return False - - if p.name in [f"learned_embeds.{s}" for s in mod.weight_files()]: - return True - - state_dict = mod.load_state_dict(p) - - # Heuristic: textual inversion embeddings have these keys - if any(key in {"string_to_param", "emb_params", "clip_g"} for key in state_dict.keys()): - return True - - # Heuristic: small state dict with all tensor values - if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()): - return True - - return False - except Exception: - return False - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: - p = path or mod.path - - try: - state_dict = mod.load_state_dict(p) - except Exception as e: - raise NotAMatch(cls, f"unable to load state dict from {p}: {e}") from e - - try: - if "string_to_token" in state_dict: - token_dim = list(state_dict["string_to_param"].values())[0].shape[-1] - elif "emb_params" in state_dict: - token_dim = state_dict["emb_params"].shape[-1] - elif "clip_g" in state_dict: - token_dim = state_dict["clip_g"].shape[-1] - else: - token_dim = list(state_dict.values())[0].shape[0] - except Exception as e: - raise NotAMatch(cls, f"unable to determine token dimension from state dict in {p}: {e}") from e - - match token_dim: - case 768: - return BaseModelType.StableDiffusion1 - case 1024: - return BaseModelType.StableDiffusion2 - case 1280: - return BaseModelType.StableDiffusionXL - case _: - raise NotAMatch(cls, f"unrecognized token dimension {token_dim}") - - -class TI_File_Config_Base(TI_Config_Base): - """Model config for textual inversion embeddings.""" - - format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - if not cls._file_looks_like_embedding(mod): - raise NotAMatch(cls, "model does not look like a textual inversion embedding file") - - cls._validate_base(mod) - - return cls(**fields) - - -class TI_File_SD1_Config(TI_File_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) - - -class TI_File_SD2_Config(TI_File_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) - - -class TI_File_SDXL_Config(TI_File_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class TI_Folder_Config_Base(TI_Config_Base): - """Model config for textual inversion embeddings.""" - - format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - for p in mod.weight_files(): - if cls._file_looks_like_embedding(mod, p): - cls._validate_base(mod, p) - return cls(**fields) - - raise NotAMatch(cls, "model does not look like a textual inversion embedding folder") - - -class TI_Folder_SD1_Config(TI_Folder_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) - - -class TI_Folder_SD2_Config(TI_Folder_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) - - -class TI_Folder_SDXL_Config(TI_Folder_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class Main_Config_Base(ABC, BaseModel): - type: Literal[ModelType.Main] = Field(default=ModelType.Main) - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - default_settings: Optional[MainModelDefaultSettings] = Field( - description="Default settings for this model", default=None - ) - - -def _has_bnb_nf4_keys(state_dict: dict[str | int, Any]) -> bool: - bnb_nf4_keys = { - "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4", - "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4", - } - return any(key in state_dict for key in bnb_nf4_keys) - - -def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool: - return any(isinstance(v, GGMLTensor) for v in state_dict.values()) - - -def _has_main_keys(state_dict: dict[str | int, Any]) -> bool: - for key in state_dict.keys(): - if isinstance(key, int): - continue - elif key.startswith( - ( - "cond_stage_model.", - "first_stage_model.", - "model.diffusion_model.", - # Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model". - # This prefix is typically used to distinguish between multiple models bundled in a single file. - "model.diffusion_model.double_blocks.", - ) - ): - return True - elif key.startswith("double_blocks.") and "ip_adapter" not in key: - # FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be - # careful to avoid false positives on XLabs FLUX IP-Adapter models. - return True - return False - - -class Main_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base): - """Model config for main checkpoint models.""" - - format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - - prediction_type: SchedulerPredictionType = Field() - variant: ModelVariantType = Field() - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_looks_like_main_model(mod) - - cls._validate_base(mod) - - prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) - - variant = fields.get("variant") or cls._get_variant_or_raise(mod) - - return cls(**fields, prediction_type=prediction_type, variant=variant) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: - state_dict = mod.load_state_dict() - - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1 - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - return BaseModelType.StableDiffusion2 - - key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: - return BaseModelType.StableDiffusionXL - elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: - return BaseModelType.StableDiffusionXLRefiner - - raise NotAMatch(cls, "unable to determine base type from state dict") - - @classmethod - def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType: - base = cls.model_fields["base"].default - - if base is BaseModelType.StableDiffusion2: - state_dict = mod.load_state_dict() - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - if "global_step" in state_dict: - if state_dict["global_step"] == 220000: - return SchedulerPredictionType.Epsilon - elif state_dict["global_step"] == 110000: - return SchedulerPredictionType.VPrediction - return SchedulerPredictionType.VPrediction - else: - return SchedulerPredictionType.Epsilon - - @classmethod - def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType: - base = cls.model_fields["base"].default - - state_dict = mod.load_state_dict() - key_name = "model.diffusion_model.input_blocks.0.0.weight" - - if key_name not in state_dict: - raise NotAMatch(cls, "unable to determine model variant from state dict") - - in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - - match in_channels: - case 4: - return ModelVariantType.Normal - case 5: - # Only SD2 has a depth variant - assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'" - return ModelVariantType.Depth - case 9: - return ModelVariantType.Inpaint - case _: - raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'") - - @classmethod - def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: - has_main_model_keys = _has_main_keys(mod.load_state_dict()) - if not has_main_model_keys: - raise NotAMatch(cls, "state dict does not look like a main model") - - -class Main_Checkpoint_SD1_Config(Main_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) - - -class Main_Checkpoint_SD2_Config(Main_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) - - -class Main_Checkpoint_SDXL_Config(Main_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class Main_Checkpoint_SDXLRefiner_Config(Main_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner) - - -def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None: - # FLUX Model variant types are distinguished by input channels and the presence of certain keys. - - # Input channels are derived from the shape of either "img_in.weight" or "model.diffusion_model.img_in.weight". - # - # Known models that use the latter key: - # - https://civitai.com/models/885098?modelVersionId=990775 - # - https://civitai.com/models/1018060?modelVersionId=1596255 - # - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133 - # - # Input channels for known FLUX models: - # - Unquantized Dev and Schnell have in_channels=64 - # - BNB-NF4 Dev and Schnell have in_channels=1 - # - FLUX Fill has in_channels=384 - # - Unsure of quantized FLUX Fill models - # - Unsure of GGUF-quantized models - - in_channels = None - for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}: - if key in state_dict: - in_channels = state_dict[key].shape[1] - break - - if in_channels is None: - # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, - # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX - # model, we should figure out a good fallback value. - return None - - # Because FLUX Dev and Schnell models have the same in_channels, we need to check for the presence of - # certain keys to distinguish between them. - is_flux_dev = ( - "guidance_in.out_layer.weight" in state_dict - or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict - ) - - if is_flux_dev and in_channels == 384: - return FluxVariantType.DevFill - elif is_flux_dev: - return FluxVariantType.Dev - else: - # Must be a Schnell model...? - return FluxVariantType.Schnell - - -class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): - """Model config for main checkpoint models.""" - - format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - - variant: FluxVariantType = Field() - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_looks_like_main_model(mod) - - cls._validate_is_flux(mod) - - cls._validate_does_not_look_like_bnb_quantized(mod) - - cls._validate_does_not_look_like_gguf_quantized(mod) - - variant = fields.get("variant") or cls._get_variant_or_raise(mod) - - return cls(**fields, variant=variant) - - @classmethod - def _validate_is_flux(cls, mod: ModelOnDisk) -> None: - if not has_any_keys( - mod.load_state_dict(), - { - "double_blocks.0.img_attn.norm.key_norm.scale", - "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", - }, - ): - raise NotAMatch(cls, "state dict does not look like a FLUX checkpoint") - - @classmethod - def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: - # FLUX Model variant types are distinguished by input channels and the presence of certain keys. - state_dict = mod.load_state_dict() - variant = _get_flux_variant(state_dict) - - if variant is None: - # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, - # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX - # model, we should figure out a good fallback value. - raise NotAMatch(cls, "unable to determine model variant from state dict") - - return variant - - @classmethod - def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: - has_main_model_keys = _has_main_keys(mod.load_state_dict()) - if not has_main_model_keys: - raise NotAMatch(cls, "state dict does not look like a main model") - - @classmethod - def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: - has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict()) - if has_bnb_nf4_keys: - raise NotAMatch(cls, "state dict looks like bnb quantized nf4") - - @classmethod - def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk): - has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict()) - if has_ggml_tensors: - raise NotAMatch(cls, "state dict looks like GGUF quantized") - - -class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): - """Model config for main checkpoint models.""" - - base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b) - - variant: FluxVariantType = Field() - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_looks_like_main_model(mod) - - cls._validate_model_looks_like_bnb_quantized(mod) - - variant = fields.get("variant") or cls._get_variant_or_raise(mod) - - return cls(**fields, variant=variant) - - @classmethod - def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: - # FLUX Model variant types are distinguished by input channels and the presence of certain keys. - state_dict = mod.load_state_dict() - variant = _get_flux_variant(state_dict) - - if variant is None: - # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, - # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX - # model, we should figure out a good fallback value. - raise NotAMatch(cls, "unable to determine model variant from state dict") - - return variant - - @classmethod - def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: - has_main_model_keys = _has_main_keys(mod.load_state_dict()) - if not has_main_model_keys: - raise NotAMatch(cls, "state dict does not look like a main model") - - @classmethod - def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: - has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict()) - if not has_bnb_nf4_keys: - raise NotAMatch(cls, "state dict does not look like bnb quantized nf4") - - -class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): - """Model config for main checkpoint models.""" - - base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized) - - variant: FluxVariantType = Field() - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_looks_like_main_model(mod) - - cls._validate_looks_like_gguf_quantized(mod) - - variant = fields.get("variant") or cls._get_variant_or_raise(mod) - - return cls(**fields, variant=variant) - - @classmethod - def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: - # FLUX Model variant types are distinguished by input channels and the presence of certain keys. - state_dict = mod.load_state_dict() - variant = _get_flux_variant(state_dict) - - if variant is None: - # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, - # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX - # model, we should figure out a good fallback value. - raise NotAMatch(cls, "unable to determine model variant from state dict") - - return variant - - @classmethod - def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: - has_main_model_keys = _has_main_keys(mod.load_state_dict()) - if not has_main_model_keys: - raise NotAMatch(cls, "state dict does not look like a main model") - - @classmethod - def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None: - has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict()) - if not has_ggml_tensors: - raise NotAMatch(cls, "state dict does not look like GGUF quantized") - - -class Main_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base): - prediction_type: SchedulerPredictionType = Field() - variant: ModelVariantType = Field() - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - _validate_class_name( - cls, - common_config_paths(mod.path), - { - # SD 1.x and 2.x - "StableDiffusionPipeline", - "StableDiffusionInpaintPipeline", - # SDXL - "StableDiffusionXLPipeline", - "StableDiffusionXLInpaintPipeline", - # SDXL Refiner - "StableDiffusionXLImg2ImgPipeline", - # TODO(psyche): Do we actually support LCM models? I don't see using this class anywhere in the codebase. - "LatentConsistencyModelPipeline", - }, - ) - - cls._validate_base(mod) - - variant = fields.get("variant") or cls._get_variant_or_raise(mod) - - prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) - - repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) - - return cls( - **fields, - variant=variant, - prediction_type=prediction_type, - repo_variant=repo_variant, - ) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: - # Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL). - unet_config_path = mod.path / "unet" / "config.json" - if unet_config_path.exists(): - with open(unet_config_path) as file: - unet_conf = json.load(file) - cross_attention_dim = unet_conf.get("cross_attention_dim") - match cross_attention_dim: - case 768: - return BaseModelType.StableDiffusion1 - case 1024: - return BaseModelType.StableDiffusion2 - case 1280: - return BaseModelType.StableDiffusionXLRefiner - case 2048: - return BaseModelType.StableDiffusionXL - case _: - raise NotAMatch(cls, f"unrecognized cross_attention_dim {cross_attention_dim}") - - raise NotAMatch(cls, "unable to determine base type") - - @classmethod - def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType: - scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json") - - # TODO(psyche): Is epsilon the right default or should we raise if it's not present? - prediction_type = scheduler_conf.get("prediction_type", "epsilon") - - match prediction_type: - case "v_prediction": - return SchedulerPredictionType.VPrediction - case "epsilon": - return SchedulerPredictionType.Epsilon - case _: - raise NotAMatch(cls, f"unrecognized scheduler prediction_type {prediction_type}") - - @classmethod - def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType: - base = cls.model_fields["base"].default - unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json") - in_channels = unet_config.get("in_channels") - - match in_channels: - case 4: - return ModelVariantType.Normal - case 5: - # Only SD2 has a depth variant - assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'" - return ModelVariantType.Depth - case 9: - return ModelVariantType.Inpaint - case _: - raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'") - - -class Main_Diffusers_SD1_Config(Main_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(BaseModelType.StableDiffusion1) - - -class Main_Diffusers_SD2_Config(Main_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion2] = Field(BaseModelType.StableDiffusion2) - - -class Main_Diffusers_SDXL_Config(Main_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(BaseModelType.StableDiffusionXL) - - -class Main_Diffusers_SDXLRefiner_Config(Main_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(BaseModelType.StableDiffusionXLRefiner) - - -class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - # This check implies the base type - no further validation needed. - _validate_class_name( - cls, - common_config_paths(mod.path), - { - "StableDiffusion3Pipeline", - "SD3Transformer2DModel", - }, - ) - - submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod) - - repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) - - return cls( - **fields, - submodels=submodels, - repo_variant=repo_variant, - ) - - @classmethod - def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]: - # Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json - config = _get_config_or_raise(cls, common_config_paths(mod.path)) - - submodels: dict[SubModelType, SubmodelDefinition] = {} - - for key, value in config.items(): - # Anything that starts with an underscore is top-level metadata, not a submodel - if key.startswith("_") or not (isinstance(value, list) and len(value) == 2): - continue - # The key is something like "transformer" and is a submodel - it will be in a dir of the same name. - # The value value is something like ["diffusers", "SD3Transformer2DModel"] - _library_name, class_name = value - - match class_name: - case "CLIPTextModelWithProjection": - model_type = ModelType.CLIPEmbed - path_or_prefix = (mod.path / key).resolve().as_posix() - - # We need to read the config to determine the variant of the CLIP model. - clip_embed_config = _get_config_or_raise( - cls, {mod.path / key / "config.json", mod.path / key / "model_index.json"} - ) - variant = _get_clip_variant_type_from_config(clip_embed_config) - submodels[SubModelType(key)] = SubmodelDefinition( - path_or_prefix=path_or_prefix, - model_type=model_type, - variant=variant, - ) - case "SD3Transformer2DModel": - model_type = ModelType.Main - path_or_prefix = (mod.path / key).resolve().as_posix() - variant = None - submodels[SubModelType(key)] = SubmodelDefinition( - path_or_prefix=path_or_prefix, - model_type=model_type, - variant=variant, - ) - case _: - pass - - return submodels - - -class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): - base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - # This check implies the base type - no further validation needed. - _validate_class_name( - cls, - common_config_paths(mod.path), - { - "CogView4Pipeline", - }, - ) - - repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) - - return cls( - **fields, - repo_variant=repo_variant, - ) - - -class IPAdapter_Config_Base(ABC, BaseModel): - type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter) - - -class IPAdapter_InvokeAI_Config_Base(IPAdapter_Config_Base): - """Model config for IP Adapter diffusers format models.""" - - format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI) - - # TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long - # time. Need to go through the history to make sure I'm understanding this fully. - image_encoder_model_id: str = Field() - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_has_weights_file(mod) - - cls._validate_has_image_encoder_metadata_file(mod) - - cls._validate_base(mod) - - return cls(**fields) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _validate_has_weights_file(cls, mod: ModelOnDisk) -> None: - weights_file = mod.path / "ip_adapter.bin" - if not weights_file.exists(): - raise NotAMatch(cls, "missing ip_adapter.bin weights file") - - @classmethod - def _validate_has_image_encoder_metadata_file(cls, mod: ModelOnDisk) -> None: - image_encoder_metadata_file = mod.path / "image_encoder.txt" - if not image_encoder_metadata_file.exists(): - raise NotAMatch(cls, "missing image_encoder.txt metadata file") - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: - state_dict = mod.load_state_dict() - - try: - cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] - except Exception as e: - raise NotAMatch(cls, f"unable to determine cross attention dimension: {e}") from e - - match cross_attention_dim: - case 1280: - return BaseModelType.StableDiffusionXL - case 768: - return BaseModelType.StableDiffusion1 - case 1024: - return BaseModelType.StableDiffusion2 - case _: - raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}") - - -class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) - - -class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) - - -class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class IPAdapter_Checkpoint_Config_Base(IPAdapter_Config_Base): - """Model config for IP Adapter checkpoint format models.""" - - format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_looks_like_ip_adapter(mod) - - cls._validate_base(mod) - - return cls(**fields) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None: - if not has_any_keys_starting_with( - mod.load_state_dict(), - { - "image_proj.", - "ip_adapter.", - # XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.". - "ip_adapter_proj_model.", - }, - ): - raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics") - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: - state_dict = mod.load_state_dict() - - if is_state_dict_xlabs_ip_adapter(state_dict): - return BaseModelType.Flux - - try: - cross_attention_dim = state_dict["ip_adapter.1.to_k_ip.weight"].shape[-1] - except Exception as e: - raise NotAMatch(cls, f"unable to determine cross attention dimension: {e}") from e - - match cross_attention_dim: - case 1280: - return BaseModelType.StableDiffusionXL - case 768: - return BaseModelType.StableDiffusion1 - case 1024: - return BaseModelType.StableDiffusion2 - case _: - raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}") - - -class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) - - -class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) - - -class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): - base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - - -def _get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None: - try: - hidden_size = config.get("hidden_size") - match hidden_size: - case 1280: - return ClipVariantType.G - case 768: - return ClipVariantType.L - case _: - return None - except Exception: - return None - - -class CLIPEmbed_Diffusers_Config_Base(Diffusers_Config_Base): - base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) - type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed) - format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - _validate_class_name( - cls, - { - mod.path / "config.json", - mod.path / "text_encoder" / "config.json", - }, - { - "CLIPModel", - "CLIPTextModel", - "CLIPTextModelWithProjection", - }, - ) - - cls._validate_variant(mod) - - return cls(**fields) - - @classmethod - def _validate_variant(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model variant does not match this config class.""" - expected_variant = cls.model_fields["variant"].default - config = _get_config_or_raise( - cls, - { - mod.path / "config.json", - mod.path / "text_encoder" / "config.json", - }, - ) - recognized_variant = _get_clip_variant_type_from_config(config) - - if recognized_variant is None: - raise NotAMatch(cls, "unable to determine CLIP variant from config") - - if expected_variant is not recognized_variant: - raise NotAMatch(cls, f"variant is {recognized_variant}, not {expected_variant}") - - -class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base): - variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G) - - -class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base): - variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L) - - -class CLIPVision_Diffusers_Config(Diffusers_Config_Base, Config_Base): - """Model config for CLIPVision.""" - - base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) - type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision) - format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - _validate_class_name( - cls, - common_config_paths(mod.path), - { - "CLIPVisionModelWithProjection", - }, - ) - - return cls(**fields) - - -class T2IAdapter_Diffusers_Config_Base(Diffusers_Config_Base, ControlAdapter_Config_Base): - """Model config for T2I.""" - - type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter) - format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - _validate_class_name( - cls, - common_config_paths(mod.path), - { - "T2IAdapter", - }, - ) - - cls._validate_base(mod) - - return cls(**fields) - - @classmethod - def _validate_base(cls, mod: ModelOnDisk) -> None: - """Raise `NotAMatch` if the model base does not match this config class.""" - expected_base = cls.model_fields["base"].default - recognized_base = cls._get_base_or_raise(mod) - if expected_base is not recognized_base: - raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") - - @classmethod - def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: - config = _get_config_or_raise(cls, common_config_paths(mod.path)) - - adapter_type = config.get("adapter_type") - - match adapter_type: - case "full_adapter_xl": - return BaseModelType.StableDiffusionXL - case "full_adapter" | "light_adapter": - return BaseModelType.StableDiffusion1 - case _: - raise NotAMatch(cls, f"unrecognized adapter_type '{adapter_type}'") - - -class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) - - -class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, Config_Base): - base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) - - -class Spandrel_Checkpoint_Config(Config_Base): - """Model config for Spandrel Image to Image models.""" - - base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) - type: Literal[ModelType.SpandrelImageToImage] = Field(default=ModelType.SpandrelImageToImage) - format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - cls._validate_spandrel_loads_model(mod) - - return cls(**fields) - - @classmethod - def _validate_spandrel_loads_model(cls, mod: ModelOnDisk) -> None: - try: - # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were - # explored to avoid this: - # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta - # device. Unfortunately, some Spandrel models perform operations during initialization that are not - # supported on meta tensors. - # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. - # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to - # maintain it, and the risk of false positive detections is higher. - SpandrelImageToImageModel.load_from_file(mod.path) - except Exception as e: - raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e - - -class SigLIP_Diffusers_Config(Diffusers_Config_Base, Config_Base): - """Model config for SigLIP.""" - - type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP) - format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) - base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - _validate_class_name( - cls, - common_config_paths(mod.path), - { - "SiglipModel", - }, - ) - - return cls(**fields) - - -class FLUXRedux_Checkpoint_Config(Config_Base): - """Model config for FLUX Tools Redux model.""" - - type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux) - format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) - base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_file(cls, mod) - - _validate_override_fields(cls, fields) - - if not is_state_dict_likely_flux_redux(mod.load_state_dict()): - raise NotAMatch(cls, "model does not match FLUX Tools Redux heuristics") - - return cls(**fields) - - -class LlavaOnevision_Diffusers_Config(Diffusers_Config_Base, Config_Base): - """Model config for Llava Onevision models.""" - - type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision) - base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) - variant: Literal[ModelVariantType.Normal] = Field(default=ModelVariantType.Normal) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - _validate_is_dir(cls, mod) - - _validate_override_fields(cls, fields) - - _validate_class_name( - cls, - common_config_paths(mod.path), - { - "LlavaOnevisionForConditionalGeneration", - }, - ) - - return cls(**fields) - - -class ExternalAPI_Config_Base(ABC, BaseModel): - """Model config for API-based models.""" - - format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - raise NotAMatch(cls, "External API models cannot be built from disk") - - -class ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): - base: Literal[BaseModelType.ChatGPT4o] = Field(default=BaseModelType.ChatGPT4o) - - -class ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): - base: Literal[BaseModelType.Gemini2_5] = Field(default=BaseModelType.Gemini2_5) - - -class ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): - base: Literal[BaseModelType.Imagen3] = Field(default=BaseModelType.Imagen3) - - -class ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): - base: Literal[BaseModelType.Imagen4] = Field(default=BaseModelType.Imagen4) - - -class ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): - base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) - - -class VideoConfigBase(ABC, BaseModel): - type: Literal[ModelType.Video] = Field(default=ModelType.Video) - trigger_phrases: set[str] | None = Field(description="Set of trigger phrases for this model", default=None) - default_settings: MainModelDefaultSettings | None = Field( - description="Default settings for this model", default=None - ) - - -class ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, VideoConfigBase, Config_Base): - base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) - - -class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, Config_Base): - base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) - - -# The types are listed explicitly because IDEs/LSPs can't identify the correct types -# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes -AnyModelConfig = Annotated[ - Union[ - # Main (Pipeline) - diffusers format - Annotated[Main_Diffusers_SD1_Config, Main_Diffusers_SD1_Config.get_tag()], - Annotated[Main_Diffusers_SD2_Config, Main_Diffusers_SD2_Config.get_tag()], - Annotated[Main_Diffusers_SDXL_Config, Main_Diffusers_SDXL_Config.get_tag()], - Annotated[Main_Diffusers_SDXLRefiner_Config, Main_Diffusers_SDXLRefiner_Config.get_tag()], - Annotated[Main_Diffusers_SD3_Config, Main_Diffusers_SD3_Config.get_tag()], - Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()], - # Main (Pipeline) - checkpoint format - Annotated[Main_Checkpoint_SD1_Config, Main_Checkpoint_SD1_Config.get_tag()], - Annotated[Main_Checkpoint_SD2_Config, Main_Checkpoint_SD2_Config.get_tag()], - Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()], - Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()], - Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()], - # Main (Pipeline) - quantized formats - Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()], - Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()], - # VAE - checkpoint format - Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()], - Annotated[VAE_Checkpoint_SD2_Config, VAE_Checkpoint_SD2_Config.get_tag()], - Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()], - Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()], - # VAE - diffusers format - Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()], - Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()], - # ControlNet - checkpoint format - Annotated[ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD1_Config.get_tag()], - Annotated[ControlNet_Checkpoint_SD2_Config, ControlNet_Checkpoint_SD2_Config.get_tag()], - Annotated[ControlNet_Checkpoint_SDXL_Config, ControlNet_Checkpoint_SDXL_Config.get_tag()], - Annotated[ControlNet_Checkpoint_FLUX_Config, ControlNet_Checkpoint_FLUX_Config.get_tag()], - # ControlNet - diffusers format - Annotated[ControlNet_Diffusers_SD1_Config, ControlNet_Diffusers_SD1_Config.get_tag()], - Annotated[ControlNet_Diffusers_SD2_Config, ControlNet_Diffusers_SD2_Config.get_tag()], - Annotated[ControlNet_Diffusers_SDXL_Config, ControlNet_Diffusers_SDXL_Config.get_tag()], - Annotated[ControlNet_Diffusers_FLUX_Config, ControlNet_Diffusers_FLUX_Config.get_tag()], - # LoRA - LyCORIS format - Annotated[LoRA_LyCORIS_SD1_Config, LoRA_LyCORIS_SD1_Config.get_tag()], - Annotated[LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SD2_Config.get_tag()], - Annotated[LoRA_LyCORIS_SDXL_Config, LoRA_LyCORIS_SDXL_Config.get_tag()], - Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()], - # LoRA - OMI format - Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()], - Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()], - # LoRA - diffusers format - Annotated[LoRA_Diffusers_SD1_Config, LoRA_Diffusers_SD1_Config.get_tag()], - Annotated[LoRA_Diffusers_SD2_Config, LoRA_Diffusers_SD2_Config.get_tag()], - Annotated[LoRA_Diffusers_SDXL_Config, LoRA_Diffusers_SDXL_Config.get_tag()], - Annotated[LoRA_Diffusers_FLUX_Config, LoRA_Diffusers_FLUX_Config.get_tag()], - # ControlLoRA - diffusers format - Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()], - # T5 Encoder - all formats - Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()], - Annotated[T5Encoder_BnBLLMint8_Config, T5Encoder_BnBLLMint8_Config.get_tag()], - # TI - file format - Annotated[TI_File_SD1_Config, TI_File_SD1_Config.get_tag()], - Annotated[TI_File_SD2_Config, TI_File_SD2_Config.get_tag()], - Annotated[TI_File_SDXL_Config, TI_File_SDXL_Config.get_tag()], - # TI - folder format - Annotated[TI_Folder_SD1_Config, TI_Folder_SD1_Config.get_tag()], - Annotated[TI_Folder_SD2_Config, TI_Folder_SD2_Config.get_tag()], - Annotated[TI_Folder_SDXL_Config, TI_Folder_SDXL_Config.get_tag()], - # IP Adapter - InvokeAI format - Annotated[IPAdapter_InvokeAI_SD1_Config, IPAdapter_InvokeAI_SD1_Config.get_tag()], - Annotated[IPAdapter_InvokeAI_SD2_Config, IPAdapter_InvokeAI_SD2_Config.get_tag()], - Annotated[IPAdapter_InvokeAI_SDXL_Config, IPAdapter_InvokeAI_SDXL_Config.get_tag()], - # IP Adapter - checkpoint format - Annotated[IPAdapter_Checkpoint_SD1_Config, IPAdapter_Checkpoint_SD1_Config.get_tag()], - Annotated[IPAdapter_Checkpoint_SD2_Config, IPAdapter_Checkpoint_SD2_Config.get_tag()], - Annotated[IPAdapter_Checkpoint_SDXL_Config, IPAdapter_Checkpoint_SDXL_Config.get_tag()], - Annotated[IPAdapter_Checkpoint_FLUX_Config, IPAdapter_Checkpoint_FLUX_Config.get_tag()], - # T2I Adapter - diffusers format - Annotated[T2IAdapter_Diffusers_SD1_Config, T2IAdapter_Diffusers_SD1_Config.get_tag()], - Annotated[T2IAdapter_Diffusers_SDXL_Config, T2IAdapter_Diffusers_SDXL_Config.get_tag()], - # Misc models - Annotated[Spandrel_Checkpoint_Config, Spandrel_Checkpoint_Config.get_tag()], - Annotated[CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_G_Config.get_tag()], - Annotated[CLIPEmbed_Diffusers_L_Config, CLIPEmbed_Diffusers_L_Config.get_tag()], - Annotated[CLIPVision_Diffusers_Config, CLIPVision_Diffusers_Config.get_tag()], - Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()], - Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()], - Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()], - # API models - Annotated[ExternalAPI_ChatGPT4o_Config, ExternalAPI_ChatGPT4o_Config.get_tag()], - Annotated[ExternalAPI_Gemini2_5_Config, ExternalAPI_Gemini2_5_Config.get_tag()], - Annotated[ExternalAPI_Imagen3_Config, ExternalAPI_Imagen3_Config.get_tag()], - Annotated[ExternalAPI_Imagen4_Config, ExternalAPI_Imagen4_Config.get_tag()], - Annotated[ExternalAPI_FluxKontext_Config, ExternalAPI_FluxKontext_Config.get_tag()], - Annotated[ExternalAPI_Veo3_Config, ExternalAPI_Veo3_Config.get_tag()], - Annotated[ExternalAPI_Runway_Config, ExternalAPI_Runway_Config.get_tag()], - # Unknown model (fallback) - Annotated[Unknown_Config, Unknown_Config.get_tag()], - ], - Discriminator(Config_Base.get_model_discriminator_value), -] - -AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig) - - -class ModelConfigFactory: - @staticmethod - def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig: - """Return the appropriate config object from raw dict values.""" - model = AnyModelConfigValidator.validate_python(model_data) - if isinstance(model, Checkpoint_Config_Base) and timestamp: - model.converted_at = timestamp - validate_hash(model.hash) - return model - - @staticmethod - def build_common_fields( - mod: ModelOnDisk, - overrides: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Builds the common fields for all model configs. - - Args: - mod: The model on disk to extract fields from. - overrides: A optional dictionary of fields to override. These fields will take precedence over the values - extracted from the model on disk. - - - Casts string fields to their Enum types. - - Does not validate the fields against the model config schema. - """ - - _overrides: dict[str, Any] = overrides or {} - fields: dict[str, Any] = {} - - if "type" in _overrides: - fields["type"] = ModelType(_overrides["type"]) - - if "format" in _overrides: - fields["format"] = ModelFormat(_overrides["format"]) - - if "base" in _overrides: - fields["base"] = BaseModelType(_overrides["base"]) - - if "source_type" in _overrides: - fields["source_type"] = ModelSourceType(_overrides["source_type"]) - - if "variant" in _overrides: - fields["variant"] = variant_type_adapter.validate_strings(_overrides["variant"]) - - fields["path"] = mod.path.as_posix() - fields["source"] = _overrides.get("source") or fields["path"] - fields["source_type"] = _overrides.get("source_type") or ModelSourceType.Path - fields["name"] = _overrides.get("name") or mod.name - fields["hash"] = _overrides.get("hash") or mod.hash() - fields["key"] = _overrides.get("key") or uuid_string() - fields["description"] = _overrides.get("description") - fields["file_size"] = _overrides.get("file_size") or mod.size() - - return fields - - @staticmethod - def from_model_on_disk( - mod: str | Path | ModelOnDisk, - overrides: dict[str, Any] | None = None, - hash_algo: HASHING_ALGORITHMS = "blake3_single", - ) -> AnyModelConfig: - """ - Returns the best matching ModelConfig instance from a model's file/folder path. - Raises InvalidModelConfigException if no valid configuration is found. - Created to deprecate ModelProbe.probe - """ - if isinstance(mod, Path | str): - mod = ModelOnDisk(Path(mod), hash_algo) - - # We will always need these fields to build any model config. - fields = ModelConfigFactory.build_common_fields(mod, overrides) - - # Store results as a mapping of config class to either an instance of that class or an exception - # that was raised when trying to build it. - results: dict[str, AnyModelConfig | Exception] = {} - - # Try to build an instance of each model config class that uses the classify API. - # Each class will either return an instance of itself or raise NotAMatch if it doesn't match. - # Other exceptions may be raised if something unexpected happens during matching or building. - for config_class in Config_Base.CONFIG_CLASSES: - class_name = config_class.__name__ - try: - instance = config_class.from_model_on_disk(mod, fields) - results[class_name] = instance - except NotAMatch as e: - results[class_name] = e - logger.debug(f"No match for {config_class.__name__} on model {mod.name}") - except ValidationError as e: - # This means the model matched, but we couldn't create the pydantic model instance for the config. - # Maybe invalid overrides were provided? - results[class_name] = e - logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}") - except Exception as e: - results[class_name] = e - logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}") - - matches = [r for r in results.values() if isinstance(r, Config_Base)] - - if not matches and app_config.allow_unknown_models: - logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config") - return Unknown_Config(**fields) - - if len(matches) > 1: - # We have multiple matches, in which case at most 1 is correct. We need to pick one. - # - # Known cases: - # - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model. - # - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with - # a config.json file. Prefer the main model. - - # Sort the matching according to known special cases. - def sort_key(m: AnyModelConfig) -> int: - match m.type: - case ModelType.Main: - return 0 - case ModelType.LoRA: - return 1 - case ModelType.CLIPEmbed: - return 2 - case _: - return 3 - - matches.sort(key=sort_key) - logger.warning( - f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(matches[0]).__name__}." - ) - - instance = matches[0] - logger.info(f"Model {mod.name} classified as {type(instance).__name__}") - return instance diff --git a/invokeai/backend/model_manager/configs/base.py b/invokeai/backend/model_manager/configs/base.py index 9e997e4bcd..e67efd2009 100644 --- a/invokeai/backend/model_manager/configs/base.py +++ b/invokeai/backend/model_manager/configs/base.py @@ -191,8 +191,8 @@ class Config_Base(ABC, BaseModel): else: raise TypeError("Model config discriminator value must be computed from a dict or ModelConfigBase instance") - @abstractmethod @classmethod + @abstractmethod def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: """Given the model on disk and any override fields, attempt to construct an instance of this config class. diff --git a/invokeai/backend/model_manager/configs/controlnet.py b/invokeai/backend/model_manager/configs/controlnet.py index 7db782a862..630e81fd24 100644 --- a/invokeai/backend/model_manager/configs/controlnet.py +++ b/invokeai/backend/model_manager/configs/controlnet.py @@ -3,14 +3,13 @@ from typing import ( Self, ) -from pydantic import Field +from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Any from invokeai.backend.flux.controlnet.state_dict_utils import ( is_state_dict_instantx_controlnet, is_state_dict_xlabs_controlnet, ) -from invokeai.backend.model_manager.config import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base, Diffusers_Config_Base from invokeai.backend.model_manager.configs.identification_utils import ( NotAMatchError, @@ -29,6 +28,42 @@ from invokeai.backend.model_manager.taxonomy import ( ModelType, ) +MODEL_NAME_TO_PREPROCESSOR = { + "canny": "canny_image_processor", + "mlsd": "mlsd_image_processor", + "depth": "depth_anything_image_processor", + "bae": "normalbae_image_processor", + "normal": "normalbae_image_processor", + "sketch": "pidi_image_processor", + "scribble": "lineart_image_processor", + "lineart anime": "lineart_anime_image_processor", + "lineart_anime": "lineart_anime_image_processor", + "lineart": "lineart_image_processor", + "soft": "hed_image_processor", + "softedge": "hed_image_processor", + "hed": "hed_image_processor", + "shuffle": "content_shuffle_image_processor", + "pose": "dw_openpose_image_processor", + "mediapipe": "mediapipe_face_processor", + "pidi": "pidi_image_processor", + "zoe": "zoe_depth_image_processor", + "color": "color_map_image_processor", +} + + +class ControlAdapterDefaultSettings(BaseModel): + # This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer. + preprocessor: str | None + model_config = ConfigDict(extra="forbid") + + @classmethod + def from_model_name(cls, model_name: str) -> Self: + for k, v in MODEL_NAME_TO_PREPROCESSOR.items(): + model_name_lower = model_name.lower() + if k in model_name_lower: + return cls(preprocessor=v) + return cls(preprocessor=None) + class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base): """Model config for ControlNet models (diffusers version).""" diff --git a/invokeai/backend/model_manager/configs/factory.py b/invokeai/backend/model_manager/configs/factory.py index 27b6f252f1..6ab16cd5f6 100644 --- a/invokeai/backend/model_manager/configs/factory.py +++ b/invokeai/backend/model_manager/configs/factory.py @@ -14,6 +14,7 @@ from invokeai.backend.model_manager.configs.base import Config_Base from invokeai.backend.model_manager.configs.clip_embed import CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_L_Config from invokeai.backend.model_manager.configs.clip_vision import CLIPVision_Diffusers_Config from invokeai.backend.model_manager.configs.controlnet import ( + ControlAdapterDefaultSettings, ControlNet_Checkpoint_FLUX_Config, ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD2_Config, @@ -47,6 +48,7 @@ from invokeai.backend.model_manager.configs.lora import ( LoRA_LyCORIS_SDXL_Config, LoRA_OMI_FLUX_Config, LoRA_OMI_SDXL_Config, + LoraModelDefaultSettings, ) from invokeai.backend.model_manager.configs.main import ( Main_BnBNF4_FLUX_Config, @@ -67,6 +69,7 @@ from invokeai.backend.model_manager.configs.main import ( Main_ExternalAPI_Imagen3_Config, Main_ExternalAPI_Imagen4_Config, Main_GGUF_FLUX_Config, + MainModelDefaultSettings, Video_ExternalAPI_Runway_Config, Video_ExternalAPI_Veo3_Config, ) @@ -332,9 +335,52 @@ class ModelConfigFactory: matches.sort(key=sort_key) logger.warning( - f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(matches[0]).__name__}." + f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}." ) instance = matches[0] logger.info(f"Model {mod.name} classified as {type(instance).__name__}") + + # Now do any post-processing needed for specific model types/bases/etc. + match instance.type: + case ModelType.Main: + match instance.base: + case BaseModelType.StableDiffusion1: + instance.default_settings = MainModelDefaultSettings(width=512, height=512) + case BaseModelType.StableDiffusion2: + instance.default_settings = MainModelDefaultSettings(width=768, height=768) + case BaseModelType.StableDiffusionXL: + instance.default_settings = MainModelDefaultSettings(width=1024, height=1024) + case _: + pass + case ModelType.ControlNet | ModelType.T2IAdapter | ModelType.ControlLoRa: + instance.default_settings = ControlAdapterDefaultSettings.from_model_name(instance.name) + case ModelType.LoRA: + instance.default_settings = LoraModelDefaultSettings() + case _: + pass + return instance + + +MODEL_NAME_TO_PREPROCESSOR = { + "canny": "canny_image_processor", + "mlsd": "mlsd_image_processor", + "depth": "depth_anything_image_processor", + "bae": "normalbae_image_processor", + "normal": "normalbae_image_processor", + "sketch": "pidi_image_processor", + "scribble": "lineart_image_processor", + "lineart anime": "lineart_anime_image_processor", + "lineart_anime": "lineart_anime_image_processor", + "lineart": "lineart_image_processor", + "soft": "hed_image_processor", + "softedge": "hed_image_processor", + "hed": "hed_image_processor", + "shuffle": "content_shuffle_image_processor", + "pose": "dw_openpose_image_processor", + "mediapipe": "mediapipe_face_processor", + "pidi": "pidi_image_processor", + "zoe": "zoe_depth_image_processor", + "color": "color_map_image_processor", +} diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py index 512137b4c3..24e10c035a 100644 --- a/invokeai/backend/model_manager/configs/lora.py +++ b/invokeai/backend/model_manager/configs/lora.py @@ -9,10 +9,10 @@ from typing import ( from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Any -from invokeai.backend.model_manager.config import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.base import ( Config_Base, ) +from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.identification_utils import ( NotAMatchError, raise_for_override_fields, diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py index ef1ab1fe77..26f6b5b60e 100644 --- a/invokeai/backend/model_manager/configs/main.py +++ b/invokeai/backend/model_manager/configs/main.py @@ -685,8 +685,8 @@ class Video_Config_Base(ABC, BaseModel): class Video_ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, Video_Config_Base, Config_Base): - base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) + base: Literal[BaseModelType.Veo3] = Field(default=BaseModelType.Veo3) class Video_ExternalAPI_Runway_Config(ExternalAPI_Config_Base, Video_Config_Base, Config_Base): - base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) + base: Literal[BaseModelType.Runway] = Field(default=BaseModelType.Runway) diff --git a/invokeai/backend/model_manager/configs/t2i_adapter.py b/invokeai/backend/model_manager/configs/t2i_adapter.py index 865c4dc763..a1da40e9b4 100644 --- a/invokeai/backend/model_manager/configs/t2i_adapter.py +++ b/invokeai/backend/model_manager/configs/t2i_adapter.py @@ -6,8 +6,8 @@ from typing import ( from pydantic import Field from typing_extensions import Any -from invokeai.backend.model_manager.config import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.identification_utils import ( NotAMatchError, common_config_paths, diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py deleted file mode 100644 index 85a39fd25e..0000000000 --- a/invokeai/backend/model_manager/legacy_probe.py +++ /dev/null @@ -1,1034 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Callable, Dict, Literal, Optional, Union - -import picklescan.scanner as pscan -import safetensors.torch -import torch - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.misc import uuid_string -from invokeai.backend.flux.controlnet.state_dict_utils import ( - is_state_dict_instantx_controlnet, - is_state_dict_xlabs_controlnet, -) -from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict -from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter -from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux -from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ControlAdapterDefaultSettings, - InvalidModelConfigException, - LoraModelDefaultSettings, - MainModelDefaultSettings, - ModelConfigFactory, - SubmodelDefinition, -) -from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader -from invokeai.backend.model_manager.model_on_disk import ModelOnDisk -from invokeai.backend.model_manager.taxonomy import ( - AnyVariant, - BaseModelType, - FluxVariantType, - ModelFormat, - ModelRepoVariant, - ModelSourceType, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_manager.util.model_util import ( - get_clip_variant_type, - lora_token_vector_length, - read_checkpoint_meta, -) -from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control -from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import ( - is_state_dict_likely_in_flux_diffusers_format, -) -from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( - is_state_dict_likely_in_flux_kohya_format, -) -from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( - is_state_dict_likely_in_flux_onetrainer_format, -) -from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor -from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader -from invokeai.backend.util.silence_warnings import SilenceWarnings - -CkptType = Dict[str | int, Any] - - -LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = { - BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v1-inference.yaml", - SchedulerPredictionType.VPrediction: "v1-inference-v.yaml", - }, - ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", - }, - BaseModelType.StableDiffusion2: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v2-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inference-v.yaml", - }, - ModelVariantType.Inpaint: { - SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml", - }, - ModelVariantType.Depth: "v2-midas-inference.yaml", - }, - BaseModelType.StableDiffusionXL: { - ModelVariantType.Normal: "sd_xl_base.yaml", - ModelVariantType.Inpaint: "sd_xl_inpaint.yaml", - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelVariantType.Normal: "sd_xl_refiner.yaml", - }, -} - - -class ProbeBase(object): - """Base class for probes.""" - - def __init__(self, model_path: Path): - self.model_path = model_path - - def get_base_type(self) -> BaseModelType: - """Get model base type.""" - raise NotImplementedError - - def get_format(self) -> ModelFormat: - """Get model file format.""" - raise NotImplementedError - - def get_variant_type(self) -> AnyVariant | None: - """Get model variant type.""" - return None - - def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: - """Get model scheduler prediction type.""" - return None - - def get_image_encoder_model_id(self) -> Optional[str]: - """Get image encoder (IP adapters only).""" - return None - - -class ModelProbe(object): - PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { - "diffusers": {}, - "checkpoint": {}, - "onnx": {}, - } - - CLASS2TYPE = { - "FluxPipeline": ModelType.Main, - "StableDiffusionPipeline": ModelType.Main, - "StableDiffusionInpaintPipeline": ModelType.Main, - "StableDiffusionXLPipeline": ModelType.Main, - "StableDiffusionXLImg2ImgPipeline": ModelType.Main, - "StableDiffusionXLInpaintPipeline": ModelType.Main, - "StableDiffusion3Pipeline": ModelType.Main, - "LatentConsistencyModelPipeline": ModelType.Main, - "AutoencoderKL": ModelType.VAE, - "AutoencoderTiny": ModelType.VAE, - "ControlNetModel": ModelType.ControlNet, - "CLIPVisionModelWithProjection": ModelType.CLIPVision, - "T2IAdapter": ModelType.T2IAdapter, - "CLIPModel": ModelType.CLIPEmbed, - "CLIPTextModel": ModelType.CLIPEmbed, - "T5EncoderModel": ModelType.T5Encoder, - "FluxControlNetModel": ModelType.ControlNet, - "SD3Transformer2DModel": ModelType.Main, - "CLIPTextModelWithProjection": ModelType.CLIPEmbed, - "SiglipModel": ModelType.SigLIP, - "LlavaOnevisionForConditionalGeneration": ModelType.LlavaOnevision, - "CogView4Pipeline": ModelType.Main, - } - - TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type} - - @classmethod - def register_probe( - cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase] - ) -> None: - cls.PROBES[format][model_type] = probe_class - - @classmethod - def probe( - cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3_single" - ) -> AnyModelConfig: - """ - Probe the model at model_path and return its configuration record. - - :param model_path: Path to the model file (checkpoint) or directory (diffusers). - :param fields: An optional dictionary that can be used to override probed - fields. Typically used for fields that don't probe well, such as prediction_type. - - Returns: The appropriate model configuration derived from ModelConfigBase. - """ - if fields is None: - fields = {} - - model_path = model_path.resolve() - - format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint - model_info = None - model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None - if not model_type: - if format_type is ModelFormat.Diffusers: - model_type = cls.get_model_type_from_folder(model_path) - else: - model_type = cls.get_model_type_from_checkpoint(model_path) - format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type - - probe_class = cls.PROBES[format_type].get(model_type) - if not probe_class: - raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") - - probe = probe_class(model_path) - - fields["source_type"] = fields.get("source_type") or ModelSourceType.Path - fields["source"] = fields.get("source") or model_path.as_posix() - fields["key"] = fields.get("key", uuid_string()) - fields["path"] = model_path.as_posix() - fields["type"] = fields.get("type") or model_type - fields["base"] = fields.get("base") or probe.get_base_type() - variant_func = cls.TYPE2VARIANT.get(fields["type"], None) - fields["variant"] = ( - fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type() - ) - fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type() - fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id() - fields["name"] = fields.get("name") or cls.get_model_name(model_path) - fields["description"] = ( - fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}" - ) - fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format() - fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path) - fields["file_size"] = fields.get("file_size") or ModelOnDisk(model_path).size() - - fields["default_settings"] = fields.get("default_settings") - - if not fields["default_settings"]: - if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter, ModelType.ControlLoRa}: - fields["default_settings"] = get_default_settings_control_adapters(fields["name"]) - if fields["type"] in {ModelType.LoRA}: - fields["default_settings"] = get_default_settings_lora() - elif fields["type"] is ModelType.Main: - fields["default_settings"] = get_default_settings_main(fields["base"]) - - if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase): - fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() - - # additional fields needed for main and controlnet models - if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [ - ModelFormat.Checkpoint, - ModelFormat.BnbQuantizednf4b, - ModelFormat.GGUFQuantized, - ]: - ckpt_config_path = cls._get_checkpoint_config_path( - model_path, - model_type=fields["type"], - base_type=fields["base"], - variant_type=fields["variant"], - prediction_type=fields["prediction_type"], - ) - fields["config_path"] = str(ckpt_config_path) - - # additional fields needed for main non-checkpoint models - elif fields["type"] == ModelType.Main and fields["format"] in [ - ModelFormat.ONNX, - ModelFormat.Olive, - ModelFormat.Diffusers, - ]: - fields["upcast_attention"] = fields.get("upcast_attention") or ( - fields["base"] == BaseModelType.StableDiffusion2 - and fields["prediction_type"] == SchedulerPredictionType.VPrediction - ) - - get_submodels = getattr(probe, "get_submodels", None) - if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels): - fields["submodels"] = get_submodels() - - model_info = ModelConfigFactory.make_config(fields) - return model_info - - @classmethod - def get_model_name(cls, model_path: Path) -> str: - if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: - return model_path.stem - else: - return model_path.name - - @classmethod - def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType: - if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth", ".gguf"): - raise InvalidModelConfigException(f"{model_path}: unrecognized suffix") - - if model_path.name == "learned_embeds.bin": - return ModelType.TextualInversion - - ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True) - ckpt = ckpt.get("state_dict", ckpt) - - if isinstance(ckpt, dict) and is_state_dict_likely_flux_control(ckpt): - return ModelType.ControlLoRa - - if isinstance(ckpt, dict) and is_state_dict_likely_flux_redux(ckpt): - return ModelType.FluxRedux - - for key in [str(k) for k in ckpt.keys()]: - if key.startswith( - ( - "cond_stage_model.", - "first_stage_model.", - "model.diffusion_model.", - # Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model". - # This prefix is typically used to distinguish between multiple models bundled in a single file. - "model.diffusion_model.double_blocks.", - ) - ): - # Keys starting with double_blocks are associated with Flux models - return ModelType.Main - # FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be - # careful to avoid false positives on XLabs FLUX IP-Adapter models. - elif key.startswith("double_blocks.") and "ip_adapter" not in key: - return ModelType.Main - elif key.startswith(("encoder.conv_in", "decoder.conv_in")): - return ModelType.VAE - elif key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")): - return ModelType.LoRA - # "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT - # LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models. - elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")): - return ModelType.LoRA - elif key.startswith( - ( - "controlnet", - "control_model", - "input_blocks", - # XLabs FLUX ControlNet models have keys starting with "controlnet_blocks." - # For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors - # TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with - # "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so - # delicate. - "controlnet_blocks", - ) - ): - return ModelType.ControlNet - elif key.startswith( - ( - "image_proj.", - "ip_adapter.", - # XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.". - "ip_adapter_proj_model.", - ) - ): - return ModelType.IPAdapter - elif key in {"emb_params", "string_to_param"}: - return ModelType.TextualInversion - - # diffusers-ti - if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): - return ModelType.TextualInversion - - raise InvalidModelConfigException(f"Unable to determine model type for {model_path}") - - @classmethod - def get_model_type_from_folder(cls, folder_path: Path) -> ModelType: - """Get the model type of a hugging-face style folder.""" - class_name = None - error_hint = None - for suffix in ["bin", "safetensors"]: - if (folder_path / f"learned_embeds.{suffix}").exists(): - return ModelType.TextualInversion - if (folder_path / f"pytorch_lora_weights.{suffix}").exists(): - return ModelType.LoRA - if (folder_path / "unet/model.onnx").exists(): - return ModelType.ONNX - if (folder_path / "image_encoder.txt").exists(): - return ModelType.IPAdapter - - config_path = None - for p in [ - folder_path / "model_index.json", # pipeline - folder_path / "config.json", # most diffusers - folder_path / "text_encoder_2" / "config.json", # T5 text encoder - folder_path / "text_encoder" / "config.json", # T5 CLIP - ]: - if p.exists(): - config_path = p - break - - if config_path: - with open(config_path, "r") as file: - conf = json.load(file) - if "_class_name" in conf: - class_name = conf["_class_name"] - elif "architectures" in conf: - class_name = conf["architectures"][0] - else: - class_name = None - else: - error_hint = f"No model_index.json or config.json found in {folder_path}." - - if class_name and (type := cls.CLASS2TYPE.get(class_name)): - return type - else: - error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]" - - # give up - raise InvalidModelConfigException( - f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "") - ) - - @classmethod - def _get_checkpoint_config_path( - cls, - model_path: Path, - model_type: ModelType, - base_type: BaseModelType, - variant_type: ModelVariantType, - prediction_type: SchedulerPredictionType, - ) -> Path: - # look for a YAML file adjacent to the model file first - possible_conf = model_path.with_suffix(".yaml") - if possible_conf.exists(): - return possible_conf.absolute() - - if model_type is ModelType.Main: - if base_type == BaseModelType.Flux: - # TODO: Decide between dev/schnell - checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) - state_dict = checkpoint.get("state_dict") or checkpoint - - # HACK: For FLUX, config_file is used as a key into invokeai.backend.flux.util.params during model - # loading. When FLUX support was first added, it was decided that this was the easiest way to support - # the various FLUX formats rather than adding new model types/formats. Be careful when modifying this in - # the future. - if ( - "guidance_in.out_layer.weight" in state_dict - or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict - ): - if variant_type == ModelVariantType.Normal: - config_file = "flux-dev" - elif variant_type == ModelVariantType.Inpaint: - config_file = "flux-dev-fill" - else: - raise ValueError(f"Unexpected FLUX variant type: {variant_type}") - else: - config_file = "flux-schnell" - else: - config_file = LEGACY_CONFIGS[base_type][variant_type] - if isinstance(config_file, dict): # need another tier for sd-2.x models - config_file = config_file[prediction_type] - config_file = f"stable-diffusion/{config_file}" - elif model_type is ModelType.ControlNet: - config_file = ( - "controlnet/cldm_v15.yaml" - if base_type is BaseModelType.StableDiffusion1 - else "controlnet/cldm_v21.yaml" - ) - elif model_type is ModelType.VAE: - config_file = ( - # For flux, this is a key in invokeai.backend.flux.util.ae_params - # Due to model type and format being the descriminator for model configs this - # is used rather than attempting to support flux with separate model types and format - # If changed in the future, please fix me - "flux" - if base_type is BaseModelType.Flux - else "stable-diffusion/v1-inference.yaml" - if base_type is BaseModelType.StableDiffusion1 - else "stable-diffusion/sd_xl_base.yaml" - if base_type is BaseModelType.StableDiffusionXL - else "stable-diffusion/v2-inference.yaml" - ) - else: - raise InvalidModelConfigException( - f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}" - ) - return Path(config_file) - - @classmethod - def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType: - with SilenceWarnings(): - if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")): - cls._scan_model(model_path.name, model_path) - model = torch.load(model_path, map_location="cpu") - assert isinstance(model, dict) - return model - elif model_path.suffix.endswith(".gguf"): - return gguf_sd_loader(model_path, compute_dtype=torch.float32) - else: - return safetensors.torch.load_file(model_path) - - @classmethod - def _scan_model(cls, model_name: str, checkpoint: Path) -> None: - """ - Apply picklescanner to the indicated checkpoint and issue a warning - and option to exit if an infected file is identified. - """ - # scan model - scan_result = pscan.scan_file_path(checkpoint) - if scan_result.infected_files != 0: - if get_config().unsafe_disable_picklescan: - logger.warning( - f"The model {model_name} is potentially infected by malware, but picklescan is disabled. " - "Proceeding with caution." - ) - else: - raise RuntimeError(f"The model {model_name} is potentially infected by malware. Aborting import.") - if scan_result.scan_err: - if get_config().unsafe_disable_picklescan: - logger.warning( - f"Error scanning the model at {model_name} for malware, but picklescan is disabled. " - "Proceeding with caution." - ) - else: - raise RuntimeError(f"Error scanning the model at {model_name} for malware. Aborting import.") - - -# Probing utilities -MODEL_NAME_TO_PREPROCESSOR = { - "canny": "canny_image_processor", - "mlsd": "mlsd_image_processor", - "depth": "depth_anything_image_processor", - "bae": "normalbae_image_processor", - "normal": "normalbae_image_processor", - "sketch": "pidi_image_processor", - "scribble": "lineart_image_processor", - "lineart anime": "lineart_anime_image_processor", - "lineart_anime": "lineart_anime_image_processor", - "lineart": "lineart_image_processor", - "soft": "hed_image_processor", - "softedge": "hed_image_processor", - "hed": "hed_image_processor", - "shuffle": "content_shuffle_image_processor", - "pose": "dw_openpose_image_processor", - "mediapipe": "mediapipe_face_processor", - "pidi": "pidi_image_processor", - "zoe": "zoe_depth_image_processor", - "color": "color_map_image_processor", -} - - -def get_default_settings_control_adapters(model_name: str) -> Optional[ControlAdapterDefaultSettings]: - for k, v in MODEL_NAME_TO_PREPROCESSOR.items(): - model_name_lower = model_name.lower() - if k in model_name_lower: - return ControlAdapterDefaultSettings(preprocessor=v) - return None - - -def get_default_settings_lora() -> LoraModelDefaultSettings: - return LoraModelDefaultSettings() - - -def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]: - if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2: - return MainModelDefaultSettings(width=512, height=512) - elif model_base is BaseModelType.StableDiffusionXL: - return MainModelDefaultSettings(width=1024, height=1024) - # We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models. - return None - - -# ##################################################3 -# Checkpoint probing -# ##################################################3 - - -class CheckpointProbeBase(ProbeBase): - def __init__(self, model_path: Path): - super().__init__(model_path) - self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) - - def get_format(self) -> ModelFormat: - state_dict = self.checkpoint.get("state_dict") or self.checkpoint - if ( - "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict - or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict - ): - return ModelFormat.BnbQuantizednf4b - elif any(isinstance(v, GGMLTensor) for v in state_dict.values()): - return ModelFormat.GGUFQuantized - return ModelFormat("checkpoint") - - def get_variant_type(self) -> AnyVariant: - model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint) - base_type = self.get_base_type() - if model_type != ModelType.Main: - return ModelVariantType.Normal - state_dict = self.checkpoint.get("state_dict") or self.checkpoint - - if base_type == BaseModelType.Flux: - in_channels = get_flux_in_channels_from_state_dict(state_dict) - - if in_channels is None: - # If we cannot find the in_channels, we assume that this is a normal variant. Log a warning. - logger.warning( - f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant." - ) - return ModelVariantType.Normal - - is_flux_dev = ( - "guidance_in.out_layer.weight" in state_dict - or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict - ) - - # FLUX Model variant types are distinguished by input channels: - # - Unquantized Dev and Schnell have in_channels=64 - # - BNB-NF4 Dev and Schnell have in_channels=1 - # - FLUX Fill has in_channels=384 - # - Unsure of quantized FLUX Fill models - # - Unsure of GGUF-quantized models - if is_flux_dev and in_channels == 384: - # This is a FLUX Fill model. FLUX Fill needs special handling throughout the application. The variant - # type is used to determine whether to use the fill model or the base model. - return FluxVariantType.DevFill - elif is_flux_dev: - # Fall back on "normal" variant type for all other FLUX models. - return FluxVariantType.Dev - else: - return FluxVariantType.Schnell - - in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - else: - raise InvalidModelConfigException( - f"Cannot determine variant type (in_channels={in_channels}) at {self.model_path}" - ) - - -class PipelineCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - if ( - "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict - or "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict - ): - return BaseModelType.Flux - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1 - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - return BaseModelType.StableDiffusion2 - key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: - return BaseModelType.StableDiffusionXL - elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: - return BaseModelType.StableDiffusionXLRefiner - else: - raise InvalidModelConfigException("Cannot determine base type") - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - """Return model prediction type.""" - type = self.get_base_type() - if type == BaseModelType.StableDiffusion2: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - if "global_step" in checkpoint: - if checkpoint["global_step"] == 220000: - return SchedulerPredictionType.Epsilon - elif checkpoint["global_step"] == 110000: - return SchedulerPredictionType.VPrediction - return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts - - elif type == BaseModelType.StableDiffusion1: - return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts - else: - return SchedulerPredictionType.Epsilon - - -class LoRACheckpointProbe(CheckpointProbeBase): - """Class for LoRA checkpoints.""" - - def get_format(self) -> ModelFormat: - if is_state_dict_likely_in_flux_diffusers_format(self.checkpoint): - # TODO(ryand): This is an unusual case. In other places throughout the codebase, we treat - # ModelFormat.Diffusers as meaning that the model is in a directory. In this case, the model is a single - # file, but the weight keys are in the diffusers format. - return ModelFormat.Diffusers - return ModelFormat.LyCORIS - - def get_base_type(self) -> BaseModelType: - if ( - is_state_dict_likely_in_flux_kohya_format(self.checkpoint) - or is_state_dict_likely_in_flux_onetrainer_format(self.checkpoint) - or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint) - or is_state_dict_likely_flux_control(self.checkpoint) - ): - return BaseModelType.Flux - - # If we've gotten here, we assume that the model is a Stable Diffusion model. - token_vector_length = lora_token_vector_length(self.checkpoint) - if token_vector_length == 768: - return BaseModelType.StableDiffusion1 - elif token_vector_length == 1024: - return BaseModelType.StableDiffusion2 - elif token_vector_length == 1280: - return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 - elif token_vector_length == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException(f"Unknown LoRA type: {self.model_path}") - - -class ControlNetCheckpointProbe(CheckpointProbeBase): - """Class for probing controlnets.""" - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - if is_state_dict_xlabs_controlnet(checkpoint) or is_state_dict_instantx_controlnet(checkpoint): - # TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing - # get_format()? - return BaseModelType.Flux - - for key_name in ( - "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "controlnet_mid_block.bias", - "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", - ): - if key_name not in checkpoint: - continue - width = checkpoint[key_name].shape[-1] - if width == 768: - return BaseModelType.StableDiffusion1 - elif width == 1024: - return BaseModelType.StableDiffusion2 - elif width == 2048: - return BaseModelType.StableDiffusionXL - elif width == 1280: - return BaseModelType.StableDiffusionXL - raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type") - - -class IPAdapterCheckpointProbe(CheckpointProbeBase): - """Class for probing IP Adapters""" - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - - if is_state_dict_xlabs_ip_adapter(checkpoint): - return BaseModelType.Flux - - for key in checkpoint.keys(): - if not key.startswith(("image_proj.", "ip_adapter.")): - continue - cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1] - if cross_attention_dim == 768: - return BaseModelType.StableDiffusion1 - elif cross_attention_dim == 1024: - return BaseModelType.StableDiffusion2 - elif cross_attention_dim == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException( - f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." - ) - raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type") - - -class CLIPVisionCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class T2IAdapterCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class SigLIPCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class FluxReduxCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Flux - - -class LlavaOnevisionCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -######################################################## -# classes for probing folders -####################################################### -class FolderProbeBase(ProbeBase): - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - def get_format(self) -> ModelFormat: - return ModelFormat("diffusers") - - def get_repo_variant(self) -> ModelRepoVariant: - # get all files ending in .bin or .safetensors - weight_files = list(self.model_path.glob("**/*.safetensors")) - weight_files.extend(list(self.model_path.glob("**/*.bin"))) - for x in weight_files: - if ".fp16" in x.suffixes: - return ModelRepoVariant.FP16 - if "openvino_model" in x.name: - return ModelRepoVariant.OpenVINO - if "flax_model" in x.name: - return ModelRepoVariant.Flax - if x.suffix == ".onnx": - return ModelRepoVariant.ONNX - return ModelRepoVariant.Default - - -class PipelineFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - # Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL). - config_path = self.model_path / "unet" / "config.json" - if config_path.exists(): - with open(config_path) as file: - unet_conf = json.load(file) - if unet_conf["cross_attention_dim"] == 768: - return BaseModelType.StableDiffusion1 - elif unet_conf["cross_attention_dim"] == 1024: - return BaseModelType.StableDiffusion2 - elif unet_conf["cross_attention_dim"] == 1280: - return BaseModelType.StableDiffusionXLRefiner - elif unet_conf["cross_attention_dim"] == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException(f"Unknown base model for {self.model_path}") - - # Handle pipelines with a transformer (i.e. SD3). - config_path = self.model_path / "transformer" / "config.json" - if config_path.exists(): - with open(config_path) as file: - transformer_conf = json.load(file) - if transformer_conf["_class_name"] == "SD3Transformer2DModel": - return BaseModelType.StableDiffusion3 - elif transformer_conf["_class_name"] == "CogView4Transformer2DModel": - return BaseModelType.CogView4 - else: - raise InvalidModelConfigException(f"Unknown base model for {self.model_path}") - - raise InvalidModelConfigException(f"Unknown base model for {self.model_path}") - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file: - scheduler_conf = json.load(file) - if scheduler_conf.get("prediction_type", "epsilon") == "v_prediction": - return SchedulerPredictionType.VPrediction - elif scheduler_conf.get("prediction_type", "epsilon") == "epsilon": - return SchedulerPredictionType.Epsilon - else: - raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}") - - def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]: - config = ConfigLoader.load_config(self.model_path, config_name="model_index.json") - submodels: Dict[SubModelType, SubmodelDefinition] = {} - for key, value in config.items(): - if key.startswith("_") or not (isinstance(value, list) and len(value) == 2): - continue - model_loader = str(value[1]) - if model_type := ModelProbe.CLASS2TYPE.get(model_loader): - variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None) - submodels[SubModelType(key)] = SubmodelDefinition( - path_or_prefix=(self.model_path / key).resolve().as_posix(), - model_type=model_type, - variant=variant_func and variant_func((self.model_path / key).as_posix()), - ) - - return submodels - - def get_variant_type(self) -> ModelVariantType: - # This only works for pipelines! Any kind of - # exception results in our returning the - # "normal" variant type - try: - config_file = self.model_path / "unet" / "config.json" - with open(config_file, "r") as file: - conf = json.load(file) - - in_channels = conf["in_channels"] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - except Exception: - pass - return ModelVariantType.Normal - - -class ONNXFolderProbe(PipelineFolderProbe): - def get_base_type(self) -> BaseModelType: - # Due to the way the installer is set up, the configuration file for safetensors - # will come along for the ride if both the onnx and safetensors forms - # share the same directory. We take advantage of this here. - if (self.model_path / "unet" / "config.json").exists(): - return super().get_base_type() - else: - logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"') - return BaseModelType.StableDiffusion1 - - def get_format(self) -> ModelFormat: - return ModelFormat("onnx") - - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - -class ControlNetFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.model_path / "config.json" - if not config_file.exists(): - raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}") - with open(config_file, "r") as file: - config = json.load(file) - - if config.get("_class_name", None) == "FluxControlNetModel": - return BaseModelType.Flux - - # no obvious way to distinguish between sd2-base and sd2-768 - dimension = config["cross_attention_dim"] - if dimension == 768: - return BaseModelType.StableDiffusion1 - if dimension == 1024: - return BaseModelType.StableDiffusion2 - if dimension == 2048: - return BaseModelType.StableDiffusionXL - raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}") - - -class LoRAFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - model_file = None - for suffix in ["safetensors", "bin"]: - base_file = self.model_path / f"pytorch_lora_weights.{suffix}" - if base_file.exists(): - model_file = base_file - break - if not model_file: - raise InvalidModelConfigException("Unknown LoRA format encountered") - return LoRACheckpointProbe(model_file).get_base_type() - - -class IPAdapterFolderProbe(FolderProbeBase): - def get_format(self) -> ModelFormat: - return ModelFormat.InvokeAI - - def get_base_type(self) -> BaseModelType: - model_file = self.model_path / "ip_adapter.bin" - if not model_file.exists(): - raise InvalidModelConfigException("Unknown IP-Adapter model format.") - - state_dict = torch.load(model_file, map_location="cpu") - cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] - if cross_attention_dim == 768: - return BaseModelType.StableDiffusion1 - elif cross_attention_dim == 1024: - return BaseModelType.StableDiffusion2 - elif cross_attention_dim == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException( - f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." - ) - - def get_image_encoder_model_id(self) -> Optional[str]: - encoder_id_path = self.model_path / "image_encoder.txt" - if not encoder_id_path.exists(): - return None - with open(encoder_id_path, "r") as f: - image_encoder_model = f.readline().strip() - return image_encoder_model - - -class CLIPVisionFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class SpandrelImageToImageFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class SigLIPFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class FluxReduxFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class LlaveOnevisionFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class T2IAdapterFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.model_path / "config.json" - if not config_file.exists(): - raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}") - with open(config_file, "r") as file: - config = json.load(file) - - adapter_type = config.get("adapter_type", None) - if adapter_type == "full_adapter_xl": - return BaseModelType.StableDiffusionXL - elif adapter_type == "full_adapter" or "light_adapter": - # I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter. - return BaseModelType.StableDiffusion1 - else: - raise InvalidModelConfigException( - f"Unable to determine base model for '{self.model_path}' (adapter_type = {adapter_type})." - ) - - -# Register probe classes -ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe) - -ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe) - -ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 75191517c7..c4d71bfe31 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -12,7 +12,7 @@ from typing import Any, Dict, Generator, Optional, Tuple import torch from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 139a7d2940..3fb7a574f3 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -6,7 +6,8 @@ from pathlib import Path from typing import Optional from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager.config import AnyModelConfig, Diffusers_Config_Base, InvalidModelConfigException +from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key @@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase): model_path = self._get_model_path(model_config) if not model_path.exists(): - raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}") + raise FileNotFoundError(f"Files for model '{model_config.name}' not found at {model_path}") with skip_torch_weight_init(): cache_record = self._load_and_cache(model_config, submodel_type) diff --git a/invokeai/backend/model_manager/load/model_loader_registry.py b/invokeai/backend/model_manager/load/model_loader_registry.py index 9b242fe167..ca9ea56edb 100644 --- a/invokeai/backend/model_manager/load/model_loader_registry.py +++ b/invokeai/backend/model_manager/load/model_loader_registry.py @@ -18,10 +18,8 @@ Use like this: from abc import ABC, abstractmethod from typing import Callable, Dict, Optional, Tuple, Type, TypeVar -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - Config_Base, -) +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load import ModelLoaderBase from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType diff --git a/invokeai/backend/model_manager/load/model_loaders/clip_vision.py b/invokeai/backend/model_manager/load/model_loaders/clip_vision.py index 9065e51fbf..0150e24248 100644 --- a/invokeai/backend/model_manager/load/model_loaders/clip_vision.py +++ b/invokeai/backend/model_manager/load/model_loaders/clip_vision.py @@ -3,10 +3,8 @@ from typing import Optional from transformers import CLIPVisionModelWithProjection -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - Diffusers_Config_Base, -) +from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType diff --git a/invokeai/backend/model_manager/load/model_loaders/cogview4.py b/invokeai/backend/model_manager/load/model_loaders/cogview4.py index a1a9269edb..782ff38450 100644 --- a/invokeai/backend/model_manager/load/model_loaders/cogview4.py +++ b/invokeai/backend/model_manager/load/model_loaders/cogview4.py @@ -3,11 +3,8 @@ from typing import Optional import torch -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - Checkpoint_Config_Base, - Diffusers_Config_Base, -) +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import ( diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py index 62a8ed4f65..8fd1796b8f 100644 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -5,10 +5,8 @@ from typing import Optional from diffusers import ControlNetModel -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ControlNet_Checkpoint_Config_Base, -) +from invokeai.backend.model_manager.configs.controlnet import ControlNet_Checkpoint_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import ( diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 07967c7c56..e44ddec382 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -34,21 +34,22 @@ from invokeai.backend.flux.model import Flux from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers_params -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - Checkpoint_Config_Base, - CLIPEmbed_Diffusers_Config_Base, +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base +from invokeai.backend.model_manager.configs.clip_embed import CLIPEmbed_Diffusers_Config_Base +from invokeai.backend.model_manager.configs.controlnet import ( ControlNet_Checkpoint_Config_Base, ControlNet_Diffusers_Config_Base, - FLUXRedux_Checkpoint_Config, - IPAdapter_Checkpoint_Config_Base, +) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoint_Config +from invokeai.backend.model_manager.configs.ip_adapter import IPAdapter_Checkpoint_Config_Base +from invokeai.backend.model_manager.configs.main import ( Main_BnBNF4_FLUX_Config, Main_Checkpoint_FLUX_Config, Main_GGUF_FLUX_Config, - T5Encoder_BnBLLMint8_Config, - T5Encoder_T5Encoder_Config, - VAE_Checkpoint_Config_Base, ) +from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config +from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import ( diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index 407a116b68..b888c69edf 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -8,7 +8,8 @@ from typing import Any, Optional from diffusers.configuration_utils import ConfigMixin from diffusers.models.modeling_utils import ModelMixin -from invokeai.backend.model_manager.config import AnyModelConfig, Diffusers_Config_Base, InvalidModelConfigException +from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import ( @@ -56,9 +57,7 @@ class GenericDiffusersLoader(ModelLoader): module, class_name = config[submodel_type.value] result = self._hf_definition_to_type(module=module, class_name=class_name) except KeyError as e: - raise InvalidModelConfigException( - f'The "{submodel_type}" submodel is not available for this model.' - ) from e + raise ValueError(f'The "{submodel_type}" submodel is not available for this model.') from e else: try: config = self._load_diffusers_config(model_path, config_name="config.json") @@ -67,9 +66,9 @@ class GenericDiffusersLoader(ModelLoader): elif class_name := config.get("architectures"): result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) else: - raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json") + raise RuntimeError("Unable to decipher Load Class based on given config.json") except KeyError as e: - raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e + raise ValueError("An expected config.json file is missing from this model.") from e assert result is not None return result diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py index d103bc5dbc..d133a36498 100644 --- a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py +++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py @@ -7,7 +7,7 @@ from typing import Optional import torch from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.raw_model import RawModel diff --git a/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py b/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py index b508137f81..e459bbf2bb 100644 --- a/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py +++ b/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py @@ -3,9 +3,7 @@ from typing import Optional from transformers import LlavaOnevisionForConditionalGeneration -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 98f54224fa..29fb815d54 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -9,7 +9,7 @@ import torch from safetensors.torch import load_file from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry diff --git a/invokeai/backend/model_manager/load/model_loaders/onnx.py b/invokeai/backend/model_manager/load/model_loaders/onnx.py index 3078d622b4..a565bb11d0 100644 --- a/invokeai/backend/model_manager/load/model_loaders/onnx.py +++ b/invokeai/backend/model_manager/load/model_loaders/onnx.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Optional -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import ( diff --git a/invokeai/backend/model_manager/load/model_loaders/sig_lip.py b/invokeai/backend/model_manager/load/model_loaders/sig_lip.py index bdf38887a3..16b8e6c88d 100644 --- a/invokeai/backend/model_manager/load/model_loaders/sig_lip.py +++ b/invokeai/backend/model_manager/load/model_loaders/sig_lip.py @@ -3,9 +3,7 @@ from typing import Optional from transformers import SiglipVisionModel -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType diff --git a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py index 44cb0277fc..e6d8f42990 100644 --- a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py +++ b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py @@ -3,9 +3,7 @@ from typing import Optional import torch -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index 647ad4dbf4..d0cc589379 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -11,10 +11,9 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpain StableDiffusionXLInpaintPipeline, ) -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - Checkpoint_Config_Base, - Diffusers_Config_Base, +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.main import ( Main_Checkpoint_SD1_Config, Main_Checkpoint_SD2_Config, Main_Checkpoint_SDXL_Config, diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py index 60ae4ea08b..2d0411a8df 100644 --- a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Optional -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import ( diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 12789e58c2..e91903ccda 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -5,7 +5,8 @@ from typing import Optional from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from invokeai.backend.model_manager.config import AnyModelConfig, VAE_Checkpoint_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import ( diff --git a/invokeai/backend/model_manager/util/lora_metadata_extractor.py b/invokeai/backend/model_manager/util/lora_metadata_extractor.py index 842e78a788..12b1073935 100644 --- a/invokeai/backend/model_manager/util/lora_metadata_extractor.py +++ b/invokeai/backend/model_manager/util/lora_metadata_extractor.py @@ -8,7 +8,8 @@ from typing import Any, Dict, Optional, Set, Tuple from PIL import Image from invokeai.app.util.thumbnails import make_thumbnail -from invokeai.backend.model_manager.config import AnyModelConfig, ModelType +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.taxonomy import ModelType logger = logging.getLogger(__name__) diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index add394e71b..e4208dc848 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -7,7 +7,8 @@ import torch from invokeai.app.services.model_manager import ModelManagerServiceBase from invokeai.app.services.model_records import UnknownModelException -from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType @pytest.fixture(scope="session") diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx index bf4464bd5b..49a289b875 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx @@ -17,6 +17,7 @@ import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig'; +import { isFluxFillMainModelModelConfig } from 'services/api/types'; const selectHasRasterLayersWithContent = createSelector( selectActiveRasterLayerEntities, @@ -46,11 +47,7 @@ export const ParamDenoisingStrength = memo(() => { // Denoising strength does nothing if there are no raster layers w/ content return true; } - if ( - selectedModelConfig?.type === 'main' && - selectedModelConfig?.base === 'flux' && - selectedModelConfig.variant === 'inpaint' - ) { + if (selectedModelConfig && isFluxFillMainModelModelConfig(selectedModelConfig)) { // Denoising strength is ignored by FLUX Fill, which is indicated by the variant being 'inpaint' return true; } diff --git a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts index 03ef5404a6..197a3d6e3e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts @@ -154,7 +154,7 @@ export const getControlLayerWarnings = ( warnings.push(WARNINGS.CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL); } else if ( model.base === 'flux' && - model.variant === 'inpaint' && + model.variant === 'dev_fill' && entity.controlAdapter.model.type === 'control_lora' ) { // FLUX inpaint variants are FLUX Fill models - not compatible w/ Control LoRA diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx index 75b3ba4bc4..538ebf597e 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx @@ -56,18 +56,17 @@ export const ModelView = memo(({ modelConfig }: Props) => { - {modelConfig.type === 'main' && ( + {modelConfig.type === 'main' && 'variant' in modelConfig && ( )} {modelConfig.type === 'main' && modelConfig.format === 'diffusers' && modelConfig.repo_variant && ( )} {modelConfig.type === 'main' && modelConfig.format === 'checkpoint' && ( - <> - - - - + + )} + {modelConfig.type === 'main' && modelConfig.format === 'checkpoint' && 'prediction_type' in modelConfig && ( + )} {modelConfig.type === 'ip_adapter' && modelConfig.format === 'invokeai' && ( diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts index 5a766e8d39..42d66f0fc8 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts @@ -660,6 +660,7 @@ describe('Graph', () => { cover_image: null, type: 'main', trigger_phrases: null, + prediction_type: 'epsilon', default_settings: { vae: null, vae_precision: null, @@ -673,7 +674,6 @@ describe('Graph', () => { variant: 'inpaint', format: 'diffusers', repo_variant: 'fp16', - submodels: null, usage_info: null, }); expect(field).toEqual({ diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts index 5fcd13ba4f..0f7163c44c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts @@ -5,7 +5,7 @@ import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store import { getControlLayerWarnings } from 'features/controlLayers/store/validators'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import { serializeError } from 'serialize-error'; -import type { ImageDTO, Invocation, MainModelConfig } from 'services/api/types'; +import type { FLUXModelConfig, ImageDTO, Invocation, MainModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; const log = logger('system'); @@ -113,7 +113,7 @@ type AddControlLoRAArg = { entities: CanvasControlLayerState[]; g: Graph; rect: Rect; - model: MainModelConfig; + model: FLUXModelConfig; denoise: Invocation<'flux_denoise'>; }; @@ -129,7 +129,7 @@ export const addControlLoRA = async ({ manager, entities, g, rect, model, denois return; } - assert(model.variant !== 'inpaint', 'FLUX Control LoRA is not compatible with FLUX Fill.'); + assert(model.variant !== 'dev_fill', 'FLUX Control LoRA is not compatible with FLUX Fill.'); assert(validControlLayers.length <= 1, 'Cannot add more than one FLUX control LoRA.'); const getImageDTOResult = await withResultAsync(() => { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index 6b0140bf9e..558f8b2ffe 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -49,7 +49,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise { // Model Configs export type AnyModelConfig = S['AnyModelConfig']; export type MainModelConfig = Extract; +export type FLUXModelConfig = Extract; export type ControlLoRAModelConfig = Extract; export type LoRAModelConfig = Extract; export type VAEModelConfig = Extract; @@ -134,6 +135,7 @@ type UnknownModelConfig = Extract; export type FLUXKontextModelConfig = MainModelConfig; export type ChatGPT4oModelConfig = ApiModelConfig; export type Gemini2_5ModelConfig = ApiModelConfig; +type SubmodelDefinition = S['SubmodelDefinition']; /** * Checks if a list of submodels contains any that match a given variant or type @@ -141,7 +143,7 @@ export type Gemini2_5ModelConfig = ApiModelConfig; * @param checkStr The string to check against for variant or type * @returns A boolean */ -const checkSubmodel = (submodels: AnyModelConfig['submodels'], checkStr: string): boolean => { +const checkSubmodel = (submodels: Record, checkStr: string): boolean => { for (const submodel in submodels) { if ( submodel && @@ -164,6 +166,7 @@ const checkSubmodels = (identifiers: string[], config: AnyModelConfig): boolean return identifiers.every( (identifier) => config.type === 'main' && + 'submodels' in config && config.submodels && (identifier in config.submodels || checkSubmodel(config.submodels, identifier)) ); @@ -332,7 +335,7 @@ export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is }; export const isFluxFillMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => { - return config.type === 'main' && config.base === 'flux' && config.variant === 'inpaint'; + return config.type === 'main' && config.base === 'flux' && config.variant === 'dev_fill'; }; export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => { diff --git a/scripts/classify-model.py b/scripts/classify-model.py index 2ae253b72f..a9129860a7 100755 --- a/scripts/classify-model.py +++ b/scripts/classify-model.py @@ -8,7 +8,7 @@ from typing import get_args from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe -from invokeai.backend.model_manager.config import ModelConfigFactory +from invokeai.backend.model_manager.configs.factory import ModelConfigFactory algos = ", ".join(set(get_args(HASHING_ALGORITHMS))) diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 03a7428382..f1249e4dc1 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -13,9 +13,9 @@ from torch import tensor from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelRepoVariant, ModelType, ModelVariantType from invokeai.backend.model_manager.config import ( AnyModelConfig, + Config_Base, InvalidModelConfigException, MainDiffusersConfig, - Config_Base, ModelConfigFactory, get_model_discriminator_value, )