refactor(mm): make config classes narrow

Simpler logic to identify, less complexity to add new model, fewer
useless attrs that do not relate to the model arch, etc
This commit is contained in:
psychedelicious
2025-10-01 16:52:28 +10:00
parent c065655a1d
commit af305250cb
9 changed files with 635 additions and 432 deletions

View File

@@ -17,8 +17,8 @@ from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
IPAdapter_InvokeAI_Config_Base,
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@@ -68,7 +68,7 @@ class FluxIPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig))
# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]

View File

@@ -13,8 +13,8 @@ from invokeai.app.services.model_records.model_records_base import ModelRecordCh
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
AnyModelConfig,
IPAdapter_InvokeAI_Config_Base,
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
)
from invokeai.backend.model_manager.starter_models import (
StarterModel,
@@ -123,9 +123,9 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig))
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else:

File diff suppressed because it is too large Load Diff

View File

@@ -12,9 +12,7 @@ from typing import Any, Dict, Generator, Optional, Tuple
import torch
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import (
AnyModelConfig,
)
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType

View File

@@ -7,7 +7,7 @@ from diffusers import ControlNetModel
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ControlNetCheckpointConfig,
ControlNet_Checkpoint_Config_Base,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@@ -46,7 +46,7 @@ class ControlNetLoader(GenericDiffusersLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
if isinstance(config, ControlNet_Checkpoint_Config_Base):
return ControlNetModel.from_single_file(
config.path,
torch_dtype=self._torch_dtype,

View File

@@ -37,26 +37,26 @@ from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
ControlNetCheckpointConfig,
ControlNetDiffusersConfig,
FLUX_Quantized_BnB_NF4_CheckpointConfig,
FLUX_Quantized_GGUF_CheckpointConfig,
FLUX_Unquantized_CheckpointConfig,
FluxReduxConfig,
IPAdapterCheckpointConfig,
T5EncoderBnbQuantizedLlmInt8bConfig,
T5EncoderConfig,
VAECheckpointConfig,
CLIPEmbed_Diffusers_Config_Base,
ControlNet_Checkpoint_Config_Base,
ControlNet_Diffusers_Config_Base,
FLUXRedux_Checkpoint_Config,
IPAdapter_Checkpoint_Config_Base,
Main_FLUX_BnBNF4_Config,
Main_FLUX_Checkpoint_Config,
Main_FLUX_GGUF_Config,
T5Encoder_BnBLLMint8_Config,
T5Encoder_T5Encoder_Config,
VAE_Checkpoint_Config_Base,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
FluxVariantType,
ModelFormat,
ModelType,
ModelVariantType,
SubModelType,
)
from invokeai.backend.model_manager.util.model_util import (
@@ -86,7 +86,7 @@ class FluxVAELoader(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, VAECheckpointConfig):
if not isinstance(config, VAE_Checkpoint_Config_Base):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)
@@ -116,7 +116,7 @@ class CLIPDiffusersLoader(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CLIPEmbedDiffusersConfig):
if not isinstance(config, CLIPEmbed_Diffusers_Config_Base):
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
match submodel_type:
@@ -139,7 +139,7 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig):
if not isinstance(config, T5Encoder_BnBLLMint8_Config):
raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.")
if not bnb_available:
raise ImportError(
@@ -186,7 +186,7 @@ class T5EncoderCheckpointModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderConfig):
if not isinstance(config, T5Encoder_T5Encoder_Config):
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
@@ -226,7 +226,7 @@ class FluxCheckpointModel(ModelLoader):
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, FLUX_Unquantized_CheckpointConfig)
assert isinstance(config, Main_FLUX_Checkpoint_Config)
model_path = Path(config.path)
with accelerate.init_empty_weights():
@@ -268,7 +268,7 @@ class FluxGGUFCheckpointModel(ModelLoader):
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, FLUX_Quantized_GGUF_CheckpointConfig)
assert isinstance(config, Main_FLUX_GGUF_Config)
model_path = Path(config.path)
with accelerate.init_empty_weights():
@@ -314,7 +314,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, FLUX_Quantized_BnB_NF4_CheckpointConfig)
assert isinstance(config, Main_FLUX_BnBNF4_Config)
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
@@ -342,9 +342,9 @@ class FluxControlnetModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
if isinstance(config, ControlNet_Checkpoint_Config_Base):
model_path = Path(config.path)
elif isinstance(config, ControlNetDiffusersConfig):
elif isinstance(config, ControlNet_Diffusers_Config_Base):
# If this is a diffusers directory, we simply ignore the config file and load from the weight file.
model_path = Path(config.path) / "diffusion_pytorch_model.safetensors"
else:
@@ -363,7 +363,7 @@ class FluxControlnetModel(ModelLoader):
def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
with accelerate.init_empty_weights():
# HACK(ryand): Is it safe to assume dev here?
model = XLabsControlNetFlux(get_flux_transformers_params(ModelVariantType.FluxDev))
model = XLabsControlNetFlux(get_flux_transformers_params(FluxVariantType.Dev))
model.load_state_dict(sd, assign=True)
return model
@@ -389,7 +389,7 @@ class FluxIpAdapterModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, IPAdapterCheckpointConfig):
if not isinstance(config, IPAdapter_Checkpoint_Config_Base):
raise ValueError(f"Unexpected model config type: {type(config)}.")
sd = load_file(Path(config.path))
@@ -412,7 +412,7 @@ class FluxReduxModelLoader(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, FluxReduxConfig):
if not isinstance(config, FLUXRedux_Checkpoint_Config):
raise ValueError(f"Unexpected model config type: {type(config)}.")
sd = load_file(Path(config.path))

View File

@@ -15,8 +15,14 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
DiffusersConfigBase,
SD_1_2_XL_XLRefiner_CheckpointConfig,
SD_1_2_XL_XLRefiner_DiffusersConfig,
Main_SD1_Checkpoint_Config,
Main_SD1_Diffusers_Config,
Main_SD2_Checkpoint_Config,
Main_SD2_Diffusers_Config,
Main_SDXL_Checkpoint_Config,
Main_SDXL_Diffusers_Config,
Main_SDXLRefiner_Checkpoint_Config,
Main_SDXLRefiner_Diffusers_Config,
)
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@@ -108,7 +114,19 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
ModelVariantType.Normal: StableDiffusionXLPipeline,
},
}
assert isinstance(config, (SD_1_2_XL_XLRefiner_DiffusersConfig, SD_1_2_XL_XLRefiner_CheckpointConfig))
assert isinstance(
config,
(
Main_SD1_Diffusers_Config,
Main_SD2_Diffusers_Config,
Main_SDXL_Diffusers_Config,
Main_SDXLRefiner_Diffusers_Config,
Main_SD1_Checkpoint_Config,
Main_SD2_Checkpoint_Config,
Main_SDXL_Checkpoint_Config,
Main_SDXLRefiner_Checkpoint_Config,
),
)
try:
load_class = load_classes[config.base][config.variant]
except KeyError as e:

View File

@@ -3,9 +3,9 @@
from typing import Optional
from diffusers import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from invokeai.backend.model_manager.config import AnyModelConfig, VAECheckpointConfig
from invokeai.backend.model_manager.config import AnyModelConfig, VAE_Checkpoint_Config_Base
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
@@ -27,7 +27,7 @@ class VAELoader(GenericDiffusersLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, VAECheckpointConfig):
if isinstance(config, VAE_Checkpoint_Config_Base):
return AutoencoderKL.from_single_file(
config.path,
torch_dtype=self._torch_dtype,

View File

@@ -21,7 +21,7 @@ from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings,
MainDiffusersConfig,
MainModelDefaultSettings,
TextualInversionFileConfig,
TI_File_Config,
VAEDiffusersConfig,
)
from invokeai.backend.model_manager.taxonomy import ModelSourceType
@@ -40,8 +40,8 @@ def store(
return ModelRecordServiceSQL(db, logger)
def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig:
config = TextualInversionFileConfig(
def example_ti_config(key: Optional[str] = None) -> TI_File_Config:
config = TI_File_Config(
source="test/source/",
source_type=ModelSourceType.Path,
path="/tmp/pokemon.bin",
@@ -61,7 +61,7 @@ def test_type(store: ModelRecordServiceBase):
config = example_ti_config("key1")
store.add_model(config)
config1 = store.get_model("key1")
assert isinstance(config1, TextualInversionFileConfig)
assert isinstance(config1, TI_File_Config)
def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):