diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py
index 84db65252e..7add1d09cf 100644
--- a/invokeai/app/api/routers/model_manager.py
+++ b/invokeai/app/api/routers/model_manager.py
@@ -28,9 +28,8 @@ from invokeai.app.services.model_records import (
UnknownModelException,
)
from invokeai.app.util.suppress_output import SuppressOutput
-from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
+from invokeai.backend.model_manager.configs.main import (
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
@@ -47,6 +46,7 @@ from invokeai.backend.model_manager.starter_models import (
StarterModelBundle,
StarterModelWithoutDependencies,
)
+from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
diff --git a/invokeai/app/invocations/cogview4_denoise.py b/invokeai/app/invocations/cogview4_denoise.py
index c0b962ba31..070d8a3478 100644
--- a/invokeai/app/invocations/cogview4_denoise.py
+++ b/invokeai/app/invocations/cogview4_denoise.py
@@ -22,7 +22,7 @@ from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
-from invokeai.backend.model_manager.config import BaseModelType
+from invokeai.backend.model_manager.taxonomy import BaseModelType
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import CogView4ConditioningInfo
diff --git a/invokeai/app/invocations/cogview4_model_loader.py b/invokeai/app/invocations/cogview4_model_loader.py
index 9db4f3c053..fbafcd345f 100644
--- a/invokeai/app/invocations/cogview4_model_loader.py
+++ b/invokeai/app/invocations/cogview4_model_loader.py
@@ -13,8 +13,7 @@ from invokeai.app.invocations.model import (
VAEField,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
-from invokeai.backend.model_manager.config import SubModelType
-from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
+from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
@invocation_output("cogview4_model_loader_output")
diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py
index f6e046d096..8a7e7c5231 100644
--- a/invokeai/app/invocations/create_gradient_mask.py
+++ b/invokeai/app/invocations/create_gradient_mask.py
@@ -20,9 +20,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
-from invokeai.backend.model_manager import LoadedModel
-from invokeai.backend.model_manager.config import Main_Config_Base
-from invokeai.backend.model_manager.taxonomy import ModelVariantType
+from invokeai.backend.model_manager.taxonomy import FluxVariantType, ModelType, ModelVariantType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@@ -182,10 +180,11 @@ class CreateGradientMaskInvocation(BaseInvocation):
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
- assert isinstance(main_model_config, Main_Config_Base)
- if main_model_config.variant is ModelVariantType.Inpaint:
+ assert main_model_config.type is ModelType.Main
+ variant = getattr(main_model_config, "variant", None)
+ if variant is ModelVariantType.Inpaint or variant is FluxVariantType.DevFill:
mask = dilated_mask_tensor
- vae_info: LoadedModel = context.models.load(self.vae.vae)
+ vae_info = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py
index 37b385914c..bb114263e2 100644
--- a/invokeai/app/invocations/denoise_latents.py
+++ b/invokeai/app/invocations/denoise_latents.py
@@ -39,7 +39,7 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.layer_patcher import LayerPatcher
diff --git a/invokeai/app/invocations/flux_ip_adapter.py b/invokeai/app/invocations/flux_ip_adapter.py
index c564023a3a..4a1997c512 100644
--- a/invokeai/app/invocations/flux_ip_adapter.py
+++ b/invokeai/app/invocations/flux_ip_adapter.py
@@ -16,9 +16,7 @@ from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
-from invokeai.backend.model_manager.config import (
- IPAdapter_Checkpoint_FLUX_Config,
-)
+from invokeai.backend.model_manager.configs.ip_adapter import IPAdapter_Checkpoint_FLUX_Config
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
diff --git a/invokeai/app/invocations/flux_model_loader.py b/invokeai/app/invocations/flux_model_loader.py
index 2803db48e0..eaac82bafc 100644
--- a/invokeai/app/invocations/flux_model_loader.py
+++ b/invokeai/app/invocations/flux_model_loader.py
@@ -14,9 +14,7 @@ from invokeai.app.util.t5_model_identifier import (
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.flux.util import get_flux_max_seq_length
-from invokeai.backend.model_manager.config import (
- Checkpoint_Config_Base,
-)
+from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
diff --git a/invokeai/app/invocations/flux_redux.py b/invokeai/app/invocations/flux_redux.py
index 3e34497b10..403d78b078 100644
--- a/invokeai/app/invocations/flux_redux.py
+++ b/invokeai/app/invocations/flux_redux.py
@@ -24,9 +24,9 @@ from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
-from invokeai.backend.model_manager import BaseModelType, ModelType
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.starter_models import siglip
+from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
from invokeai.backend.util.devices import TorchDevice
diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py
index 77b6187840..c395a0bf22 100644
--- a/invokeai/app/invocations/flux_text_encoder.py
+++ b/invokeai/app/invocations/flux_text_encoder.py
@@ -17,7 +17,7 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
-from invokeai.backend.model_manager import ModelFormat
+from invokeai.backend.model_manager.taxonomy import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
diff --git a/invokeai/app/invocations/flux_vae_encode.py b/invokeai/app/invocations/flux_vae_encode.py
index 2932517edc..4ec0365c2c 100644
--- a/invokeai/app/invocations/flux_vae_encode.py
+++ b/invokeai/app/invocations/flux_vae_encode.py
@@ -12,7 +12,7 @@ from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
-from invokeai.backend.model_manager import LoadedModel
+from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py
index 552f5edb1b..fde70a34fd 100644
--- a/invokeai/app/invocations/image_to_latents.py
+++ b/invokeai/app/invocations/image_to_latents.py
@@ -23,7 +23,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
-from invokeai.backend.model_manager import LoadedModel
+from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py
index 7c3234bdc7..2b2931e78f 100644
--- a/invokeai/app/invocations/ip_adapter.py
+++ b/invokeai/app/invocations/ip_adapter.py
@@ -11,8 +11,8 @@ from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
+from invokeai.backend.model_manager.configs.ip_adapter import (
IPAdapter_Checkpoint_Config_Base,
IPAdapter_InvokeAI_Config_Base,
)
diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py
index 327de6ac70..753ae77c55 100644
--- a/invokeai/app/invocations/model.py
+++ b/invokeai/app/invocations/model.py
@@ -12,9 +12,7 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
-)
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
diff --git a/invokeai/app/invocations/sd3_denoise.py b/invokeai/app/invocations/sd3_denoise.py
index f43f26ae0e..b9d69369b7 100644
--- a/invokeai/app/invocations/sd3_denoise.py
+++ b/invokeai/app/invocations/sd3_denoise.py
@@ -23,7 +23,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
-from invokeai.backend.model_manager import BaseModelType
+from invokeai.backend.model_manager.taxonomy import BaseModelType
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py
index fc0f0bb2c6..c70ef3fa16 100644
--- a/invokeai/app/services/events/events_base.py
+++ b/invokeai/app/services/events/events_base.py
@@ -44,8 +44,8 @@ if TYPE_CHECKING:
SessionQueueItem,
SessionQueueStatus,
)
- from invokeai.backend.model_manager import SubModelType
- from invokeai.backend.model_manager.config import AnyModelConfig
+ from invokeai.backend.model_manager.configs.factory import AnyModelConfig
+ from invokeai.backend.model_manager.taxonomy import SubModelType
class EventServiceBase:
diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py
index 8fbb08015a..2f99529398 100644
--- a/invokeai/app/services/events/events_common.py
+++ b/invokeai/app/services/events/events_common.py
@@ -16,8 +16,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
)
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
-from invokeai.backend.model_manager import SubModelType
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
+from invokeai.backend.model_manager.taxonomy import SubModelType
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py
index fea75d7375..67832466f3 100644
--- a/invokeai/app/services/model_install/model_install_common.py
+++ b/invokeai/app/services/model_install/model_install_common.py
@@ -10,7 +10,7 @@ from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
from invokeai.app.services.model_records import ModelRecordChanges
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py
index 10a954a563..53bb5cc12d 100644
--- a/invokeai/app/services/model_install/model_install_default.py
+++ b/invokeai/app/services/model_install/model_install_default.py
@@ -35,10 +35,9 @@ from invokeai.app.services.model_install.model_install_common import (
)
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
-from invokeai.backend.model_manager.config import (
+from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base
+from invokeai.backend.model_manager.configs.factory import (
AnyModelConfig,
- Checkpoint_Config_Base,
- InvalidModelConfigException,
ModelConfigFactory,
)
from invokeai.backend.model_manager.metadata import (
@@ -532,7 +531,7 @@ class ModelInstallService(ModelInstallServiceBase):
x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
):
install_job.set_error(
- InvalidModelConfigException(
+ ValueError(
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
@@ -602,7 +601,7 @@ class ModelInstallService(ModelInstallServiceBase):
return ModelConfigFactory.from_model_on_disk(
mod=model_path,
- overrides=deepcopy(fields),
+ override_fields=deepcopy(fields),
hash_algo=hash_algo,
)
diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py
index 8aae80e29d..87a405b4ea 100644
--- a/invokeai/app/services/model_load/model_load_base.py
+++ b/invokeai/app/services/model_load/model_load_base.py
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Optional
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py
index ad4ad97a02..2e2d2ae219 100644
--- a/invokeai/app/services/model_load/model_load_default.py
+++ b/invokeai/app/services/model_load/model_load_default.py
@@ -11,7 +11,7 @@ from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load import (
LoadedModel,
LoadedModelWithoutConfig,
diff --git a/invokeai/app/services/model_manager/__init__.py b/invokeai/app/services/model_manager/__init__.py
index aad67ff352..e703d4f1ff 100644
--- a/invokeai/app/services/model_manager/__init__.py
+++ b/invokeai/app/services/model_manager/__init__.py
@@ -1,12 +1,10 @@
"""Initialization file for model manager service."""
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService, ModelManagerServiceBase
-from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.load import LoadedModel
__all__ = [
"ModelManagerServiceBase",
"ModelManagerService",
- "AnyModelConfig",
"LoadedModel",
]
diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py
index 48f5317536..2d34832dbe 100644
--- a/invokeai/app/services/model_records/model_records_base.py
+++ b/invokeai/app/services/model_records/model_records_base.py
@@ -12,12 +12,10 @@ from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
- ControlAdapterDefaultSettings,
- LoraModelDefaultSettings,
- MainModelDefaultSettings,
-)
+from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
+from invokeai.backend.model_manager.configs.lora import LoraModelDefaultSettings
+from invokeai.backend.model_manager.configs.main import MainModelDefaultSettings
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py
index 7fad1761cc..6d9a33ba4a 100644
--- a/invokeai/app/services/model_records/model_records_sql.py
+++ b/invokeai/app/services/model_records/model_records_sql.py
@@ -58,10 +58,7 @@ from invokeai.app.services.model_records.model_records_base import (
)
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
- ModelConfigFactory,
-)
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
@@ -157,7 +154,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
record_as_dict[field_name] = getattr(changes, field_name)
# 3. create a new model config from the updated dict
- record = ModelConfigFactory.make_config(record_as_dict)
+ record = ModelConfigFactory.from_dict(record_as_dict)
# If we get this far, the updated model config is valid, so we can save it to the database.
json_serialized = record.model_dump_json()
@@ -187,7 +184,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
- SELECT config, strftime('%s',updated_at) FROM models
+ SELECT config FROM models
WHERE id=?;
""",
(key,),
@@ -195,14 +192,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
- model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
+ model = ModelConfigFactory.from_dict(json.loads(rows[0]))
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
- SELECT config, strftime('%s',updated_at) FROM models
+ SELECT config FROM models
WHERE hash=?;
""",
(hash,),
@@ -210,7 +207,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
- model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
+ model = ModelConfigFactory.from_dict(json.loads(rows[0]))
return model
def exists(self, key: str) -> bool:
@@ -278,7 +275,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
cursor.execute(
f"""--sql
- SELECT config, strftime('%s',updated_at)
+ SELECT config
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
@@ -291,7 +288,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
results: list[AnyModelConfig] = []
for row in result:
try:
- model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
+ model_config = ModelConfigFactory.from_dict(json.loads(row[0]))
except pydantic.ValidationError as e:
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
@@ -315,12 +312,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
- SELECT config, strftime('%s',updated_at) FROM models
+ SELECT config FROM models
WHERE path=?;
""",
(str(path),),
)
- results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
+ results = [ModelConfigFactory.from_dict(json.loads(x[0])) for x in cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
@@ -328,12 +325,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
- SELECT config, strftime('%s',updated_at) FROM models
+ SELECT config FROM models
WHERE hash=?;
""",
(hash,),
)
- results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
+ results = [ModelConfigFactory.from_dict(json.loads(x[0])) for x in cursor.fetchall()]
return results
def list_models(
diff --git a/invokeai/app/services/model_relationships/model_relationships_default.py b/invokeai/app/services/model_relationships/model_relationships_default.py
index 67fa6c0069..e4da482ff2 100644
--- a/invokeai/app/services/model_relationships/model_relationships_default.py
+++ b/invokeai/app/services/model_relationships/model_relationships_default.py
@@ -1,6 +1,6 @@
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
class ModelRelationshipsService(ModelRelationshipsServiceABC):
diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py
index 16aacbb985..97291230e0 100644
--- a/invokeai/app/services/shared/invocation_context.py
+++ b/invokeai/app/services/shared/invocation_context.py
@@ -19,10 +19,8 @@ from invokeai.app.services.model_records.model_records_base import UnknownModelE
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.step_callback import diffusion_step_callback
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
- Config_Base,
-)
+from invokeai.backend.model_manager.configs.base import Config_Base
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py
index 08b0e76068..1d9c81529e 100644
--- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py
+++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py
@@ -8,7 +8,7 @@ from pydantic import ValidationError
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
-from invokeai.backend.model_manager.config import AnyModelConfigValidator
+from invokeai.backend.model_manager.configs.factory import AnyModelConfigValidator
class NormalizeResult(NamedTuple):
diff --git a/invokeai/app/util/custom_openapi.py b/invokeai/app/util/custom_openapi.py
index 2e07622530..d400e0ff11 100644
--- a/invokeai/app/util/custom_openapi.py
+++ b/invokeai/app/util/custom_openapi.py
@@ -12,7 +12,7 @@ from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFie
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
-from invokeai.backend.model_manager.config import AnyModelConfigValidator
+from invokeai.backend.model_manager.configs.factory import AnyModelConfigValidator
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
diff --git a/invokeai/backend/flux/flux_state_dict_utils.py b/invokeai/backend/flux/flux_state_dict_utils.py
index 8ffab54c68..c306c88f96 100644
--- a/invokeai/backend/flux/flux_state_dict_utils.py
+++ b/invokeai/backend/flux/flux_state_dict_utils.py
@@ -1,10 +1,7 @@
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from invokeai.backend.model_manager.legacy_probe import CkptType
+from typing import Any
-def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None:
+def get_flux_in_channels_from_state_dict(state_dict: dict[str | int, Any]) -> int | None:
"""Gets the in channels from the state dict."""
# "Standard" FLUX models use "img_in.weight", but some community fine tunes use
diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py
index a167687d2e..e69de29bb2 100644
--- a/invokeai/backend/model_manager/__init__.py
+++ b/invokeai/backend/model_manager/__init__.py
@@ -1,45 +0,0 @@
-"""Re-export frequently-used symbols from the Model Manager backend."""
-
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
- InvalidModelConfigException,
- Config_Base,
- ModelConfigFactory,
-)
-from invokeai.backend.model_manager.legacy_probe import ModelProbe
-from invokeai.backend.model_manager.load import LoadedModel
-from invokeai.backend.model_manager.search import ModelSearch
-from invokeai.backend.model_manager.taxonomy import (
- AnyModel,
- AnyVariant,
- BaseModelType,
- ClipVariantType,
- ModelFormat,
- ModelRepoVariant,
- ModelSourceType,
- ModelType,
- ModelVariantType,
- SchedulerPredictionType,
- SubModelType,
-)
-
-__all__ = [
- "AnyModelConfig",
- "InvalidModelConfigException",
- "LoadedModel",
- "ModelConfigFactory",
- "ModelProbe",
- "ModelSearch",
- "Config_Base",
- "AnyModel",
- "AnyVariant",
- "BaseModelType",
- "ClipVariantType",
- "ModelFormat",
- "ModelRepoVariant",
- "ModelSourceType",
- "ModelType",
- "ModelVariantType",
- "SchedulerPredictionType",
- "SubModelType",
-]
diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py
deleted file mode 100644
index 188ac9ad11..0000000000
--- a/invokeai/backend/model_manager/config.py
+++ /dev/null
@@ -1,2584 +0,0 @@
-# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
-"""
-Configuration definitions for image generation models.
-
-Typical usage:
-
- from invokeai.backend.model_manager import ModelConfigFactory
- raw = dict(path='models/sd-1/main/foo.ckpt',
- name='foo',
- base='sd-1',
- type='main',
- config='configs/stable-diffusion/v1-inference.yaml',
- variant='normal',
- format='checkpoint'
- )
- config = ModelConfigFactory.make_config(raw)
- print(config.name)
-
-Validation errors will raise an InvalidModelConfigException error.
-
-"""
-
-import json
-import logging
-import re
-import time
-from abc import ABC
-from enum import Enum
-from functools import cache
-from inspect import isabstract
-from pathlib import Path
-from typing import (
- ClassVar,
- Literal,
- Optional,
- Self,
- Type,
- Union,
-)
-
-import torch
-from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError
-from pydantic_core import CoreSchema, PydanticUndefined, SchemaValidator
-from typing_extensions import Annotated, Any, Dict
-
-from invokeai.app.services.config.config_default import get_config
-from invokeai.app.util.misc import uuid_string
-from invokeai.backend.flux.controlnet.state_dict_utils import (
- is_state_dict_instantx_controlnet,
- is_state_dict_xlabs_controlnet,
-)
-from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
-from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
-from invokeai.backend.model_hash.hash_validator import validate_hash
-from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
-from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
-from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora
-from invokeai.backend.model_manager.taxonomy import (
- AnyVariant,
- BaseModelType,
- ClipVariantType,
- FluxLoRAFormat,
- FluxVariantType,
- ModelFormat,
- ModelRepoVariant,
- ModelSourceType,
- ModelType,
- ModelVariantType,
- SchedulerPredictionType,
- SubModelType,
- variant_type_adapter,
-)
-from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
-from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
-from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
-from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
-from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
-
-logger = logging.getLogger(__name__)
-app_config = get_config()
-
-
-class InvalidModelConfigException(Exception):
- """Exception for when config parser doesn't recognize this combination of model type and format."""
-
- pass
-
-
-class NotAMatch(Exception):
- """Exception for when a model does not match a config class.
-
- Args:
- config_class: The config class that was being tested.
- reason: The reason why the model did not match.
- """
-
- def __init__(
- self,
- config_class: type,
- reason: str,
- ):
- super().__init__(f"{config_class.__name__}: {reason}")
-
-
-DEFAULTS_PRECISION = Literal["fp16", "fp32"]
-
-
-class FieldValidator:
- """Utility class for validating individual fields of a Pydantic model without instantiating the whole model.
-
- See: https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
- """
-
- @staticmethod
- def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema:
- """Find the Pydantic core schema for a specific field in a model."""
- schema: CoreSchema = model.__pydantic_core_schema__.copy()
- # we shallow copied, be careful not to mutate the original schema!
-
- assert schema["type"] in ["definitions", "model"]
-
- # find the field schema
- field_schema = schema["schema"] # type: ignore
- while "fields" not in field_schema:
- field_schema = field_schema["schema"] # type: ignore
-
- field_schema = field_schema["fields"][field_name]["schema"] # type: ignore
-
- # if the original schema is a definition schema, replace the model schema with the field schema
- if schema["type"] == "definitions":
- schema["schema"] = field_schema
- return schema
- else:
- return field_schema
-
- @cache
- @staticmethod
- def get_validator(model: type[BaseModel], field_name: str) -> SchemaValidator:
- """Get a SchemaValidator for a specific field in a model."""
- return SchemaValidator(FieldValidator.find_field_schema(model, field_name))
-
- @staticmethod
- def validate_field(model: type[BaseModel], field_name: str, value: Any) -> Any:
- """Validate a value for a specific field in a model."""
- return FieldValidator.get_validator(model, field_name).validate_python(value)
-
-
-def has_any_keys(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool:
- """Returns true if the state dict has any of the specified keys."""
- _keys = {keys} if isinstance(keys, str) else keys
- return any(key in state_dict for key in _keys)
-
-
-def has_any_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool:
- """Returns true if the state dict has any keys starting with any of the specified prefixes."""
- _prefixes = {prefixes} if isinstance(prefixes, str) else prefixes
- return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str))
-
-
-def has_any_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool:
- """Returns true if the state dict has any keys ending with any of the specified suffixes."""
- _suffixes = {suffixes} if isinstance(suffixes, str) else suffixes
- return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str))
-
-
-def common_config_paths(path: Path) -> set[Path]:
- """Returns common config file paths for models stored in directories."""
- return {path / "config.json", path / "model_index.json"}
-
-
-# These utility functions are tightly coupled to the config classes below in order to make the process of raising
-# NotAMatch exceptions as easy and consistent as possible.
-
-
-def _get_config_or_raise(
- config_class: type,
- config_path: Path | set[Path],
-) -> dict[str, Any]:
- """Load the config file at the given path, or raise NotAMatch if it cannot be loaded."""
- paths_to_check = config_path if isinstance(config_path, set) else {config_path}
-
- problems: dict[Path, str] = {}
-
- for p in paths_to_check:
- if not p.exists():
- problems[p] = "file does not exist"
- continue
-
- try:
- with open(p, "r") as file:
- config = json.load(file)
-
- return config
- except Exception as e:
- problems[p] = str(e)
- continue
-
- raise NotAMatch(config_class, f"unable to load config file(s): {problems}")
-
-
-def _get_class_name_from_config(
- config_class: type,
- config_path: Path | set[Path],
-) -> str:
- """Load the config file and return the class name.
-
- Raises:
- NotAMatch if the config file is missing or does not contain a valid class name.
- """
-
- config = _get_config_or_raise(config_class, config_path)
-
- try:
- if "_class_name" in config:
- config_class_name = config["_class_name"]
- elif "architectures" in config:
- config_class_name = config["architectures"][0]
- else:
- raise ValueError("missing _class_name or architectures field")
- except Exception as e:
- raise NotAMatch(config_class, f"unable to determine class name from config file: {config_path}") from e
-
- if not isinstance(config_class_name, str):
- raise NotAMatch(config_class, f"_class_name or architectures field is not a string: {config_class_name}")
-
- return config_class_name
-
-
-def _validate_class_name(config_class: type[BaseModel], config_path: Path | set[Path], expected: set[str]) -> None:
- """Check if the class name in the config file matches the expected class names.
-
- Args:
- config_class: The config class that is being tested.
- config_path: The path to the config file.
- expected: The expected class names."""
-
- class_name = _get_class_name_from_config(config_class, config_path)
- if class_name not in expected:
- raise NotAMatch(config_class, f"invalid class name from config: {class_name}")
-
-
-def _validate_override_fields(
- config_class: type[BaseModel],
- override_fields: dict[str, Any],
-) -> None:
- """Check if the provided override fields are valid for the config class.
-
- Args:
- config_class: The config class that is being tested.
- override_fields: The override fields provided by the user.
-
- Raises:
- NotAMatch if any override field is invalid for the config.
- """
- for field_name, override_value in override_fields.items():
- if field_name not in config_class.model_fields:
- raise NotAMatch(config_class, f"unknown override field: {field_name}")
- try:
- FieldValidator.validate_field(config_class, field_name, override_value)
- except ValidationError as e:
- raise NotAMatch(config_class, f"invalid override for field '{field_name}': {e}") from e
-
-
-def _validate_is_file(
- config_class: type,
- mod: ModelOnDisk,
-) -> None:
- """Raise NotAMatch if the model path is not a file."""
- if not mod.path.is_file():
- raise NotAMatch(config_class, "model path is not a file")
-
-
-def _validate_is_dir(
- config_class: type,
- mod: ModelOnDisk,
-) -> None:
- """Raise NotAMatch if the model path is not a directory."""
- if not mod.path.is_dir():
- raise NotAMatch(config_class, "model path is not a directory")
-
-
-class SubmodelDefinition(BaseModel):
- path_or_prefix: str
- model_type: ModelType
- variant: AnyVariant | None = None
-
- model_config = ConfigDict(protected_namespaces=())
-
-
-class MainModelDefaultSettings(BaseModel):
- vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
- vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
- scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model")
- steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model")
- cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model")
- cfg_rescale_multiplier: float | None = Field(
- default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model"
- )
- width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model")
- height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
- guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model")
-
- model_config = ConfigDict(extra="forbid")
-
-
-class LoraModelDefaultSettings(BaseModel):
- weight: float | None = Field(default=None, ge=-1, le=2, description="Default weight for this model")
- model_config = ConfigDict(extra="forbid")
-
-
-class ControlAdapterDefaultSettings(BaseModel):
- # This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
- preprocessor: str | None
- model_config = ConfigDict(extra="forbid")
-
-
-class LegacyProbeMixin:
- """Mixin for classes using the legacy probe for model classification."""
-
- pass
-
-
-class Config_Base(ABC, BaseModel):
- """
- Abstract base class for model configurations. A model config describes a specific combination of model base, type and
- format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format
- would have base=sd-1, type=main, format=checkpoint.
-
- To create a new config type, inherit from this class and implement its interface:
- - Define method 'from_model_on_disk' that returns an instance of the class or raises NotAMatch. This method will be
- called during model installation to determine the correct config class for a model.
- - Define fields 'type', 'base' and 'format' as pydantic fields. These should be Literals with a single value. A
- default must be provided for each of these fields.
-
- If multiple combinations of base, type and format need to be supported, create a separate subclass for each.
-
- See MinimalConfigExample in test_model_probe.py for an example implementation.
- """
-
- key: str = Field(
- description="A unique key for this model.",
- default_factory=uuid_string,
- )
- hash: str = Field(
- description="The hash of the model file(s).",
- )
- path: str = Field(
- description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory.",
- )
- file_size: int = Field(
- description="The size of the model in bytes.",
- )
- name: str = Field(
- description="Name of the model.",
- )
- description: str | None = Field(
- description="Model description",
- default=None,
- )
- source: str = Field(
- description="The original source of the model (path, URL or repo_id).",
- )
- source_type: ModelSourceType = Field(
- description="The type of source",
- )
- source_api_response: str | None = Field(
- description="The original API response from the source, as stringified JSON.",
- default=None,
- )
- cover_image: str | None = Field(
- description="Url for image to preview model",
- default=None,
- )
- submodels: dict[SubModelType, SubmodelDefinition] | None = Field(
- description="Loadable submodels in this model",
- default=None,
- )
- usage_info: str | None = Field(
- default=None,
- description="Usage information for this model",
- )
-
- CONFIG_CLASSES: ClassVar[set[Type["AnyModelConfig"]]] = set()
-
- model_config = ConfigDict(
- validate_assignment=True,
- json_schema_serialization_defaults_required=True,
- json_schema_mode_override="serialization",
- )
-
- @classmethod
- def __init_subclass__(cls, **kwargs):
- super().__init_subclass__(**kwargs)
- # Register non-abstract subclasses so we can iterate over them later during model probing.
- if not isabstract(cls):
- cls.CONFIG_CLASSES.add(cls)
-
- @classmethod
- def __pydantic_init_subclass__(cls, **kwargs):
- # Ensure that subclasses define 'base', 'type' and 'format' fields and provide defaults for them. Each subclass
- # is expected to represent a single combination of base, type and format.
- for name in ("type", "base", "format"):
- assert name in cls.model_fields, f"{cls.__name__} must define a '{name}' field"
- assert cls.model_fields[name].default is not PydanticUndefined, (
- f"{cls.__name__} must define a default for the '{name}' field"
- )
-
- @classmethod
- def get_tag(cls) -> Tag:
- """Constructs a pydantic discriminated union tag for this model config class. When a config is deserialized,
- pydantic uses the tag to determine which subclass to instantiate.
-
- The tag is a dot-separated string of the type, format, base and variant (if applicable).
- """
- tag_strings: list[str] = []
- for name in ("type", "format", "base", "variant"):
- if field := cls.model_fields.get(name):
- if field.default is not PydanticUndefined:
- # We expect each of these fields has an Enum for its default; we want the value of the enum.
- tag_strings.append(field.default.value)
- return Tag(".".join(tag_strings))
-
- @staticmethod
- def get_model_discriminator_value(v: Any) -> str:
- """
- Computes the discriminator value for a model config.
- https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
- """
- if isinstance(v, Config_Base):
- # We have an instance of a ModelConfigBase subclass - use its tag directly.
- return v.get_tag().tag
- if isinstance(v, dict):
- # We have a dict - compute the tag from its fields.
- tag_strings: list[str] = []
- if type_ := v.get("type"):
- if isinstance(type_, Enum):
- type_ = type_.value
- tag_strings.append(type_)
-
- if format_ := v.get("format"):
- if isinstance(format_, Enum):
- format_ = format_.value
- tag_strings.append(format_)
-
- if base_ := v.get("base"):
- if isinstance(base_, Enum):
- base_ = base_.value
- tag_strings.append(base_)
-
- # Special case: CLIP Embed models also need the variant to distinguish them.
- if (
- type_ == ModelType.CLIPEmbed.value
- and format_ == ModelFormat.Diffusers.value
- and base_ == BaseModelType.Any.value
- ):
- if variant_value := v.get("variant"):
- if isinstance(variant_value, Enum):
- variant_value = variant_value.value
- tag_strings.append(variant_value)
- else:
- raise ValueError("CLIP Embed model config dict must include a 'variant' field")
-
- return ".".join(tag_strings)
- else:
- raise TypeError("Model config discriminator value must be computed from a dict or ModelConfigBase instance")
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- """Given the model on disk and any overrides, return an instance of this config class.
-
- Implementations should raise NotAMatch if the model does not match this config class."""
- raise NotImplementedError(f"from_model_on_disk not implemented for {cls.__name__}")
-
-
-class Unknown_Config(Config_Base):
- """Model config for unknown models, used as a fallback when we cannot identify a model."""
-
- base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown)
- type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown)
- format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- raise NotAMatch(cls, "unknown model config cannot match any model")
-
-
-class Checkpoint_Config_Base(ABC, BaseModel):
- """Base class for checkpoint-style models."""
-
- config_path: str | None = Field(
- description="Path to the config for this model, if any.",
- default=None,
- )
- converted_at: float | None = Field(
- description="When this model was last converted to diffusers",
- default_factory=time.time,
- )
-
-
-class Diffusers_Config_Base(ABC, BaseModel):
- """Base class for diffusers-style models."""
-
- format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
- repo_variant: Optional[ModelRepoVariant] = Field(ModelRepoVariant.Default)
-
- @classmethod
- def _get_repo_variant_or_raise(cls, mod: ModelOnDisk) -> ModelRepoVariant:
- # get all files ending in .bin or .safetensors
- weight_files = list(mod.path.glob("**/*.safetensors"))
- weight_files.extend(list(mod.path.glob("**/*.bin")))
- for x in weight_files:
- if ".fp16" in x.suffixes:
- return ModelRepoVariant.FP16
- if "openvino_model" in x.name:
- return ModelRepoVariant.OpenVINO
- if "flax_model" in x.name:
- return ModelRepoVariant.Flax
- if x.suffix == ".onnx":
- return ModelRepoVariant.ONNX
- return ModelRepoVariant.Default
-
-
-class T5Encoder_T5Encoder_Config(Config_Base):
- base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
- type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
- format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- _validate_class_name(
- cls,
- common_config_paths(mod.path),
- {
- "T5EncoderModel",
- },
- )
-
- cls._validate_has_unquantized_config_file(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_has_unquantized_config_file(cls, mod: ModelOnDisk) -> None:
- has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists()
-
- if not has_unquantized_config:
- raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json")
-
-
-class T5Encoder_BnBLLMint8_Config(Config_Base):
- base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
- type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
- format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- _validate_class_name(
- cls,
- common_config_paths(mod.path),
- {
- "T5EncoderModel",
- },
- )
-
- cls._validate_filename_looks_like_bnb_quantized(mod)
-
- cls._validate_model_looks_like_bnb_quantized(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_filename_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
- filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix())
- if not filename_looks_like_bnb:
- raise NotAMatch(cls, "filename does not look like bnb quantized llm_int8")
-
- @classmethod
- def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
- has_scb_key_suffix = has_any_keys_ending_with(mod.load_state_dict(), "SCB")
- if not has_scb_key_suffix:
- raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8")
-
-
-class LoRA_Config_Base(ABC, BaseModel):
- """Base class for LoRA models."""
-
- type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
- trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
- default_settings: Optional[LoraModelDefaultSettings] = Field(
- description="Default settings for this model", default=None
- )
-
-
-def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
- # TODO(psyche): Moving this import to the function to avoid circular imports. Refactor later.
- from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
-
- state_dict = mod.load_state_dict(mod.path)
- value = flux_format_from_state_dict(state_dict, mod.metadata())
- return value
-
-
-class LoRA_OMI_Config_Base(LoRA_Config_Base):
- format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_looks_like_omi_lora(mod)
-
- cls._validate_base(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _validate_looks_like_omi_lora(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model metadata does not look like an OMI LoRA."""
- flux_format = _get_flux_lora_format(mod)
- if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
- raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA")
-
- metadata = mod.metadata()
-
- metadata_looks_like_omi_lora = (
- bool(metadata.get("modelspec.sai_model_spec"))
- and metadata.get("ot_branch") == "omi_format"
- and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora"
- )
-
- if not metadata_looks_like_omi_lora:
- raise NotAMatch(cls, "metadata does not look like OMI LoRA")
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk) -> Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL]:
- metadata = mod.metadata()
- architecture = metadata["modelspec.architecture"]
-
- if architecture == stable_diffusion_xl_1_lora:
- return BaseModelType.StableDiffusionXL
- elif architecture == flux_dev_1_lora:
- return BaseModelType.Flux
- else:
- raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}")
-
-
-class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, Config_Base):
- base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
-
-
-class LoRA_LyCORIS_Config_Base(LoRA_Config_Base):
- """Model config for LoRA/Lycoris models."""
-
- type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
- format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_looks_like_lora(mod)
-
- cls._validate_base(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
- # First rule out ControlLoRA and Diffusers LoRA
- flux_format = _get_flux_lora_format(mod)
- if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
- raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA")
-
- # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
- # Some main models have these keys, likely due to the creator merging in a LoRA.
- has_key_with_lora_prefix = has_any_keys_starting_with(
- mod.load_state_dict(),
- {
- "lora_te_",
- "lora_unet_",
- "lora_te1_",
- "lora_te2_",
- "lora_transformer_",
- },
- )
-
- has_key_with_lora_suffix = has_any_keys_ending_with(
- mod.load_state_dict(),
- {
- "to_k_lora.up.weight",
- "to_q_lora.down.weight",
- "lora_A.weight",
- "lora_B.weight",
- },
- )
-
- if not has_key_with_lora_prefix and not has_key_with_lora_suffix:
- raise NotAMatch(cls, "model does not match LyCORIS LoRA heuristics")
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
- if _get_flux_lora_format(mod):
- return BaseModelType.Flux
-
- state_dict = mod.load_state_dict()
- # If we've gotten here, we assume that the model is a Stable Diffusion model
- token_vector_length = lora_token_vector_length(state_dict)
- if token_vector_length == 768:
- return BaseModelType.StableDiffusion1
- elif token_vector_length == 1024:
- return BaseModelType.StableDiffusion2
- elif token_vector_length == 1280:
- return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
- elif token_vector_length == 2048:
- return BaseModelType.StableDiffusionXL
- else:
- raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}")
-
-
-class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
-
-
-class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
-
-
-class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base):
- base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
-
-
-class ControlAdapter_Config_Base(ABC, BaseModel):
- default_settings: ControlAdapterDefaultSettings | None = Field(None)
-
-
-class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapter_Config_Base, Config_Base):
- """Model config for Control LoRA models."""
-
- base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
- type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa)
- format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS)
-
- trigger_phrases: set[str] | None = Field(None)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_looks_like_control_lora(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_looks_like_control_lora(cls, mod: ModelOnDisk) -> None:
- state_dict = mod.load_state_dict()
-
- if not is_state_dict_likely_flux_control(state_dict):
- raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA")
-
-
-class LoRA_Diffusers_Config_Base(LoRA_Config_Base):
- """Model config for LoRA/Diffusers models."""
-
- # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
- # the weights format. FLUX Diffusers LoRAs are single files.
-
- format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_base(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
- if _get_flux_lora_format(mod):
- return BaseModelType.Flux
-
- # If we've gotten here, we assume that the LoRA is a Stable Diffusion LoRA
- path_to_weight_file = cls._get_weight_file_or_raise(mod)
- state_dict = mod.load_state_dict(path_to_weight_file)
- token_vector_length = lora_token_vector_length(state_dict)
-
- match token_vector_length:
- case 768:
- return BaseModelType.StableDiffusion1
- case 1024:
- return BaseModelType.StableDiffusion2
- case 1280:
- return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
- case 2048:
- return BaseModelType.StableDiffusionXL
- case _:
- raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}")
-
- @classmethod
- def _get_weight_file_or_raise(cls, mod: ModelOnDisk) -> Path:
- suffixes = ["bin", "safetensors"]
- weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
- for wf in weight_files:
- if wf.exists():
- return wf
- raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
-
-
-class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
-
-
-class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
-
-
-class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
-
-
-class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
- """Model config for standalone VAE models."""
-
- type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
- format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
-
- REGEX_TO_BASE: ClassVar[dict[str, BaseModelType]] = {
- r"xl": BaseModelType.StableDiffusionXL,
- r"sd2": BaseModelType.StableDiffusion2,
- r"vae": BaseModelType.StableDiffusion1,
- r"FLUX.1-schnell_ae": BaseModelType.Flux,
- }
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_looks_like_vae(mod)
-
- cls._validate_base(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
- if not has_any_keys_starting_with(
- mod.load_state_dict(),
- {
- "encoder.conv_in",
- "decoder.conv_in",
- },
- ):
- raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
- # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
- for regexp, base in cls.REGEX_TO_BASE.items():
- if re.search(regexp, mod.path.name, re.IGNORECASE):
- return base
-
- raise NotAMatch(cls, "cannot determine base type")
-
-
-class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
-
-
-class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
-
-
-class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
-
-
-class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
- """Model config for standalone VAE models (diffusers version)."""
-
- type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
- format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- _validate_class_name(
- cls,
- common_config_paths(mod.path),
- {
- "AutoencoderKL",
- "AutoencoderTiny",
- },
- )
-
- cls._validate_base(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool:
- # Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
- return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
-
- @classmethod
- def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool:
- # Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
- # by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best
- # we can do is guess based on name.
- return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE))
-
- @classmethod
- def _guess_name(cls, mod: ModelOnDisk) -> str:
- name = mod.path.name
- if name == "vae":
- name = mod.path.parent.name
- return name
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
- config = _get_config_or_raise(cls, common_config_paths(mod.path))
- if cls._config_looks_like_sdxl(config):
- return BaseModelType.StableDiffusionXL
- elif cls._name_looks_like_sdxl(mod):
- return BaseModelType.StableDiffusionXL
- else:
- # TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO.
- return BaseModelType.StableDiffusion1
-
-
-class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
-
-
-class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base, ControlAdapter_Config_Base):
- """Model config for ControlNet models (diffusers version)."""
-
- type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
- format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- _validate_class_name(
- cls,
- common_config_paths(mod.path),
- {
- "ControlNetModel",
- "FluxControlNetModel",
- },
- )
-
- cls._validate_base(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
- config = _get_config_or_raise(cls, common_config_paths(mod.path))
-
- if config.get("_class_name") == "FluxControlNetModel":
- return BaseModelType.Flux
-
- dimension = config.get("cross_attention_dim")
-
- match dimension:
- case 768:
- return BaseModelType.StableDiffusion1
- case 1024:
- # No obvious way to distinguish between sd2-base and sd2-768, but we don't really differentiate them
- # anyway.
- return BaseModelType.StableDiffusion2
- case 2048:
- return BaseModelType.StableDiffusionXL
- case _:
- raise NotAMatch(cls, f"unrecognized cross_attention_dim {dimension}")
-
-
-class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
-
-
-class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
-
-
-class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
-
-
-class ControlNet_Checkpoint_Config_Base(Checkpoint_Config_Base, ControlAdapter_Config_Base):
- """Model config for ControlNet models (diffusers version)."""
-
- type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
- format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_looks_like_controlnet(mod)
-
- cls._validate_base(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None:
- if has_any_keys_starting_with(
- mod.load_state_dict(),
- {
- "controlnet",
- "control_model",
- "input_blocks",
- # XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
- # For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
- # TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
- # "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
- # delicate.
- "controlnet_blocks",
- },
- ):
- raise NotAMatch(cls, "state dict does not look like a ControlNet checkpoint")
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
- state_dict = mod.load_state_dict()
-
- if is_state_dict_xlabs_controlnet(state_dict) or is_state_dict_instantx_controlnet(state_dict):
- # TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
- # get_format()?
- return BaseModelType.Flux
-
- for key in (
- "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
- "controlnet_mid_block.bias",
- "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
- "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
- ):
- if key not in state_dict:
- continue
- width = state_dict[key].shape[-1]
- match width:
- case 768:
- return BaseModelType.StableDiffusion1
- case 1024:
- return BaseModelType.StableDiffusion2
- case 2048:
- return BaseModelType.StableDiffusionXL
- case 1280:
- return BaseModelType.StableDiffusionXL
- case _:
- pass
-
- raise NotAMatch(cls, "unable to determine base type from state dict")
-
-
-class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
-
-
-class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
-
-
-class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
-
-
-class TI_Config_Base(ABC, BaseModel):
- type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion)
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk, path: Path | None = None) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod, path)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
- try:
- p = path or mod.path
-
- if not p.exists():
- return False
-
- if p.is_dir():
- return False
-
- if p.name in [f"learned_embeds.{s}" for s in mod.weight_files()]:
- return True
-
- state_dict = mod.load_state_dict(p)
-
- # Heuristic: textual inversion embeddings have these keys
- if any(key in {"string_to_param", "emb_params", "clip_g"} for key in state_dict.keys()):
- return True
-
- # Heuristic: small state dict with all tensor values
- if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()):
- return True
-
- return False
- except Exception:
- return False
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
- p = path or mod.path
-
- try:
- state_dict = mod.load_state_dict(p)
- except Exception as e:
- raise NotAMatch(cls, f"unable to load state dict from {p}: {e}") from e
-
- try:
- if "string_to_token" in state_dict:
- token_dim = list(state_dict["string_to_param"].values())[0].shape[-1]
- elif "emb_params" in state_dict:
- token_dim = state_dict["emb_params"].shape[-1]
- elif "clip_g" in state_dict:
- token_dim = state_dict["clip_g"].shape[-1]
- else:
- token_dim = list(state_dict.values())[0].shape[0]
- except Exception as e:
- raise NotAMatch(cls, f"unable to determine token dimension from state dict in {p}: {e}") from e
-
- match token_dim:
- case 768:
- return BaseModelType.StableDiffusion1
- case 1024:
- return BaseModelType.StableDiffusion2
- case 1280:
- return BaseModelType.StableDiffusionXL
- case _:
- raise NotAMatch(cls, f"unrecognized token dimension {token_dim}")
-
-
-class TI_File_Config_Base(TI_Config_Base):
- """Model config for textual inversion embeddings."""
-
- format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- if not cls._file_looks_like_embedding(mod):
- raise NotAMatch(cls, "model does not look like a textual inversion embedding file")
-
- cls._validate_base(mod)
-
- return cls(**fields)
-
-
-class TI_File_SD1_Config(TI_File_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
-
-
-class TI_File_SD2_Config(TI_File_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
-
-
-class TI_File_SDXL_Config(TI_File_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class TI_Folder_Config_Base(TI_Config_Base):
- """Model config for textual inversion embeddings."""
-
- format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- for p in mod.weight_files():
- if cls._file_looks_like_embedding(mod, p):
- cls._validate_base(mod, p)
- return cls(**fields)
-
- raise NotAMatch(cls, "model does not look like a textual inversion embedding folder")
-
-
-class TI_Folder_SD1_Config(TI_Folder_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
-
-
-class TI_Folder_SD2_Config(TI_Folder_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
-
-
-class TI_Folder_SDXL_Config(TI_Folder_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class Main_Config_Base(ABC, BaseModel):
- type: Literal[ModelType.Main] = Field(default=ModelType.Main)
- trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
- default_settings: Optional[MainModelDefaultSettings] = Field(
- description="Default settings for this model", default=None
- )
-
-
-def _has_bnb_nf4_keys(state_dict: dict[str | int, Any]) -> bool:
- bnb_nf4_keys = {
- "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4",
- "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4",
- }
- return any(key in state_dict for key in bnb_nf4_keys)
-
-
-def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
- return any(isinstance(v, GGMLTensor) for v in state_dict.values())
-
-
-def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
- for key in state_dict.keys():
- if isinstance(key, int):
- continue
- elif key.startswith(
- (
- "cond_stage_model.",
- "first_stage_model.",
- "model.diffusion_model.",
- # Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
- # This prefix is typically used to distinguish between multiple models bundled in a single file.
- "model.diffusion_model.double_blocks.",
- )
- ):
- return True
- elif key.startswith("double_blocks.") and "ip_adapter" not in key:
- # FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be
- # careful to avoid false positives on XLabs FLUX IP-Adapter models.
- return True
- return False
-
-
-class Main_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base):
- """Model config for main checkpoint models."""
-
- format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
-
- prediction_type: SchedulerPredictionType = Field()
- variant: ModelVariantType = Field()
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_looks_like_main_model(mod)
-
- cls._validate_base(mod)
-
- prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod)
-
- variant = fields.get("variant") or cls._get_variant_or_raise(mod)
-
- return cls(**fields, prediction_type=prediction_type, variant=variant)
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
- state_dict = mod.load_state_dict()
-
- key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
- if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
- return BaseModelType.StableDiffusion1
- if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
- return BaseModelType.StableDiffusion2
-
- key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
- if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
- return BaseModelType.StableDiffusionXL
- elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
- return BaseModelType.StableDiffusionXLRefiner
-
- raise NotAMatch(cls, "unable to determine base type from state dict")
-
- @classmethod
- def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType:
- base = cls.model_fields["base"].default
-
- if base is BaseModelType.StableDiffusion2:
- state_dict = mod.load_state_dict()
- key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
- if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
- if "global_step" in state_dict:
- if state_dict["global_step"] == 220000:
- return SchedulerPredictionType.Epsilon
- elif state_dict["global_step"] == 110000:
- return SchedulerPredictionType.VPrediction
- return SchedulerPredictionType.VPrediction
- else:
- return SchedulerPredictionType.Epsilon
-
- @classmethod
- def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType:
- base = cls.model_fields["base"].default
-
- state_dict = mod.load_state_dict()
- key_name = "model.diffusion_model.input_blocks.0.0.weight"
-
- if key_name not in state_dict:
- raise NotAMatch(cls, "unable to determine model variant from state dict")
-
- in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
-
- match in_channels:
- case 4:
- return ModelVariantType.Normal
- case 5:
- # Only SD2 has a depth variant
- assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'"
- return ModelVariantType.Depth
- case 9:
- return ModelVariantType.Inpaint
- case _:
- raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
-
- @classmethod
- def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
- has_main_model_keys = _has_main_keys(mod.load_state_dict())
- if not has_main_model_keys:
- raise NotAMatch(cls, "state dict does not look like a main model")
-
-
-class Main_Checkpoint_SD1_Config(Main_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
-
-
-class Main_Checkpoint_SD2_Config(Main_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
-
-
-class Main_Checkpoint_SDXL_Config(Main_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class Main_Checkpoint_SDXLRefiner_Config(Main_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner)
-
-
-def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None:
- # FLUX Model variant types are distinguished by input channels and the presence of certain keys.
-
- # Input channels are derived from the shape of either "img_in.weight" or "model.diffusion_model.img_in.weight".
- #
- # Known models that use the latter key:
- # - https://civitai.com/models/885098?modelVersionId=990775
- # - https://civitai.com/models/1018060?modelVersionId=1596255
- # - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133
- #
- # Input channels for known FLUX models:
- # - Unquantized Dev and Schnell have in_channels=64
- # - BNB-NF4 Dev and Schnell have in_channels=1
- # - FLUX Fill has in_channels=384
- # - Unsure of quantized FLUX Fill models
- # - Unsure of GGUF-quantized models
-
- in_channels = None
- for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
- if key in state_dict:
- in_channels = state_dict[key].shape[1]
- break
-
- if in_channels is None:
- # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
- # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
- # model, we should figure out a good fallback value.
- return None
-
- # Because FLUX Dev and Schnell models have the same in_channels, we need to check for the presence of
- # certain keys to distinguish between them.
- is_flux_dev = (
- "guidance_in.out_layer.weight" in state_dict
- or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
- )
-
- if is_flux_dev and in_channels == 384:
- return FluxVariantType.DevFill
- elif is_flux_dev:
- return FluxVariantType.Dev
- else:
- # Must be a Schnell model...?
- return FluxVariantType.Schnell
-
-
-class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
- """Model config for main checkpoint models."""
-
- format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
- base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
-
- variant: FluxVariantType = Field()
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_looks_like_main_model(mod)
-
- cls._validate_is_flux(mod)
-
- cls._validate_does_not_look_like_bnb_quantized(mod)
-
- cls._validate_does_not_look_like_gguf_quantized(mod)
-
- variant = fields.get("variant") or cls._get_variant_or_raise(mod)
-
- return cls(**fields, variant=variant)
-
- @classmethod
- def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
- if not has_any_keys(
- mod.load_state_dict(),
- {
- "double_blocks.0.img_attn.norm.key_norm.scale",
- "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
- },
- ):
- raise NotAMatch(cls, "state dict does not look like a FLUX checkpoint")
-
- @classmethod
- def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
- # FLUX Model variant types are distinguished by input channels and the presence of certain keys.
- state_dict = mod.load_state_dict()
- variant = _get_flux_variant(state_dict)
-
- if variant is None:
- # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
- # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
- # model, we should figure out a good fallback value.
- raise NotAMatch(cls, "unable to determine model variant from state dict")
-
- return variant
-
- @classmethod
- def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
- has_main_model_keys = _has_main_keys(mod.load_state_dict())
- if not has_main_model_keys:
- raise NotAMatch(cls, "state dict does not look like a main model")
-
- @classmethod
- def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
- has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
- if has_bnb_nf4_keys:
- raise NotAMatch(cls, "state dict looks like bnb quantized nf4")
-
- @classmethod
- def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk):
- has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
- if has_ggml_tensors:
- raise NotAMatch(cls, "state dict looks like GGUF quantized")
-
-
-class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
- """Model config for main checkpoint models."""
-
- base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
- format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b)
-
- variant: FluxVariantType = Field()
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_looks_like_main_model(mod)
-
- cls._validate_model_looks_like_bnb_quantized(mod)
-
- variant = fields.get("variant") or cls._get_variant_or_raise(mod)
-
- return cls(**fields, variant=variant)
-
- @classmethod
- def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
- # FLUX Model variant types are distinguished by input channels and the presence of certain keys.
- state_dict = mod.load_state_dict()
- variant = _get_flux_variant(state_dict)
-
- if variant is None:
- # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
- # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
- # model, we should figure out a good fallback value.
- raise NotAMatch(cls, "unable to determine model variant from state dict")
-
- return variant
-
- @classmethod
- def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
- has_main_model_keys = _has_main_keys(mod.load_state_dict())
- if not has_main_model_keys:
- raise NotAMatch(cls, "state dict does not look like a main model")
-
- @classmethod
- def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
- has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
- if not has_bnb_nf4_keys:
- raise NotAMatch(cls, "state dict does not look like bnb quantized nf4")
-
-
-class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
- """Model config for main checkpoint models."""
-
- base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
- format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
-
- variant: FluxVariantType = Field()
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_looks_like_main_model(mod)
-
- cls._validate_looks_like_gguf_quantized(mod)
-
- variant = fields.get("variant") or cls._get_variant_or_raise(mod)
-
- return cls(**fields, variant=variant)
-
- @classmethod
- def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
- # FLUX Model variant types are distinguished by input channels and the presence of certain keys.
- state_dict = mod.load_state_dict()
- variant = _get_flux_variant(state_dict)
-
- if variant is None:
- # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
- # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
- # model, we should figure out a good fallback value.
- raise NotAMatch(cls, "unable to determine model variant from state dict")
-
- return variant
-
- @classmethod
- def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
- has_main_model_keys = _has_main_keys(mod.load_state_dict())
- if not has_main_model_keys:
- raise NotAMatch(cls, "state dict does not look like a main model")
-
- @classmethod
- def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
- has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
- if not has_ggml_tensors:
- raise NotAMatch(cls, "state dict does not look like GGUF quantized")
-
-
-class Main_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base):
- prediction_type: SchedulerPredictionType = Field()
- variant: ModelVariantType = Field()
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- _validate_class_name(
- cls,
- common_config_paths(mod.path),
- {
- # SD 1.x and 2.x
- "StableDiffusionPipeline",
- "StableDiffusionInpaintPipeline",
- # SDXL
- "StableDiffusionXLPipeline",
- "StableDiffusionXLInpaintPipeline",
- # SDXL Refiner
- "StableDiffusionXLImg2ImgPipeline",
- # TODO(psyche): Do we actually support LCM models? I don't see using this class anywhere in the codebase.
- "LatentConsistencyModelPipeline",
- },
- )
-
- cls._validate_base(mod)
-
- variant = fields.get("variant") or cls._get_variant_or_raise(mod)
-
- prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod)
-
- repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
-
- return cls(
- **fields,
- variant=variant,
- prediction_type=prediction_type,
- repo_variant=repo_variant,
- )
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
- # Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL).
- unet_config_path = mod.path / "unet" / "config.json"
- if unet_config_path.exists():
- with open(unet_config_path) as file:
- unet_conf = json.load(file)
- cross_attention_dim = unet_conf.get("cross_attention_dim")
- match cross_attention_dim:
- case 768:
- return BaseModelType.StableDiffusion1
- case 1024:
- return BaseModelType.StableDiffusion2
- case 1280:
- return BaseModelType.StableDiffusionXLRefiner
- case 2048:
- return BaseModelType.StableDiffusionXL
- case _:
- raise NotAMatch(cls, f"unrecognized cross_attention_dim {cross_attention_dim}")
-
- raise NotAMatch(cls, "unable to determine base type")
-
- @classmethod
- def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType:
- scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json")
-
- # TODO(psyche): Is epsilon the right default or should we raise if it's not present?
- prediction_type = scheduler_conf.get("prediction_type", "epsilon")
-
- match prediction_type:
- case "v_prediction":
- return SchedulerPredictionType.VPrediction
- case "epsilon":
- return SchedulerPredictionType.Epsilon
- case _:
- raise NotAMatch(cls, f"unrecognized scheduler prediction_type {prediction_type}")
-
- @classmethod
- def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType:
- base = cls.model_fields["base"].default
- unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json")
- in_channels = unet_config.get("in_channels")
-
- match in_channels:
- case 4:
- return ModelVariantType.Normal
- case 5:
- # Only SD2 has a depth variant
- assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'"
- return ModelVariantType.Depth
- case 9:
- return ModelVariantType.Inpaint
- case _:
- raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
-
-
-class Main_Diffusers_SD1_Config(Main_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(BaseModelType.StableDiffusion1)
-
-
-class Main_Diffusers_SD2_Config(Main_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion2] = Field(BaseModelType.StableDiffusion2)
-
-
-class Main_Diffusers_SDXL_Config(Main_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(BaseModelType.StableDiffusionXL)
-
-
-class Main_Diffusers_SDXLRefiner_Config(Main_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(BaseModelType.StableDiffusionXLRefiner)
-
-
-class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- # This check implies the base type - no further validation needed.
- _validate_class_name(
- cls,
- common_config_paths(mod.path),
- {
- "StableDiffusion3Pipeline",
- "SD3Transformer2DModel",
- },
- )
-
- submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod)
-
- repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
-
- return cls(
- **fields,
- submodels=submodels,
- repo_variant=repo_variant,
- )
-
- @classmethod
- def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]:
- # Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
- config = _get_config_or_raise(cls, common_config_paths(mod.path))
-
- submodels: dict[SubModelType, SubmodelDefinition] = {}
-
- for key, value in config.items():
- # Anything that starts with an underscore is top-level metadata, not a submodel
- if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
- continue
- # The key is something like "transformer" and is a submodel - it will be in a dir of the same name.
- # The value value is something like ["diffusers", "SD3Transformer2DModel"]
- _library_name, class_name = value
-
- match class_name:
- case "CLIPTextModelWithProjection":
- model_type = ModelType.CLIPEmbed
- path_or_prefix = (mod.path / key).resolve().as_posix()
-
- # We need to read the config to determine the variant of the CLIP model.
- clip_embed_config = _get_config_or_raise(
- cls, {mod.path / key / "config.json", mod.path / key / "model_index.json"}
- )
- variant = _get_clip_variant_type_from_config(clip_embed_config)
- submodels[SubModelType(key)] = SubmodelDefinition(
- path_or_prefix=path_or_prefix,
- model_type=model_type,
- variant=variant,
- )
- case "SD3Transformer2DModel":
- model_type = ModelType.Main
- path_or_prefix = (mod.path / key).resolve().as_posix()
- variant = None
- submodels[SubModelType(key)] = SubmodelDefinition(
- path_or_prefix=path_or_prefix,
- model_type=model_type,
- variant=variant,
- )
- case _:
- pass
-
- return submodels
-
-
-class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
- base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- # This check implies the base type - no further validation needed.
- _validate_class_name(
- cls,
- common_config_paths(mod.path),
- {
- "CogView4Pipeline",
- },
- )
-
- repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
-
- return cls(
- **fields,
- repo_variant=repo_variant,
- )
-
-
-class IPAdapter_Config_Base(ABC, BaseModel):
- type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
-
-
-class IPAdapter_InvokeAI_Config_Base(IPAdapter_Config_Base):
- """Model config for IP Adapter diffusers format models."""
-
- format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI)
-
- # TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
- # time. Need to go through the history to make sure I'm understanding this fully.
- image_encoder_model_id: str = Field()
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_has_weights_file(mod)
-
- cls._validate_has_image_encoder_metadata_file(mod)
-
- cls._validate_base(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _validate_has_weights_file(cls, mod: ModelOnDisk) -> None:
- weights_file = mod.path / "ip_adapter.bin"
- if not weights_file.exists():
- raise NotAMatch(cls, "missing ip_adapter.bin weights file")
-
- @classmethod
- def _validate_has_image_encoder_metadata_file(cls, mod: ModelOnDisk) -> None:
- image_encoder_metadata_file = mod.path / "image_encoder.txt"
- if not image_encoder_metadata_file.exists():
- raise NotAMatch(cls, "missing image_encoder.txt metadata file")
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
- state_dict = mod.load_state_dict()
-
- try:
- cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
- except Exception as e:
- raise NotAMatch(cls, f"unable to determine cross attention dimension: {e}") from e
-
- match cross_attention_dim:
- case 1280:
- return BaseModelType.StableDiffusionXL
- case 768:
- return BaseModelType.StableDiffusion1
- case 1024:
- return BaseModelType.StableDiffusion2
- case _:
- raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
-
-
-class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
-
-
-class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
-
-
-class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class IPAdapter_Checkpoint_Config_Base(IPAdapter_Config_Base):
- """Model config for IP Adapter checkpoint format models."""
-
- format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_looks_like_ip_adapter(mod)
-
- cls._validate_base(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None:
- if not has_any_keys_starting_with(
- mod.load_state_dict(),
- {
- "image_proj.",
- "ip_adapter.",
- # XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
- "ip_adapter_proj_model.",
- },
- ):
- raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics")
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
- state_dict = mod.load_state_dict()
-
- if is_state_dict_xlabs_ip_adapter(state_dict):
- return BaseModelType.Flux
-
- try:
- cross_attention_dim = state_dict["ip_adapter.1.to_k_ip.weight"].shape[-1]
- except Exception as e:
- raise NotAMatch(cls, f"unable to determine cross attention dimension: {e}") from e
-
- match cross_attention_dim:
- case 1280:
- return BaseModelType.StableDiffusionXL
- case 768:
- return BaseModelType.StableDiffusion1
- case 1024:
- return BaseModelType.StableDiffusion2
- case _:
- raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
-
-
-class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
-
-
-class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
-
-
-class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
- base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
-
-
-def _get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None:
- try:
- hidden_size = config.get("hidden_size")
- match hidden_size:
- case 1280:
- return ClipVariantType.G
- case 768:
- return ClipVariantType.L
- case _:
- return None
- except Exception:
- return None
-
-
-class CLIPEmbed_Diffusers_Config_Base(Diffusers_Config_Base):
- base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
- type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
- format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- _validate_class_name(
- cls,
- {
- mod.path / "config.json",
- mod.path / "text_encoder" / "config.json",
- },
- {
- "CLIPModel",
- "CLIPTextModel",
- "CLIPTextModelWithProjection",
- },
- )
-
- cls._validate_variant(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_variant(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model variant does not match this config class."""
- expected_variant = cls.model_fields["variant"].default
- config = _get_config_or_raise(
- cls,
- {
- mod.path / "config.json",
- mod.path / "text_encoder" / "config.json",
- },
- )
- recognized_variant = _get_clip_variant_type_from_config(config)
-
- if recognized_variant is None:
- raise NotAMatch(cls, "unable to determine CLIP variant from config")
-
- if expected_variant is not recognized_variant:
- raise NotAMatch(cls, f"variant is {recognized_variant}, not {expected_variant}")
-
-
-class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base):
- variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G)
-
-
-class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base):
- variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L)
-
-
-class CLIPVision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
- """Model config for CLIPVision."""
-
- base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
- type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision)
- format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- _validate_class_name(
- cls,
- common_config_paths(mod.path),
- {
- "CLIPVisionModelWithProjection",
- },
- )
-
- return cls(**fields)
-
-
-class T2IAdapter_Diffusers_Config_Base(Diffusers_Config_Base, ControlAdapter_Config_Base):
- """Model config for T2I."""
-
- type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
- format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- _validate_class_name(
- cls,
- common_config_paths(mod.path),
- {
- "T2IAdapter",
- },
- )
-
- cls._validate_base(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_base(cls, mod: ModelOnDisk) -> None:
- """Raise `NotAMatch` if the model base does not match this config class."""
- expected_base = cls.model_fields["base"].default
- recognized_base = cls._get_base_or_raise(mod)
- if expected_base is not recognized_base:
- raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
-
- @classmethod
- def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
- config = _get_config_or_raise(cls, common_config_paths(mod.path))
-
- adapter_type = config.get("adapter_type")
-
- match adapter_type:
- case "full_adapter_xl":
- return BaseModelType.StableDiffusionXL
- case "full_adapter" | "light_adapter":
- return BaseModelType.StableDiffusion1
- case _:
- raise NotAMatch(cls, f"unrecognized adapter_type '{adapter_type}'")
-
-
-class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
-
-
-class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, Config_Base):
- base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
-
-
-class Spandrel_Checkpoint_Config(Config_Base):
- """Model config for Spandrel Image to Image models."""
-
- base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
- type: Literal[ModelType.SpandrelImageToImage] = Field(default=ModelType.SpandrelImageToImage)
- format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- cls._validate_spandrel_loads_model(mod)
-
- return cls(**fields)
-
- @classmethod
- def _validate_spandrel_loads_model(cls, mod: ModelOnDisk) -> None:
- try:
- # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
- # explored to avoid this:
- # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
- # device. Unfortunately, some Spandrel models perform operations during initialization that are not
- # supported on meta tensors.
- # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
- # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
- # maintain it, and the risk of false positive detections is higher.
- SpandrelImageToImageModel.load_from_file(mod.path)
- except Exception as e:
- raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e
-
-
-class SigLIP_Diffusers_Config(Diffusers_Config_Base, Config_Base):
- """Model config for SigLIP."""
-
- type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP)
- format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
- base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- _validate_class_name(
- cls,
- common_config_paths(mod.path),
- {
- "SiglipModel",
- },
- )
-
- return cls(**fields)
-
-
-class FLUXRedux_Checkpoint_Config(Config_Base):
- """Model config for FLUX Tools Redux model."""
-
- type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux)
- format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
- base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_file(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- if not is_state_dict_likely_flux_redux(mod.load_state_dict()):
- raise NotAMatch(cls, "model does not match FLUX Tools Redux heuristics")
-
- return cls(**fields)
-
-
-class LlavaOnevision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
- """Model config for Llava Onevision models."""
-
- type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision)
- base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
- variant: Literal[ModelVariantType.Normal] = Field(default=ModelVariantType.Normal)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- _validate_is_dir(cls, mod)
-
- _validate_override_fields(cls, fields)
-
- _validate_class_name(
- cls,
- common_config_paths(mod.path),
- {
- "LlavaOnevisionForConditionalGeneration",
- },
- )
-
- return cls(**fields)
-
-
-class ExternalAPI_Config_Base(ABC, BaseModel):
- """Model config for API-based models."""
-
- format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api)
-
- @classmethod
- def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
- raise NotAMatch(cls, "External API models cannot be built from disk")
-
-
-class ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
- base: Literal[BaseModelType.ChatGPT4o] = Field(default=BaseModelType.ChatGPT4o)
-
-
-class ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
- base: Literal[BaseModelType.Gemini2_5] = Field(default=BaseModelType.Gemini2_5)
-
-
-class ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
- base: Literal[BaseModelType.Imagen3] = Field(default=BaseModelType.Imagen3)
-
-
-class ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
- base: Literal[BaseModelType.Imagen4] = Field(default=BaseModelType.Imagen4)
-
-
-class ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
- base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
-
-
-class VideoConfigBase(ABC, BaseModel):
- type: Literal[ModelType.Video] = Field(default=ModelType.Video)
- trigger_phrases: set[str] | None = Field(description="Set of trigger phrases for this model", default=None)
- default_settings: MainModelDefaultSettings | None = Field(
- description="Default settings for this model", default=None
- )
-
-
-class ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, VideoConfigBase, Config_Base):
- base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
-
-
-class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, Config_Base):
- base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
-
-
-# The types are listed explicitly because IDEs/LSPs can't identify the correct types
-# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
-AnyModelConfig = Annotated[
- Union[
- # Main (Pipeline) - diffusers format
- Annotated[Main_Diffusers_SD1_Config, Main_Diffusers_SD1_Config.get_tag()],
- Annotated[Main_Diffusers_SD2_Config, Main_Diffusers_SD2_Config.get_tag()],
- Annotated[Main_Diffusers_SDXL_Config, Main_Diffusers_SDXL_Config.get_tag()],
- Annotated[Main_Diffusers_SDXLRefiner_Config, Main_Diffusers_SDXLRefiner_Config.get_tag()],
- Annotated[Main_Diffusers_SD3_Config, Main_Diffusers_SD3_Config.get_tag()],
- Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()],
- # Main (Pipeline) - checkpoint format
- Annotated[Main_Checkpoint_SD1_Config, Main_Checkpoint_SD1_Config.get_tag()],
- Annotated[Main_Checkpoint_SD2_Config, Main_Checkpoint_SD2_Config.get_tag()],
- Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()],
- Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()],
- Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()],
- # Main (Pipeline) - quantized formats
- Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()],
- Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()],
- # VAE - checkpoint format
- Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()],
- Annotated[VAE_Checkpoint_SD2_Config, VAE_Checkpoint_SD2_Config.get_tag()],
- Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()],
- Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()],
- # VAE - diffusers format
- Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()],
- Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()],
- # ControlNet - checkpoint format
- Annotated[ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD1_Config.get_tag()],
- Annotated[ControlNet_Checkpoint_SD2_Config, ControlNet_Checkpoint_SD2_Config.get_tag()],
- Annotated[ControlNet_Checkpoint_SDXL_Config, ControlNet_Checkpoint_SDXL_Config.get_tag()],
- Annotated[ControlNet_Checkpoint_FLUX_Config, ControlNet_Checkpoint_FLUX_Config.get_tag()],
- # ControlNet - diffusers format
- Annotated[ControlNet_Diffusers_SD1_Config, ControlNet_Diffusers_SD1_Config.get_tag()],
- Annotated[ControlNet_Diffusers_SD2_Config, ControlNet_Diffusers_SD2_Config.get_tag()],
- Annotated[ControlNet_Diffusers_SDXL_Config, ControlNet_Diffusers_SDXL_Config.get_tag()],
- Annotated[ControlNet_Diffusers_FLUX_Config, ControlNet_Diffusers_FLUX_Config.get_tag()],
- # LoRA - LyCORIS format
- Annotated[LoRA_LyCORIS_SD1_Config, LoRA_LyCORIS_SD1_Config.get_tag()],
- Annotated[LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SD2_Config.get_tag()],
- Annotated[LoRA_LyCORIS_SDXL_Config, LoRA_LyCORIS_SDXL_Config.get_tag()],
- Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()],
- # LoRA - OMI format
- Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()],
- Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()],
- # LoRA - diffusers format
- Annotated[LoRA_Diffusers_SD1_Config, LoRA_Diffusers_SD1_Config.get_tag()],
- Annotated[LoRA_Diffusers_SD2_Config, LoRA_Diffusers_SD2_Config.get_tag()],
- Annotated[LoRA_Diffusers_SDXL_Config, LoRA_Diffusers_SDXL_Config.get_tag()],
- Annotated[LoRA_Diffusers_FLUX_Config, LoRA_Diffusers_FLUX_Config.get_tag()],
- # ControlLoRA - diffusers format
- Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()],
- # T5 Encoder - all formats
- Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()],
- Annotated[T5Encoder_BnBLLMint8_Config, T5Encoder_BnBLLMint8_Config.get_tag()],
- # TI - file format
- Annotated[TI_File_SD1_Config, TI_File_SD1_Config.get_tag()],
- Annotated[TI_File_SD2_Config, TI_File_SD2_Config.get_tag()],
- Annotated[TI_File_SDXL_Config, TI_File_SDXL_Config.get_tag()],
- # TI - folder format
- Annotated[TI_Folder_SD1_Config, TI_Folder_SD1_Config.get_tag()],
- Annotated[TI_Folder_SD2_Config, TI_Folder_SD2_Config.get_tag()],
- Annotated[TI_Folder_SDXL_Config, TI_Folder_SDXL_Config.get_tag()],
- # IP Adapter - InvokeAI format
- Annotated[IPAdapter_InvokeAI_SD1_Config, IPAdapter_InvokeAI_SD1_Config.get_tag()],
- Annotated[IPAdapter_InvokeAI_SD2_Config, IPAdapter_InvokeAI_SD2_Config.get_tag()],
- Annotated[IPAdapter_InvokeAI_SDXL_Config, IPAdapter_InvokeAI_SDXL_Config.get_tag()],
- # IP Adapter - checkpoint format
- Annotated[IPAdapter_Checkpoint_SD1_Config, IPAdapter_Checkpoint_SD1_Config.get_tag()],
- Annotated[IPAdapter_Checkpoint_SD2_Config, IPAdapter_Checkpoint_SD2_Config.get_tag()],
- Annotated[IPAdapter_Checkpoint_SDXL_Config, IPAdapter_Checkpoint_SDXL_Config.get_tag()],
- Annotated[IPAdapter_Checkpoint_FLUX_Config, IPAdapter_Checkpoint_FLUX_Config.get_tag()],
- # T2I Adapter - diffusers format
- Annotated[T2IAdapter_Diffusers_SD1_Config, T2IAdapter_Diffusers_SD1_Config.get_tag()],
- Annotated[T2IAdapter_Diffusers_SDXL_Config, T2IAdapter_Diffusers_SDXL_Config.get_tag()],
- # Misc models
- Annotated[Spandrel_Checkpoint_Config, Spandrel_Checkpoint_Config.get_tag()],
- Annotated[CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_G_Config.get_tag()],
- Annotated[CLIPEmbed_Diffusers_L_Config, CLIPEmbed_Diffusers_L_Config.get_tag()],
- Annotated[CLIPVision_Diffusers_Config, CLIPVision_Diffusers_Config.get_tag()],
- Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()],
- Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()],
- Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()],
- # API models
- Annotated[ExternalAPI_ChatGPT4o_Config, ExternalAPI_ChatGPT4o_Config.get_tag()],
- Annotated[ExternalAPI_Gemini2_5_Config, ExternalAPI_Gemini2_5_Config.get_tag()],
- Annotated[ExternalAPI_Imagen3_Config, ExternalAPI_Imagen3_Config.get_tag()],
- Annotated[ExternalAPI_Imagen4_Config, ExternalAPI_Imagen4_Config.get_tag()],
- Annotated[ExternalAPI_FluxKontext_Config, ExternalAPI_FluxKontext_Config.get_tag()],
- Annotated[ExternalAPI_Veo3_Config, ExternalAPI_Veo3_Config.get_tag()],
- Annotated[ExternalAPI_Runway_Config, ExternalAPI_Runway_Config.get_tag()],
- # Unknown model (fallback)
- Annotated[Unknown_Config, Unknown_Config.get_tag()],
- ],
- Discriminator(Config_Base.get_model_discriminator_value),
-]
-
-AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig)
-
-
-class ModelConfigFactory:
- @staticmethod
- def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig:
- """Return the appropriate config object from raw dict values."""
- model = AnyModelConfigValidator.validate_python(model_data)
- if isinstance(model, Checkpoint_Config_Base) and timestamp:
- model.converted_at = timestamp
- validate_hash(model.hash)
- return model
-
- @staticmethod
- def build_common_fields(
- mod: ModelOnDisk,
- overrides: dict[str, Any] | None = None,
- ) -> dict[str, Any]:
- """Builds the common fields for all model configs.
-
- Args:
- mod: The model on disk to extract fields from.
- overrides: A optional dictionary of fields to override. These fields will take precedence over the values
- extracted from the model on disk.
-
- - Casts string fields to their Enum types.
- - Does not validate the fields against the model config schema.
- """
-
- _overrides: dict[str, Any] = overrides or {}
- fields: dict[str, Any] = {}
-
- if "type" in _overrides:
- fields["type"] = ModelType(_overrides["type"])
-
- if "format" in _overrides:
- fields["format"] = ModelFormat(_overrides["format"])
-
- if "base" in _overrides:
- fields["base"] = BaseModelType(_overrides["base"])
-
- if "source_type" in _overrides:
- fields["source_type"] = ModelSourceType(_overrides["source_type"])
-
- if "variant" in _overrides:
- fields["variant"] = variant_type_adapter.validate_strings(_overrides["variant"])
-
- fields["path"] = mod.path.as_posix()
- fields["source"] = _overrides.get("source") or fields["path"]
- fields["source_type"] = _overrides.get("source_type") or ModelSourceType.Path
- fields["name"] = _overrides.get("name") or mod.name
- fields["hash"] = _overrides.get("hash") or mod.hash()
- fields["key"] = _overrides.get("key") or uuid_string()
- fields["description"] = _overrides.get("description")
- fields["file_size"] = _overrides.get("file_size") or mod.size()
-
- return fields
-
- @staticmethod
- def from_model_on_disk(
- mod: str | Path | ModelOnDisk,
- overrides: dict[str, Any] | None = None,
- hash_algo: HASHING_ALGORITHMS = "blake3_single",
- ) -> AnyModelConfig:
- """
- Returns the best matching ModelConfig instance from a model's file/folder path.
- Raises InvalidModelConfigException if no valid configuration is found.
- Created to deprecate ModelProbe.probe
- """
- if isinstance(mod, Path | str):
- mod = ModelOnDisk(Path(mod), hash_algo)
-
- # We will always need these fields to build any model config.
- fields = ModelConfigFactory.build_common_fields(mod, overrides)
-
- # Store results as a mapping of config class to either an instance of that class or an exception
- # that was raised when trying to build it.
- results: dict[str, AnyModelConfig | Exception] = {}
-
- # Try to build an instance of each model config class that uses the classify API.
- # Each class will either return an instance of itself or raise NotAMatch if it doesn't match.
- # Other exceptions may be raised if something unexpected happens during matching or building.
- for config_class in Config_Base.CONFIG_CLASSES:
- class_name = config_class.__name__
- try:
- instance = config_class.from_model_on_disk(mod, fields)
- results[class_name] = instance
- except NotAMatch as e:
- results[class_name] = e
- logger.debug(f"No match for {config_class.__name__} on model {mod.name}")
- except ValidationError as e:
- # This means the model matched, but we couldn't create the pydantic model instance for the config.
- # Maybe invalid overrides were provided?
- results[class_name] = e
- logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}")
- except Exception as e:
- results[class_name] = e
- logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
-
- matches = [r for r in results.values() if isinstance(r, Config_Base)]
-
- if not matches and app_config.allow_unknown_models:
- logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config")
- return Unknown_Config(**fields)
-
- if len(matches) > 1:
- # We have multiple matches, in which case at most 1 is correct. We need to pick one.
- #
- # Known cases:
- # - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
- # - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
- # a config.json file. Prefer the main model.
-
- # Sort the matching according to known special cases.
- def sort_key(m: AnyModelConfig) -> int:
- match m.type:
- case ModelType.Main:
- return 0
- case ModelType.LoRA:
- return 1
- case ModelType.CLIPEmbed:
- return 2
- case _:
- return 3
-
- matches.sort(key=sort_key)
- logger.warning(
- f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(matches[0]).__name__}."
- )
-
- instance = matches[0]
- logger.info(f"Model {mod.name} classified as {type(instance).__name__}")
- return instance
diff --git a/invokeai/backend/model_manager/configs/base.py b/invokeai/backend/model_manager/configs/base.py
index 9e997e4bcd..e67efd2009 100644
--- a/invokeai/backend/model_manager/configs/base.py
+++ b/invokeai/backend/model_manager/configs/base.py
@@ -191,8 +191,8 @@ class Config_Base(ABC, BaseModel):
else:
raise TypeError("Model config discriminator value must be computed from a dict or ModelConfigBase instance")
- @abstractmethod
@classmethod
+ @abstractmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
"""Given the model on disk and any override fields, attempt to construct an instance of this config class.
diff --git a/invokeai/backend/model_manager/configs/controlnet.py b/invokeai/backend/model_manager/configs/controlnet.py
index 7db782a862..630e81fd24 100644
--- a/invokeai/backend/model_manager/configs/controlnet.py
+++ b/invokeai/backend/model_manager/configs/controlnet.py
@@ -3,14 +3,13 @@ from typing import (
Self,
)
-from pydantic import Field
+from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Any
from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
-from invokeai.backend.model_manager.config import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
@@ -29,6 +28,42 @@ from invokeai.backend.model_manager.taxonomy import (
ModelType,
)
+MODEL_NAME_TO_PREPROCESSOR = {
+ "canny": "canny_image_processor",
+ "mlsd": "mlsd_image_processor",
+ "depth": "depth_anything_image_processor",
+ "bae": "normalbae_image_processor",
+ "normal": "normalbae_image_processor",
+ "sketch": "pidi_image_processor",
+ "scribble": "lineart_image_processor",
+ "lineart anime": "lineart_anime_image_processor",
+ "lineart_anime": "lineart_anime_image_processor",
+ "lineart": "lineart_image_processor",
+ "soft": "hed_image_processor",
+ "softedge": "hed_image_processor",
+ "hed": "hed_image_processor",
+ "shuffle": "content_shuffle_image_processor",
+ "pose": "dw_openpose_image_processor",
+ "mediapipe": "mediapipe_face_processor",
+ "pidi": "pidi_image_processor",
+ "zoe": "zoe_depth_image_processor",
+ "color": "color_map_image_processor",
+}
+
+
+class ControlAdapterDefaultSettings(BaseModel):
+ # This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
+ preprocessor: str | None
+ model_config = ConfigDict(extra="forbid")
+
+ @classmethod
+ def from_model_name(cls, model_name: str) -> Self:
+ for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
+ model_name_lower = model_name.lower()
+ if k in model_name_lower:
+ return cls(preprocessor=v)
+ return cls(preprocessor=None)
+
class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base):
"""Model config for ControlNet models (diffusers version)."""
diff --git a/invokeai/backend/model_manager/configs/factory.py b/invokeai/backend/model_manager/configs/factory.py
index 27b6f252f1..6ab16cd5f6 100644
--- a/invokeai/backend/model_manager/configs/factory.py
+++ b/invokeai/backend/model_manager/configs/factory.py
@@ -14,6 +14,7 @@ from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.clip_embed import CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_L_Config
from invokeai.backend.model_manager.configs.clip_vision import CLIPVision_Diffusers_Config
from invokeai.backend.model_manager.configs.controlnet import (
+ ControlAdapterDefaultSettings,
ControlNet_Checkpoint_FLUX_Config,
ControlNet_Checkpoint_SD1_Config,
ControlNet_Checkpoint_SD2_Config,
@@ -47,6 +48,7 @@ from invokeai.backend.model_manager.configs.lora import (
LoRA_LyCORIS_SDXL_Config,
LoRA_OMI_FLUX_Config,
LoRA_OMI_SDXL_Config,
+ LoraModelDefaultSettings,
)
from invokeai.backend.model_manager.configs.main import (
Main_BnBNF4_FLUX_Config,
@@ -67,6 +69,7 @@ from invokeai.backend.model_manager.configs.main import (
Main_ExternalAPI_Imagen3_Config,
Main_ExternalAPI_Imagen4_Config,
Main_GGUF_FLUX_Config,
+ MainModelDefaultSettings,
Video_ExternalAPI_Runway_Config,
Video_ExternalAPI_Veo3_Config,
)
@@ -332,9 +335,52 @@ class ModelConfigFactory:
matches.sort(key=sort_key)
logger.warning(
- f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(matches[0]).__name__}."
+ f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}."
)
instance = matches[0]
logger.info(f"Model {mod.name} classified as {type(instance).__name__}")
+
+ # Now do any post-processing needed for specific model types/bases/etc.
+ match instance.type:
+ case ModelType.Main:
+ match instance.base:
+ case BaseModelType.StableDiffusion1:
+ instance.default_settings = MainModelDefaultSettings(width=512, height=512)
+ case BaseModelType.StableDiffusion2:
+ instance.default_settings = MainModelDefaultSettings(width=768, height=768)
+ case BaseModelType.StableDiffusionXL:
+ instance.default_settings = MainModelDefaultSettings(width=1024, height=1024)
+ case _:
+ pass
+ case ModelType.ControlNet | ModelType.T2IAdapter | ModelType.ControlLoRa:
+ instance.default_settings = ControlAdapterDefaultSettings.from_model_name(instance.name)
+ case ModelType.LoRA:
+ instance.default_settings = LoraModelDefaultSettings()
+ case _:
+ pass
+
return instance
+
+
+MODEL_NAME_TO_PREPROCESSOR = {
+ "canny": "canny_image_processor",
+ "mlsd": "mlsd_image_processor",
+ "depth": "depth_anything_image_processor",
+ "bae": "normalbae_image_processor",
+ "normal": "normalbae_image_processor",
+ "sketch": "pidi_image_processor",
+ "scribble": "lineart_image_processor",
+ "lineart anime": "lineart_anime_image_processor",
+ "lineart_anime": "lineart_anime_image_processor",
+ "lineart": "lineart_image_processor",
+ "soft": "hed_image_processor",
+ "softedge": "hed_image_processor",
+ "hed": "hed_image_processor",
+ "shuffle": "content_shuffle_image_processor",
+ "pose": "dw_openpose_image_processor",
+ "mediapipe": "mediapipe_face_processor",
+ "pidi": "pidi_image_processor",
+ "zoe": "zoe_depth_image_processor",
+ "color": "color_map_image_processor",
+}
diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py
index 512137b4c3..24e10c035a 100644
--- a/invokeai/backend/model_manager/configs/lora.py
+++ b/invokeai/backend/model_manager/configs/lora.py
@@ -9,10 +9,10 @@ from typing import (
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Any
-from invokeai.backend.model_manager.config import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.base import (
Config_Base,
)
+from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_override_fields,
diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py
index ef1ab1fe77..26f6b5b60e 100644
--- a/invokeai/backend/model_manager/configs/main.py
+++ b/invokeai/backend/model_manager/configs/main.py
@@ -685,8 +685,8 @@ class Video_Config_Base(ABC, BaseModel):
class Video_ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, Video_Config_Base, Config_Base):
- base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
+ base: Literal[BaseModelType.Veo3] = Field(default=BaseModelType.Veo3)
class Video_ExternalAPI_Runway_Config(ExternalAPI_Config_Base, Video_Config_Base, Config_Base):
- base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
+ base: Literal[BaseModelType.Runway] = Field(default=BaseModelType.Runway)
diff --git a/invokeai/backend/model_manager/configs/t2i_adapter.py b/invokeai/backend/model_manager/configs/t2i_adapter.py
index 865c4dc763..a1da40e9b4 100644
--- a/invokeai/backend/model_manager/configs/t2i_adapter.py
+++ b/invokeai/backend/model_manager/configs/t2i_adapter.py
@@ -6,8 +6,8 @@ from typing import (
from pydantic import Field
from typing_extensions import Any
-from invokeai.backend.model_manager.config import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
+from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
common_config_paths,
diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py
deleted file mode 100644
index 85a39fd25e..0000000000
--- a/invokeai/backend/model_manager/legacy_probe.py
+++ /dev/null
@@ -1,1034 +0,0 @@
-import json
-from pathlib import Path
-from typing import Any, Callable, Dict, Literal, Optional, Union
-
-import picklescan.scanner as pscan
-import safetensors.torch
-import torch
-
-import invokeai.backend.util.logging as logger
-from invokeai.app.services.config.config_default import get_config
-from invokeai.app.util.misc import uuid_string
-from invokeai.backend.flux.controlnet.state_dict_utils import (
- is_state_dict_instantx_controlnet,
- is_state_dict_xlabs_controlnet,
-)
-from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict
-from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
-from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
-from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
- ControlAdapterDefaultSettings,
- InvalidModelConfigException,
- LoraModelDefaultSettings,
- MainModelDefaultSettings,
- ModelConfigFactory,
- SubmodelDefinition,
-)
-from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
-from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
-from invokeai.backend.model_manager.taxonomy import (
- AnyVariant,
- BaseModelType,
- FluxVariantType,
- ModelFormat,
- ModelRepoVariant,
- ModelSourceType,
- ModelType,
- ModelVariantType,
- SchedulerPredictionType,
- SubModelType,
-)
-from invokeai.backend.model_manager.util.model_util import (
- get_clip_variant_type,
- lora_token_vector_length,
- read_checkpoint_meta,
-)
-from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
-from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
- is_state_dict_likely_in_flux_diffusers_format,
-)
-from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
- is_state_dict_likely_in_flux_kohya_format,
-)
-from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
- is_state_dict_likely_in_flux_onetrainer_format,
-)
-from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
-from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
-from invokeai.backend.util.silence_warnings import SilenceWarnings
-
-CkptType = Dict[str | int, Any]
-
-
-LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
- BaseModelType.StableDiffusion1: {
- ModelVariantType.Normal: {
- SchedulerPredictionType.Epsilon: "v1-inference.yaml",
- SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
- },
- ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
- },
- BaseModelType.StableDiffusion2: {
- ModelVariantType.Normal: {
- SchedulerPredictionType.Epsilon: "v2-inference.yaml",
- SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
- },
- ModelVariantType.Inpaint: {
- SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
- SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
- },
- ModelVariantType.Depth: "v2-midas-inference.yaml",
- },
- BaseModelType.StableDiffusionXL: {
- ModelVariantType.Normal: "sd_xl_base.yaml",
- ModelVariantType.Inpaint: "sd_xl_inpaint.yaml",
- },
- BaseModelType.StableDiffusionXLRefiner: {
- ModelVariantType.Normal: "sd_xl_refiner.yaml",
- },
-}
-
-
-class ProbeBase(object):
- """Base class for probes."""
-
- def __init__(self, model_path: Path):
- self.model_path = model_path
-
- def get_base_type(self) -> BaseModelType:
- """Get model base type."""
- raise NotImplementedError
-
- def get_format(self) -> ModelFormat:
- """Get model file format."""
- raise NotImplementedError
-
- def get_variant_type(self) -> AnyVariant | None:
- """Get model variant type."""
- return None
-
- def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
- """Get model scheduler prediction type."""
- return None
-
- def get_image_encoder_model_id(self) -> Optional[str]:
- """Get image encoder (IP adapters only)."""
- return None
-
-
-class ModelProbe(object):
- PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
- "diffusers": {},
- "checkpoint": {},
- "onnx": {},
- }
-
- CLASS2TYPE = {
- "FluxPipeline": ModelType.Main,
- "StableDiffusionPipeline": ModelType.Main,
- "StableDiffusionInpaintPipeline": ModelType.Main,
- "StableDiffusionXLPipeline": ModelType.Main,
- "StableDiffusionXLImg2ImgPipeline": ModelType.Main,
- "StableDiffusionXLInpaintPipeline": ModelType.Main,
- "StableDiffusion3Pipeline": ModelType.Main,
- "LatentConsistencyModelPipeline": ModelType.Main,
- "AutoencoderKL": ModelType.VAE,
- "AutoencoderTiny": ModelType.VAE,
- "ControlNetModel": ModelType.ControlNet,
- "CLIPVisionModelWithProjection": ModelType.CLIPVision,
- "T2IAdapter": ModelType.T2IAdapter,
- "CLIPModel": ModelType.CLIPEmbed,
- "CLIPTextModel": ModelType.CLIPEmbed,
- "T5EncoderModel": ModelType.T5Encoder,
- "FluxControlNetModel": ModelType.ControlNet,
- "SD3Transformer2DModel": ModelType.Main,
- "CLIPTextModelWithProjection": ModelType.CLIPEmbed,
- "SiglipModel": ModelType.SigLIP,
- "LlavaOnevisionForConditionalGeneration": ModelType.LlavaOnevision,
- "CogView4Pipeline": ModelType.Main,
- }
-
- TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
-
- @classmethod
- def register_probe(
- cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
- ) -> None:
- cls.PROBES[format][model_type] = probe_class
-
- @classmethod
- def probe(
- cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3_single"
- ) -> AnyModelConfig:
- """
- Probe the model at model_path and return its configuration record.
-
- :param model_path: Path to the model file (checkpoint) or directory (diffusers).
- :param fields: An optional dictionary that can be used to override probed
- fields. Typically used for fields that don't probe well, such as prediction_type.
-
- Returns: The appropriate model configuration derived from ModelConfigBase.
- """
- if fields is None:
- fields = {}
-
- model_path = model_path.resolve()
-
- format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
- model_info = None
- model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
- if not model_type:
- if format_type is ModelFormat.Diffusers:
- model_type = cls.get_model_type_from_folder(model_path)
- else:
- model_type = cls.get_model_type_from_checkpoint(model_path)
- format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
-
- probe_class = cls.PROBES[format_type].get(model_type)
- if not probe_class:
- raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
-
- probe = probe_class(model_path)
-
- fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
- fields["source"] = fields.get("source") or model_path.as_posix()
- fields["key"] = fields.get("key", uuid_string())
- fields["path"] = model_path.as_posix()
- fields["type"] = fields.get("type") or model_type
- fields["base"] = fields.get("base") or probe.get_base_type()
- variant_func = cls.TYPE2VARIANT.get(fields["type"], None)
- fields["variant"] = (
- fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type()
- )
- fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
- fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
- fields["name"] = fields.get("name") or cls.get_model_name(model_path)
- fields["description"] = (
- fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
- )
- fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
- fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
- fields["file_size"] = fields.get("file_size") or ModelOnDisk(model_path).size()
-
- fields["default_settings"] = fields.get("default_settings")
-
- if not fields["default_settings"]:
- if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter, ModelType.ControlLoRa}:
- fields["default_settings"] = get_default_settings_control_adapters(fields["name"])
- if fields["type"] in {ModelType.LoRA}:
- fields["default_settings"] = get_default_settings_lora()
- elif fields["type"] is ModelType.Main:
- fields["default_settings"] = get_default_settings_main(fields["base"])
-
- if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
- fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
-
- # additional fields needed for main and controlnet models
- if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
- ModelFormat.Checkpoint,
- ModelFormat.BnbQuantizednf4b,
- ModelFormat.GGUFQuantized,
- ]:
- ckpt_config_path = cls._get_checkpoint_config_path(
- model_path,
- model_type=fields["type"],
- base_type=fields["base"],
- variant_type=fields["variant"],
- prediction_type=fields["prediction_type"],
- )
- fields["config_path"] = str(ckpt_config_path)
-
- # additional fields needed for main non-checkpoint models
- elif fields["type"] == ModelType.Main and fields["format"] in [
- ModelFormat.ONNX,
- ModelFormat.Olive,
- ModelFormat.Diffusers,
- ]:
- fields["upcast_attention"] = fields.get("upcast_attention") or (
- fields["base"] == BaseModelType.StableDiffusion2
- and fields["prediction_type"] == SchedulerPredictionType.VPrediction
- )
-
- get_submodels = getattr(probe, "get_submodels", None)
- if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
- fields["submodels"] = get_submodels()
-
- model_info = ModelConfigFactory.make_config(fields)
- return model_info
-
- @classmethod
- def get_model_name(cls, model_path: Path) -> str:
- if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
- return model_path.stem
- else:
- return model_path.name
-
- @classmethod
- def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
- if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth", ".gguf"):
- raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
-
- if model_path.name == "learned_embeds.bin":
- return ModelType.TextualInversion
-
- ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
- ckpt = ckpt.get("state_dict", ckpt)
-
- if isinstance(ckpt, dict) and is_state_dict_likely_flux_control(ckpt):
- return ModelType.ControlLoRa
-
- if isinstance(ckpt, dict) and is_state_dict_likely_flux_redux(ckpt):
- return ModelType.FluxRedux
-
- for key in [str(k) for k in ckpt.keys()]:
- if key.startswith(
- (
- "cond_stage_model.",
- "first_stage_model.",
- "model.diffusion_model.",
- # Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
- # This prefix is typically used to distinguish between multiple models bundled in a single file.
- "model.diffusion_model.double_blocks.",
- )
- ):
- # Keys starting with double_blocks are associated with Flux models
- return ModelType.Main
- # FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be
- # careful to avoid false positives on XLabs FLUX IP-Adapter models.
- elif key.startswith("double_blocks.") and "ip_adapter" not in key:
- return ModelType.Main
- elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
- return ModelType.VAE
- elif key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
- return ModelType.LoRA
- # "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
- # LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
- elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
- return ModelType.LoRA
- elif key.startswith(
- (
- "controlnet",
- "control_model",
- "input_blocks",
- # XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
- # For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
- # TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
- # "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
- # delicate.
- "controlnet_blocks",
- )
- ):
- return ModelType.ControlNet
- elif key.startswith(
- (
- "image_proj.",
- "ip_adapter.",
- # XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
- "ip_adapter_proj_model.",
- )
- ):
- return ModelType.IPAdapter
- elif key in {"emb_params", "string_to_param"}:
- return ModelType.TextualInversion
-
- # diffusers-ti
- if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
- return ModelType.TextualInversion
-
- raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
-
- @classmethod
- def get_model_type_from_folder(cls, folder_path: Path) -> ModelType:
- """Get the model type of a hugging-face style folder."""
- class_name = None
- error_hint = None
- for suffix in ["bin", "safetensors"]:
- if (folder_path / f"learned_embeds.{suffix}").exists():
- return ModelType.TextualInversion
- if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
- return ModelType.LoRA
- if (folder_path / "unet/model.onnx").exists():
- return ModelType.ONNX
- if (folder_path / "image_encoder.txt").exists():
- return ModelType.IPAdapter
-
- config_path = None
- for p in [
- folder_path / "model_index.json", # pipeline
- folder_path / "config.json", # most diffusers
- folder_path / "text_encoder_2" / "config.json", # T5 text encoder
- folder_path / "text_encoder" / "config.json", # T5 CLIP
- ]:
- if p.exists():
- config_path = p
- break
-
- if config_path:
- with open(config_path, "r") as file:
- conf = json.load(file)
- if "_class_name" in conf:
- class_name = conf["_class_name"]
- elif "architectures" in conf:
- class_name = conf["architectures"][0]
- else:
- class_name = None
- else:
- error_hint = f"No model_index.json or config.json found in {folder_path}."
-
- if class_name and (type := cls.CLASS2TYPE.get(class_name)):
- return type
- else:
- error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
-
- # give up
- raise InvalidModelConfigException(
- f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
- )
-
- @classmethod
- def _get_checkpoint_config_path(
- cls,
- model_path: Path,
- model_type: ModelType,
- base_type: BaseModelType,
- variant_type: ModelVariantType,
- prediction_type: SchedulerPredictionType,
- ) -> Path:
- # look for a YAML file adjacent to the model file first
- possible_conf = model_path.with_suffix(".yaml")
- if possible_conf.exists():
- return possible_conf.absolute()
-
- if model_type is ModelType.Main:
- if base_type == BaseModelType.Flux:
- # TODO: Decide between dev/schnell
- checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
- state_dict = checkpoint.get("state_dict") or checkpoint
-
- # HACK: For FLUX, config_file is used as a key into invokeai.backend.flux.util.params during model
- # loading. When FLUX support was first added, it was decided that this was the easiest way to support
- # the various FLUX formats rather than adding new model types/formats. Be careful when modifying this in
- # the future.
- if (
- "guidance_in.out_layer.weight" in state_dict
- or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
- ):
- if variant_type == ModelVariantType.Normal:
- config_file = "flux-dev"
- elif variant_type == ModelVariantType.Inpaint:
- config_file = "flux-dev-fill"
- else:
- raise ValueError(f"Unexpected FLUX variant type: {variant_type}")
- else:
- config_file = "flux-schnell"
- else:
- config_file = LEGACY_CONFIGS[base_type][variant_type]
- if isinstance(config_file, dict): # need another tier for sd-2.x models
- config_file = config_file[prediction_type]
- config_file = f"stable-diffusion/{config_file}"
- elif model_type is ModelType.ControlNet:
- config_file = (
- "controlnet/cldm_v15.yaml"
- if base_type is BaseModelType.StableDiffusion1
- else "controlnet/cldm_v21.yaml"
- )
- elif model_type is ModelType.VAE:
- config_file = (
- # For flux, this is a key in invokeai.backend.flux.util.ae_params
- # Due to model type and format being the descriminator for model configs this
- # is used rather than attempting to support flux with separate model types and format
- # If changed in the future, please fix me
- "flux"
- if base_type is BaseModelType.Flux
- else "stable-diffusion/v1-inference.yaml"
- if base_type is BaseModelType.StableDiffusion1
- else "stable-diffusion/sd_xl_base.yaml"
- if base_type is BaseModelType.StableDiffusionXL
- else "stable-diffusion/v2-inference.yaml"
- )
- else:
- raise InvalidModelConfigException(
- f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
- )
- return Path(config_file)
-
- @classmethod
- def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
- with SilenceWarnings():
- if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
- cls._scan_model(model_path.name, model_path)
- model = torch.load(model_path, map_location="cpu")
- assert isinstance(model, dict)
- return model
- elif model_path.suffix.endswith(".gguf"):
- return gguf_sd_loader(model_path, compute_dtype=torch.float32)
- else:
- return safetensors.torch.load_file(model_path)
-
- @classmethod
- def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
- """
- Apply picklescanner to the indicated checkpoint and issue a warning
- and option to exit if an infected file is identified.
- """
- # scan model
- scan_result = pscan.scan_file_path(checkpoint)
- if scan_result.infected_files != 0:
- if get_config().unsafe_disable_picklescan:
- logger.warning(
- f"The model {model_name} is potentially infected by malware, but picklescan is disabled. "
- "Proceeding with caution."
- )
- else:
- raise RuntimeError(f"The model {model_name} is potentially infected by malware. Aborting import.")
- if scan_result.scan_err:
- if get_config().unsafe_disable_picklescan:
- logger.warning(
- f"Error scanning the model at {model_name} for malware, but picklescan is disabled. "
- "Proceeding with caution."
- )
- else:
- raise RuntimeError(f"Error scanning the model at {model_name} for malware. Aborting import.")
-
-
-# Probing utilities
-MODEL_NAME_TO_PREPROCESSOR = {
- "canny": "canny_image_processor",
- "mlsd": "mlsd_image_processor",
- "depth": "depth_anything_image_processor",
- "bae": "normalbae_image_processor",
- "normal": "normalbae_image_processor",
- "sketch": "pidi_image_processor",
- "scribble": "lineart_image_processor",
- "lineart anime": "lineart_anime_image_processor",
- "lineart_anime": "lineart_anime_image_processor",
- "lineart": "lineart_image_processor",
- "soft": "hed_image_processor",
- "softedge": "hed_image_processor",
- "hed": "hed_image_processor",
- "shuffle": "content_shuffle_image_processor",
- "pose": "dw_openpose_image_processor",
- "mediapipe": "mediapipe_face_processor",
- "pidi": "pidi_image_processor",
- "zoe": "zoe_depth_image_processor",
- "color": "color_map_image_processor",
-}
-
-
-def get_default_settings_control_adapters(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
- for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
- model_name_lower = model_name.lower()
- if k in model_name_lower:
- return ControlAdapterDefaultSettings(preprocessor=v)
- return None
-
-
-def get_default_settings_lora() -> LoraModelDefaultSettings:
- return LoraModelDefaultSettings()
-
-
-def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
- if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
- return MainModelDefaultSettings(width=512, height=512)
- elif model_base is BaseModelType.StableDiffusionXL:
- return MainModelDefaultSettings(width=1024, height=1024)
- # We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
- return None
-
-
-# ##################################################3
-# Checkpoint probing
-# ##################################################3
-
-
-class CheckpointProbeBase(ProbeBase):
- def __init__(self, model_path: Path):
- super().__init__(model_path)
- self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
-
- def get_format(self) -> ModelFormat:
- state_dict = self.checkpoint.get("state_dict") or self.checkpoint
- if (
- "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
- or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
- ):
- return ModelFormat.BnbQuantizednf4b
- elif any(isinstance(v, GGMLTensor) for v in state_dict.values()):
- return ModelFormat.GGUFQuantized
- return ModelFormat("checkpoint")
-
- def get_variant_type(self) -> AnyVariant:
- model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
- base_type = self.get_base_type()
- if model_type != ModelType.Main:
- return ModelVariantType.Normal
- state_dict = self.checkpoint.get("state_dict") or self.checkpoint
-
- if base_type == BaseModelType.Flux:
- in_channels = get_flux_in_channels_from_state_dict(state_dict)
-
- if in_channels is None:
- # If we cannot find the in_channels, we assume that this is a normal variant. Log a warning.
- logger.warning(
- f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant."
- )
- return ModelVariantType.Normal
-
- is_flux_dev = (
- "guidance_in.out_layer.weight" in state_dict
- or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
- )
-
- # FLUX Model variant types are distinguished by input channels:
- # - Unquantized Dev and Schnell have in_channels=64
- # - BNB-NF4 Dev and Schnell have in_channels=1
- # - FLUX Fill has in_channels=384
- # - Unsure of quantized FLUX Fill models
- # - Unsure of GGUF-quantized models
- if is_flux_dev and in_channels == 384:
- # This is a FLUX Fill model. FLUX Fill needs special handling throughout the application. The variant
- # type is used to determine whether to use the fill model or the base model.
- return FluxVariantType.DevFill
- elif is_flux_dev:
- # Fall back on "normal" variant type for all other FLUX models.
- return FluxVariantType.Dev
- else:
- return FluxVariantType.Schnell
-
- in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
- if in_channels == 9:
- return ModelVariantType.Inpaint
- elif in_channels == 5:
- return ModelVariantType.Depth
- elif in_channels == 4:
- return ModelVariantType.Normal
- else:
- raise InvalidModelConfigException(
- f"Cannot determine variant type (in_channels={in_channels}) at {self.model_path}"
- )
-
-
-class PipelineCheckpointProbe(CheckpointProbeBase):
- def get_base_type(self) -> BaseModelType:
- checkpoint = self.checkpoint
- state_dict = self.checkpoint.get("state_dict") or checkpoint
- if (
- "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
- or "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
- ):
- return BaseModelType.Flux
- key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
- if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
- return BaseModelType.StableDiffusion1
- if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
- return BaseModelType.StableDiffusion2
- key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
- if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
- return BaseModelType.StableDiffusionXL
- elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
- return BaseModelType.StableDiffusionXLRefiner
- else:
- raise InvalidModelConfigException("Cannot determine base type")
-
- def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
- """Return model prediction type."""
- type = self.get_base_type()
- if type == BaseModelType.StableDiffusion2:
- checkpoint = self.checkpoint
- state_dict = self.checkpoint.get("state_dict") or checkpoint
- key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
- if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
- if "global_step" in checkpoint:
- if checkpoint["global_step"] == 220000:
- return SchedulerPredictionType.Epsilon
- elif checkpoint["global_step"] == 110000:
- return SchedulerPredictionType.VPrediction
- return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
-
- elif type == BaseModelType.StableDiffusion1:
- return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
- else:
- return SchedulerPredictionType.Epsilon
-
-
-class LoRACheckpointProbe(CheckpointProbeBase):
- """Class for LoRA checkpoints."""
-
- def get_format(self) -> ModelFormat:
- if is_state_dict_likely_in_flux_diffusers_format(self.checkpoint):
- # TODO(ryand): This is an unusual case. In other places throughout the codebase, we treat
- # ModelFormat.Diffusers as meaning that the model is in a directory. In this case, the model is a single
- # file, but the weight keys are in the diffusers format.
- return ModelFormat.Diffusers
- return ModelFormat.LyCORIS
-
- def get_base_type(self) -> BaseModelType:
- if (
- is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
- or is_state_dict_likely_in_flux_onetrainer_format(self.checkpoint)
- or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
- or is_state_dict_likely_flux_control(self.checkpoint)
- ):
- return BaseModelType.Flux
-
- # If we've gotten here, we assume that the model is a Stable Diffusion model.
- token_vector_length = lora_token_vector_length(self.checkpoint)
- if token_vector_length == 768:
- return BaseModelType.StableDiffusion1
- elif token_vector_length == 1024:
- return BaseModelType.StableDiffusion2
- elif token_vector_length == 1280:
- return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
- elif token_vector_length == 2048:
- return BaseModelType.StableDiffusionXL
- else:
- raise InvalidModelConfigException(f"Unknown LoRA type: {self.model_path}")
-
-
-class ControlNetCheckpointProbe(CheckpointProbeBase):
- """Class for probing controlnets."""
-
- def get_base_type(self) -> BaseModelType:
- checkpoint = self.checkpoint
- if is_state_dict_xlabs_controlnet(checkpoint) or is_state_dict_instantx_controlnet(checkpoint):
- # TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
- # get_format()?
- return BaseModelType.Flux
-
- for key_name in (
- "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
- "controlnet_mid_block.bias",
- "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
- "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
- ):
- if key_name not in checkpoint:
- continue
- width = checkpoint[key_name].shape[-1]
- if width == 768:
- return BaseModelType.StableDiffusion1
- elif width == 1024:
- return BaseModelType.StableDiffusion2
- elif width == 2048:
- return BaseModelType.StableDiffusionXL
- elif width == 1280:
- return BaseModelType.StableDiffusionXL
- raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
-
-
-class IPAdapterCheckpointProbe(CheckpointProbeBase):
- """Class for probing IP Adapters"""
-
- def get_base_type(self) -> BaseModelType:
- checkpoint = self.checkpoint
-
- if is_state_dict_xlabs_ip_adapter(checkpoint):
- return BaseModelType.Flux
-
- for key in checkpoint.keys():
- if not key.startswith(("image_proj.", "ip_adapter.")):
- continue
- cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1]
- if cross_attention_dim == 768:
- return BaseModelType.StableDiffusion1
- elif cross_attention_dim == 1024:
- return BaseModelType.StableDiffusion2
- elif cross_attention_dim == 2048:
- return BaseModelType.StableDiffusionXL
- else:
- raise InvalidModelConfigException(
- f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
- )
- raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
-
-
-class CLIPVisionCheckpointProbe(CheckpointProbeBase):
- def get_base_type(self) -> BaseModelType:
- raise NotImplementedError()
-
-
-class T2IAdapterCheckpointProbe(CheckpointProbeBase):
- def get_base_type(self) -> BaseModelType:
- raise NotImplementedError()
-
-
-class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
- def get_base_type(self) -> BaseModelType:
- return BaseModelType.Any
-
-
-class SigLIPCheckpointProbe(CheckpointProbeBase):
- def get_base_type(self) -> BaseModelType:
- raise NotImplementedError()
-
-
-class FluxReduxCheckpointProbe(CheckpointProbeBase):
- def get_base_type(self) -> BaseModelType:
- return BaseModelType.Flux
-
-
-class LlavaOnevisionCheckpointProbe(CheckpointProbeBase):
- def get_base_type(self) -> BaseModelType:
- raise NotImplementedError()
-
-
-########################################################
-# classes for probing folders
-#######################################################
-class FolderProbeBase(ProbeBase):
- def get_variant_type(self) -> ModelVariantType:
- return ModelVariantType.Normal
-
- def get_format(self) -> ModelFormat:
- return ModelFormat("diffusers")
-
- def get_repo_variant(self) -> ModelRepoVariant:
- # get all files ending in .bin or .safetensors
- weight_files = list(self.model_path.glob("**/*.safetensors"))
- weight_files.extend(list(self.model_path.glob("**/*.bin")))
- for x in weight_files:
- if ".fp16" in x.suffixes:
- return ModelRepoVariant.FP16
- if "openvino_model" in x.name:
- return ModelRepoVariant.OpenVINO
- if "flax_model" in x.name:
- return ModelRepoVariant.Flax
- if x.suffix == ".onnx":
- return ModelRepoVariant.ONNX
- return ModelRepoVariant.Default
-
-
-class PipelineFolderProbe(FolderProbeBase):
- def get_base_type(self) -> BaseModelType:
- # Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
- config_path = self.model_path / "unet" / "config.json"
- if config_path.exists():
- with open(config_path) as file:
- unet_conf = json.load(file)
- if unet_conf["cross_attention_dim"] == 768:
- return BaseModelType.StableDiffusion1
- elif unet_conf["cross_attention_dim"] == 1024:
- return BaseModelType.StableDiffusion2
- elif unet_conf["cross_attention_dim"] == 1280:
- return BaseModelType.StableDiffusionXLRefiner
- elif unet_conf["cross_attention_dim"] == 2048:
- return BaseModelType.StableDiffusionXL
- else:
- raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
-
- # Handle pipelines with a transformer (i.e. SD3).
- config_path = self.model_path / "transformer" / "config.json"
- if config_path.exists():
- with open(config_path) as file:
- transformer_conf = json.load(file)
- if transformer_conf["_class_name"] == "SD3Transformer2DModel":
- return BaseModelType.StableDiffusion3
- elif transformer_conf["_class_name"] == "CogView4Transformer2DModel":
- return BaseModelType.CogView4
- else:
- raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
-
- raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
-
- def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
- with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
- scheduler_conf = json.load(file)
- if scheduler_conf.get("prediction_type", "epsilon") == "v_prediction":
- return SchedulerPredictionType.VPrediction
- elif scheduler_conf.get("prediction_type", "epsilon") == "epsilon":
- return SchedulerPredictionType.Epsilon
- else:
- raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
-
- def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]:
- config = ConfigLoader.load_config(self.model_path, config_name="model_index.json")
- submodels: Dict[SubModelType, SubmodelDefinition] = {}
- for key, value in config.items():
- if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
- continue
- model_loader = str(value[1])
- if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
- variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None)
- submodels[SubModelType(key)] = SubmodelDefinition(
- path_or_prefix=(self.model_path / key).resolve().as_posix(),
- model_type=model_type,
- variant=variant_func and variant_func((self.model_path / key).as_posix()),
- )
-
- return submodels
-
- def get_variant_type(self) -> ModelVariantType:
- # This only works for pipelines! Any kind of
- # exception results in our returning the
- # "normal" variant type
- try:
- config_file = self.model_path / "unet" / "config.json"
- with open(config_file, "r") as file:
- conf = json.load(file)
-
- in_channels = conf["in_channels"]
- if in_channels == 9:
- return ModelVariantType.Inpaint
- elif in_channels == 5:
- return ModelVariantType.Depth
- elif in_channels == 4:
- return ModelVariantType.Normal
- except Exception:
- pass
- return ModelVariantType.Normal
-
-
-class ONNXFolderProbe(PipelineFolderProbe):
- def get_base_type(self) -> BaseModelType:
- # Due to the way the installer is set up, the configuration file for safetensors
- # will come along for the ride if both the onnx and safetensors forms
- # share the same directory. We take advantage of this here.
- if (self.model_path / "unet" / "config.json").exists():
- return super().get_base_type()
- else:
- logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"')
- return BaseModelType.StableDiffusion1
-
- def get_format(self) -> ModelFormat:
- return ModelFormat("onnx")
-
- def get_variant_type(self) -> ModelVariantType:
- return ModelVariantType.Normal
-
-
-class ControlNetFolderProbe(FolderProbeBase):
- def get_base_type(self) -> BaseModelType:
- config_file = self.model_path / "config.json"
- if not config_file.exists():
- raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
- with open(config_file, "r") as file:
- config = json.load(file)
-
- if config.get("_class_name", None) == "FluxControlNetModel":
- return BaseModelType.Flux
-
- # no obvious way to distinguish between sd2-base and sd2-768
- dimension = config["cross_attention_dim"]
- if dimension == 768:
- return BaseModelType.StableDiffusion1
- if dimension == 1024:
- return BaseModelType.StableDiffusion2
- if dimension == 2048:
- return BaseModelType.StableDiffusionXL
- raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
-
-
-class LoRAFolderProbe(FolderProbeBase):
- def get_base_type(self) -> BaseModelType:
- model_file = None
- for suffix in ["safetensors", "bin"]:
- base_file = self.model_path / f"pytorch_lora_weights.{suffix}"
- if base_file.exists():
- model_file = base_file
- break
- if not model_file:
- raise InvalidModelConfigException("Unknown LoRA format encountered")
- return LoRACheckpointProbe(model_file).get_base_type()
-
-
-class IPAdapterFolderProbe(FolderProbeBase):
- def get_format(self) -> ModelFormat:
- return ModelFormat.InvokeAI
-
- def get_base_type(self) -> BaseModelType:
- model_file = self.model_path / "ip_adapter.bin"
- if not model_file.exists():
- raise InvalidModelConfigException("Unknown IP-Adapter model format.")
-
- state_dict = torch.load(model_file, map_location="cpu")
- cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
- if cross_attention_dim == 768:
- return BaseModelType.StableDiffusion1
- elif cross_attention_dim == 1024:
- return BaseModelType.StableDiffusion2
- elif cross_attention_dim == 2048:
- return BaseModelType.StableDiffusionXL
- else:
- raise InvalidModelConfigException(
- f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
- )
-
- def get_image_encoder_model_id(self) -> Optional[str]:
- encoder_id_path = self.model_path / "image_encoder.txt"
- if not encoder_id_path.exists():
- return None
- with open(encoder_id_path, "r") as f:
- image_encoder_model = f.readline().strip()
- return image_encoder_model
-
-
-class CLIPVisionFolderProbe(FolderProbeBase):
- def get_base_type(self) -> BaseModelType:
- return BaseModelType.Any
-
-
-class SpandrelImageToImageFolderProbe(FolderProbeBase):
- def get_base_type(self) -> BaseModelType:
- raise NotImplementedError()
-
-
-class SigLIPFolderProbe(FolderProbeBase):
- def get_base_type(self) -> BaseModelType:
- return BaseModelType.Any
-
-
-class FluxReduxFolderProbe(FolderProbeBase):
- def get_base_type(self) -> BaseModelType:
- raise NotImplementedError()
-
-
-class LlaveOnevisionFolderProbe(FolderProbeBase):
- def get_base_type(self) -> BaseModelType:
- return BaseModelType.Any
-
-
-class T2IAdapterFolderProbe(FolderProbeBase):
- def get_base_type(self) -> BaseModelType:
- config_file = self.model_path / "config.json"
- if not config_file.exists():
- raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
- with open(config_file, "r") as file:
- config = json.load(file)
-
- adapter_type = config.get("adapter_type", None)
- if adapter_type == "full_adapter_xl":
- return BaseModelType.StableDiffusionXL
- elif adapter_type == "full_adapter" or "light_adapter":
- # I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
- return BaseModelType.StableDiffusion1
- else:
- raise InvalidModelConfigException(
- f"Unable to determine base model for '{self.model_path}' (adapter_type = {adapter_type})."
- )
-
-
-# Register probe classes
-ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
-ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
-ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe)
-ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
-ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
-ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
-ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
-ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe)
-ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe)
-ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)
-
-ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
-ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
-ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe)
-ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
-ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
-ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
-ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
-ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe)
-ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe)
-ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe)
-
-ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py
index 75191517c7..c4d71bfe31 100644
--- a/invokeai/backend/model_manager/load/load_base.py
+++ b/invokeai/backend/model_manager/load/load_base.py
@@ -12,7 +12,7 @@ from typing import Any, Dict, Generator, Optional, Tuple
import torch
from invokeai.app.services.config import InvokeAIAppConfig
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py
index 139a7d2940..3fb7a574f3 100644
--- a/invokeai/backend/model_manager/load/load_default.py
+++ b/invokeai/backend/model_manager/load/load_default.py
@@ -6,7 +6,8 @@ from pathlib import Path
from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
-from invokeai.backend.model_manager.config import AnyModelConfig, Diffusers_Config_Base, InvalidModelConfigException
+from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
@@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
model_path = self._get_model_path(model_config)
if not model_path.exists():
- raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
+ raise FileNotFoundError(f"Files for model '{model_config.name}' not found at {model_path}")
with skip_torch_weight_init():
cache_record = self._load_and_cache(model_config, submodel_type)
diff --git a/invokeai/backend/model_manager/load/model_loader_registry.py b/invokeai/backend/model_manager/load/model_loader_registry.py
index 9b242fe167..ca9ea56edb 100644
--- a/invokeai/backend/model_manager/load/model_loader_registry.py
+++ b/invokeai/backend/model_manager/load/model_loader_registry.py
@@ -18,10 +18,8 @@ Use like this:
from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
- Config_Base,
-)
+from invokeai.backend.model_manager.configs.base import Config_Base
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load import ModelLoaderBase
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType
diff --git a/invokeai/backend/model_manager/load/model_loaders/clip_vision.py b/invokeai/backend/model_manager/load/model_loaders/clip_vision.py
index 9065e51fbf..0150e24248 100644
--- a/invokeai/backend/model_manager/load/model_loaders/clip_vision.py
+++ b/invokeai/backend/model_manager/load/model_loaders/clip_vision.py
@@ -3,10 +3,8 @@ from typing import Optional
from transformers import CLIPVisionModelWithProjection
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
- Diffusers_Config_Base,
-)
+from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
diff --git a/invokeai/backend/model_manager/load/model_loaders/cogview4.py b/invokeai/backend/model_manager/load/model_loaders/cogview4.py
index a1a9269edb..782ff38450 100644
--- a/invokeai/backend/model_manager/load/model_loaders/cogview4.py
+++ b/invokeai/backend/model_manager/load/model_loaders/cogview4.py
@@ -3,11 +3,8 @@ from typing import Optional
import torch
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
- Checkpoint_Config_Base,
- Diffusers_Config_Base,
-)
+from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py
index 62a8ed4f65..8fd1796b8f 100644
--- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py
+++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py
@@ -5,10 +5,8 @@ from typing import Optional
from diffusers import ControlNetModel
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
- ControlNet_Checkpoint_Config_Base,
-)
+from invokeai.backend.model_manager.configs.controlnet import ControlNet_Checkpoint_Config_Base
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py
index 07967c7c56..e44ddec382 100644
--- a/invokeai/backend/model_manager/load/model_loaders/flux.py
+++ b/invokeai/backend/model_manager/load/model_loaders/flux.py
@@ -34,21 +34,22 @@ from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers_params
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
- Checkpoint_Config_Base,
- CLIPEmbed_Diffusers_Config_Base,
+from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base
+from invokeai.backend.model_manager.configs.clip_embed import CLIPEmbed_Diffusers_Config_Base
+from invokeai.backend.model_manager.configs.controlnet import (
ControlNet_Checkpoint_Config_Base,
ControlNet_Diffusers_Config_Base,
- FLUXRedux_Checkpoint_Config,
- IPAdapter_Checkpoint_Config_Base,
+)
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
+from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoint_Config
+from invokeai.backend.model_manager.configs.ip_adapter import IPAdapter_Checkpoint_Config_Base
+from invokeai.backend.model_manager.configs.main import (
Main_BnBNF4_FLUX_Config,
Main_Checkpoint_FLUX_Config,
Main_GGUF_FLUX_Config,
- T5Encoder_BnBLLMint8_Config,
- T5Encoder_T5Encoder_Config,
- VAE_Checkpoint_Config_Base,
)
+from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config
+from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (
diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py
index 407a116b68..b888c69edf 100644
--- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py
+++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py
@@ -8,7 +8,8 @@ from typing import Any, Optional
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
-from invokeai.backend.model_manager.config import AnyModelConfig, Diffusers_Config_Base, InvalidModelConfigException
+from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (
@@ -56,9 +57,7 @@ class GenericDiffusersLoader(ModelLoader):
module, class_name = config[submodel_type.value]
result = self._hf_definition_to_type(module=module, class_name=class_name)
except KeyError as e:
- raise InvalidModelConfigException(
- f'The "{submodel_type}" submodel is not available for this model.'
- ) from e
+ raise ValueError(f'The "{submodel_type}" submodel is not available for this model.') from e
else:
try:
config = self._load_diffusers_config(model_path, config_name="config.json")
@@ -67,9 +66,9 @@ class GenericDiffusersLoader(ModelLoader):
elif class_name := config.get("architectures"):
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
else:
- raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
+ raise RuntimeError("Unable to decipher Load Class based on given config.json")
except KeyError as e:
- raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
+ raise ValueError("An expected config.json file is missing from this model.") from e
assert result is not None
return result
diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py
index d103bc5dbc..d133a36498 100644
--- a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py
+++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py
@@ -7,7 +7,7 @@ from typing import Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.raw_model import RawModel
diff --git a/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py b/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py
index b508137f81..e459bbf2bb 100644
--- a/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py
+++ b/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py
@@ -3,9 +3,7 @@ from typing import Optional
from transformers import LlavaOnevisionForConditionalGeneration
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
-)
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py
index 98f54224fa..29fb815d54 100644
--- a/invokeai/backend/model_manager/load/model_loaders/lora.py
+++ b/invokeai/backend/model_manager/load/model_loaders/lora.py
@@ -9,7 +9,7 @@ import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
diff --git a/invokeai/backend/model_manager/load/model_loaders/onnx.py b/invokeai/backend/model_manager/load/model_loaders/onnx.py
index 3078d622b4..a565bb11d0 100644
--- a/invokeai/backend/model_manager/load/model_loaders/onnx.py
+++ b/invokeai/backend/model_manager/load/model_loaders/onnx.py
@@ -5,7 +5,7 @@
from pathlib import Path
from typing import Optional
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
diff --git a/invokeai/backend/model_manager/load/model_loaders/sig_lip.py b/invokeai/backend/model_manager/load/model_loaders/sig_lip.py
index bdf38887a3..16b8e6c88d 100644
--- a/invokeai/backend/model_manager/load/model_loaders/sig_lip.py
+++ b/invokeai/backend/model_manager/load/model_loaders/sig_lip.py
@@ -3,9 +3,7 @@ from typing import Optional
from transformers import SiglipVisionModel
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
-)
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
diff --git a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py
index 44cb0277fc..e6d8f42990 100644
--- a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py
+++ b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py
@@ -3,9 +3,7 @@ from typing import Optional
import torch
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
-)
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py
index 647ad4dbf4..d0cc589379 100644
--- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py
+++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py
@@ -11,10 +11,9 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpain
StableDiffusionXLInpaintPipeline,
)
-from invokeai.backend.model_manager.config import (
- AnyModelConfig,
- Checkpoint_Config_Base,
- Diffusers_Config_Base,
+from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
+from invokeai.backend.model_manager.configs.main import (
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py
index 60ae4ea08b..2d0411a8df 100644
--- a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py
+++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py
@@ -4,7 +4,7 @@
from pathlib import Path
from typing import Optional
-from invokeai.backend.model_manager.config import AnyModelConfig
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (
diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py
index 12789e58c2..e91903ccda 100644
--- a/invokeai/backend/model_manager/load/model_loaders/vae.py
+++ b/invokeai/backend/model_manager/load/model_loaders/vae.py
@@ -5,7 +5,8 @@ from typing import Optional
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
-from invokeai.backend.model_manager.config import AnyModelConfig, VAE_Checkpoint_Config_Base
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
+from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
diff --git a/invokeai/backend/model_manager/util/lora_metadata_extractor.py b/invokeai/backend/model_manager/util/lora_metadata_extractor.py
index 842e78a788..12b1073935 100644
--- a/invokeai/backend/model_manager/util/lora_metadata_extractor.py
+++ b/invokeai/backend/model_manager/util/lora_metadata_extractor.py
@@ -8,7 +8,8 @@ from typing import Any, Dict, Optional, Set, Tuple
from PIL import Image
from invokeai.app.util.thumbnails import make_thumbnail
-from invokeai.backend.model_manager.config import AnyModelConfig, ModelType
+from invokeai.backend.model_manager.configs.factory import AnyModelConfig
+from invokeai.backend.model_manager.taxonomy import ModelType
logger = logging.getLogger(__name__)
diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py
index add394e71b..e4208dc848 100644
--- a/invokeai/backend/util/test_utils.py
+++ b/invokeai/backend/util/test_utils.py
@@ -7,7 +7,8 @@ import torch
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.model_records import UnknownModelException
-from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType
+from invokeai.backend.model_manager.load.load_base import LoadedModel
+from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
@pytest.fixture(scope="session")
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx
index bf4464bd5b..49a289b875 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx
@@ -17,6 +17,7 @@ import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
+import { isFluxFillMainModelModelConfig } from 'services/api/types';
const selectHasRasterLayersWithContent = createSelector(
selectActiveRasterLayerEntities,
@@ -46,11 +47,7 @@ export const ParamDenoisingStrength = memo(() => {
// Denoising strength does nothing if there are no raster layers w/ content
return true;
}
- if (
- selectedModelConfig?.type === 'main' &&
- selectedModelConfig?.base === 'flux' &&
- selectedModelConfig.variant === 'inpaint'
- ) {
+ if (selectedModelConfig && isFluxFillMainModelModelConfig(selectedModelConfig)) {
// Denoising strength is ignored by FLUX Fill, which is indicated by the variant being 'inpaint'
return true;
}
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts
index 03ef5404a6..197a3d6e3e 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts
@@ -154,7 +154,7 @@ export const getControlLayerWarnings = (
warnings.push(WARNINGS.CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL);
} else if (
model.base === 'flux' &&
- model.variant === 'inpaint' &&
+ model.variant === 'dev_fill' &&
entity.controlAdapter.model.type === 'control_lora'
) {
// FLUX inpaint variants are FLUX Fill models - not compatible w/ Control LoRA
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx
index 75b3ba4bc4..538ebf597e 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx
@@ -56,18 +56,17 @@ export const ModelView = memo(({ modelConfig }: Props) => {
- {modelConfig.type === 'main' && (
+ {modelConfig.type === 'main' && 'variant' in modelConfig && (
)}
{modelConfig.type === 'main' && modelConfig.format === 'diffusers' && modelConfig.repo_variant && (
)}
{modelConfig.type === 'main' && modelConfig.format === 'checkpoint' && (
- <>
-
-
-
- >
+
+ )}
+ {modelConfig.type === 'main' && modelConfig.format === 'checkpoint' && 'prediction_type' in modelConfig && (
+
)}
{modelConfig.type === 'ip_adapter' && modelConfig.format === 'invokeai' && (
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts
index 5a766e8d39..42d66f0fc8 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts
@@ -660,6 +660,7 @@ describe('Graph', () => {
cover_image: null,
type: 'main',
trigger_phrases: null,
+ prediction_type: 'epsilon',
default_settings: {
vae: null,
vae_precision: null,
@@ -673,7 +674,6 @@ describe('Graph', () => {
variant: 'inpaint',
format: 'diffusers',
repo_variant: 'fp16',
- submodels: null,
usage_info: null,
});
expect(field).toEqual({
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts
index 5fcd13ba4f..0f7163c44c 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts
@@ -5,7 +5,7 @@ import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store
import { getControlLayerWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { serializeError } from 'serialize-error';
-import type { ImageDTO, Invocation, MainModelConfig } from 'services/api/types';
+import type { FLUXModelConfig, ImageDTO, Invocation, MainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('system');
@@ -113,7 +113,7 @@ type AddControlLoRAArg = {
entities: CanvasControlLayerState[];
g: Graph;
rect: Rect;
- model: MainModelConfig;
+ model: FLUXModelConfig;
denoise: Invocation<'flux_denoise'>;
};
@@ -129,7 +129,7 @@ export const addControlLoRA = async ({ manager, entities, g, rect, model, denois
return;
}
- assert(model.variant !== 'inpaint', 'FLUX Control LoRA is not compatible with FLUX Fill.');
+ assert(model.variant !== 'dev_fill', 'FLUX Control LoRA is not compatible with FLUX Fill.');
assert(validControlLayers.length <= 1, 'Cannot add more than one FLUX control LoRA.');
const getImageDTOResult = await withResultAsync(() => {
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts
index 6b0140bf9e..558f8b2ffe 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts
@@ -49,7 +49,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise {
// Model Configs
export type AnyModelConfig = S['AnyModelConfig'];
export type MainModelConfig = Extract;
+export type FLUXModelConfig = Extract;
export type ControlLoRAModelConfig = Extract;
export type LoRAModelConfig = Extract;
export type VAEModelConfig = Extract;
@@ -134,6 +135,7 @@ type UnknownModelConfig = Extract;
export type FLUXKontextModelConfig = MainModelConfig;
export type ChatGPT4oModelConfig = ApiModelConfig;
export type Gemini2_5ModelConfig = ApiModelConfig;
+type SubmodelDefinition = S['SubmodelDefinition'];
/**
* Checks if a list of submodels contains any that match a given variant or type
@@ -141,7 +143,7 @@ export type Gemini2_5ModelConfig = ApiModelConfig;
* @param checkStr The string to check against for variant or type
* @returns A boolean
*/
-const checkSubmodel = (submodels: AnyModelConfig['submodels'], checkStr: string): boolean => {
+const checkSubmodel = (submodels: Record, checkStr: string): boolean => {
for (const submodel in submodels) {
if (
submodel &&
@@ -164,6 +166,7 @@ const checkSubmodels = (identifiers: string[], config: AnyModelConfig): boolean
return identifiers.every(
(identifier) =>
config.type === 'main' &&
+ 'submodels' in config &&
config.submodels &&
(identifier in config.submodels || checkSubmodel(config.submodels, identifier))
);
@@ -332,7 +335,7 @@ export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is
};
export const isFluxFillMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
- return config.type === 'main' && config.base === 'flux' && config.variant === 'inpaint';
+ return config.type === 'main' && config.base === 'flux' && config.variant === 'dev_fill';
};
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
diff --git a/scripts/classify-model.py b/scripts/classify-model.py
index 2ae253b72f..a9129860a7 100755
--- a/scripts/classify-model.py
+++ b/scripts/classify-model.py
@@ -8,7 +8,7 @@ from typing import get_args
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe
-from invokeai.backend.model_manager.config import ModelConfigFactory
+from invokeai.backend.model_manager.configs.factory import ModelConfigFactory
algos = ", ".join(set(get_args(HASHING_ALGORITHMS)))
diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py
index 03a7428382..f1249e4dc1 100644
--- a/tests/test_model_probe.py
+++ b/tests/test_model_probe.py
@@ -13,9 +13,9 @@ from torch import tensor
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelRepoVariant, ModelType, ModelVariantType
from invokeai.backend.model_manager.config import (
AnyModelConfig,
+ Config_Base,
InvalidModelConfigException,
MainDiffusersConfig,
- Config_Base,
ModelConfigFactory,
get_model_discriminator_value,
)