feat(mm): consistent naming for all model config classes

This commit is contained in:
psychedelicious
2025-10-01 18:09:49 +10:00
parent 315ddefbf1
commit 1e1c8b988b
5 changed files with 130 additions and 137 deletions

View File

@@ -31,10 +31,10 @@ from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.config import (
AnyModelConfig,
Main_SD1_Checkpoint_Config,
Main_SD2_Checkpoint_Config,
Main_SDXL_Checkpoint_Config,
Main_SDXLRefiner_Checkpoint_Config,
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
)
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
@@ -747,10 +747,10 @@ async def convert_model(
if isinstance(
model_config,
(
Main_SD1_Checkpoint_Config,
Main_SD2_Checkpoint_Config,
Main_SDXL_Checkpoint_Config,
Main_SDXLRefiner_Checkpoint_Config,
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
),
):
msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model."

View File

@@ -17,7 +17,7 @@ from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
IPAdapter_FLUX_Checkpoint_Config,
IPAdapter_Checkpoint_FLUX_Config,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@@ -67,7 +67,7 @@ class FluxIPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, IPAdapter_FLUX_Checkpoint_Config)
assert isinstance(ip_adapter_info, IPAdapter_Checkpoint_FLUX_Config)
# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]

View File

@@ -806,19 +806,19 @@ class LoRA_Diffusers_Config_Base(LoRAConfigBase):
raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
class LoRA_SD1_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class LoRA_SD2_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class LoRA_SDXL_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class LoRA_FLUX_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@@ -876,19 +876,19 @@ class VAE_Checkpoint_Config_Base(CheckpointConfigBase):
raise NotAMatch(cls, "cannot determine base type")
class VAE_SD1_Checkpoint_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class VAE_SD2_Checkpoint_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class VAE_SDXL_Checkpoint_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class VAE_FLUX_Checkpoint_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@@ -956,11 +956,11 @@ class VAE_Diffusers_Config_Base(DiffusersConfigBase):
return BaseModelType.StableDiffusion1
class VAE_SD1_Diffusers_Config(VAE_Diffusers_Config_Base, ModelConfigBase):
class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class VAE_SDXL_Diffusers_Config(VAE_Diffusers_Config_Base, ModelConfigBase):
class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
@@ -1019,30 +1019,22 @@ class ControlNet_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig
raise NotAMatch(cls, f"unrecognized cross_attention_dim {dimension}")
class ControlNet_SD1_Diffusers_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class ControlNet_SD2_Diffusers_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class ControlNet_SDXL_Diffusers_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class ControlNet_FLUX_Diffusers_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
ControlNetCheckpoint_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
BaseModelType.Flux,
]
class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConfigBase):
"""Model config for ControlNet models (diffusers version)."""
@@ -1088,7 +1080,7 @@ class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConf
raise NotAMatch(cls, "state dict does not look like a ControlNet checkpoint")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetCheckpoint_SupportedBases:
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
state_dict = mod.load_state_dict()
if is_state_dict_xlabs_controlnet(state_dict) or is_state_dict_instantx_controlnet(state_dict):
@@ -1120,19 +1112,19 @@ class ControlNet_Checkpoint_Config_Base(CheckpointConfigBase, ControlAdapterConf
raise NotAMatch(cls, "unable to determine base type from state dict")
class ControlNet_SD1_Checkpoint_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class ControlNet_SD2_Checkpoint_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class ControlNet_SDXL_Checkpoint_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class ControlNet_FLUX_Checkpoint_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@@ -1226,15 +1218,15 @@ class TI_File_Config_Base(TI_Config_Base):
return cls(**fields)
class TI_SD1_File_Config(TI_File_Config_Base, ModelConfigBase):
class TI_File_SD1_Config(TI_File_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class TI_SD2_File_Config(TI_File_Config_Base, ModelConfigBase):
class TI_File_SD2_Config(TI_File_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class TI_SDXL_File_Config(TI_File_Config_Base, ModelConfigBase):
class TI_File_SDXL_Config(TI_File_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
@@ -1257,15 +1249,15 @@ class TI_Folder_Config_Base(TI_Config_Base):
raise NotAMatch(cls, "model does not look like a textual inversion embedding folder")
class TI_SD1_Folder_Config(TI_Folder_Config_Base, ModelConfigBase):
class TI_Folder_SD1_Config(TI_Folder_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class TI_SD2_Folder_Config(TI_Folder_Config_Base, ModelConfigBase):
class TI_Folder_SD2_Config(TI_Folder_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class TI_SDXL_Folder_Config(TI_Folder_Config_Base, ModelConfigBase):
class TI_Folder_SDXL_Config(TI_Folder_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
@@ -1409,19 +1401,19 @@ class Main_Checkpoint_Config_Base(CheckpointConfigBase, MainConfigBase):
raise NotAMatch(cls, "state dict does not look like a main model")
class Main_SD1_Checkpoint_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
class Main_Checkpoint_SD1_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class Main_SD2_Checkpoint_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
class Main_Checkpoint_SD2_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class Main_SDXL_Checkpoint_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
class Main_Checkpoint_SDXL_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class Main_SDXLRefiner_Checkpoint_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
class Main_Checkpoint_SDXLRefiner_Config(Main_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner)
@@ -1470,7 +1462,7 @@ def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | Non
return FluxVariantType.Schnell
class Main_FLUX_Checkpoint_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
class Main_Checkpoint_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
@@ -1540,7 +1532,7 @@ class Main_FLUX_Checkpoint_Config(CheckpointConfigBase, MainConfigBase, ModelCon
raise NotAMatch(cls, "state dict looks like GGUF quantized")
class Main_FLUX_BnBNF4_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
class Main_BnBNF4_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@@ -1589,7 +1581,7 @@ class Main_FLUX_BnBNF4_Config(CheckpointConfigBase, MainConfigBase, ModelConfigB
raise NotAMatch(cls, "state dict does not look like bnb quantized nf4")
class Main_FLUX_GGUF_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
class Main_GGUF_FLUX_Config(CheckpointConfigBase, MainConfigBase, ModelConfigBase):
"""Model config for main checkpoint models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@@ -1744,23 +1736,23 @@ class Main_Diffusers_Config_Base(DiffusersConfigBase, MainConfigBase):
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
class Main_SD1_Diffusers_Config(Main_Diffusers_Config_Base, ModelConfigBase):
class Main_Diffusers_SD1_Config(Main_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(BaseModelType.StableDiffusion1)
class Main_SD2_Diffusers_Config(Main_Diffusers_Config_Base, ModelConfigBase):
class Main_Diffusers_SD2_Config(Main_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion2] = Field(BaseModelType.StableDiffusion2)
class Main_SDXL_Diffusers_Config(Main_Diffusers_Config_Base, ModelConfigBase):
class Main_Diffusers_SDXL_Config(Main_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(BaseModelType.StableDiffusionXL)
class Main_SDXLRefiner_Diffusers_Config(Main_Diffusers_Config_Base, ModelConfigBase):
class Main_Diffusers_SDXLRefiner_Config(Main_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(BaseModelType.StableDiffusionXLRefiner)
class Main_SD3_Diffusers_Config(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
class Main_Diffusers_SD3_Config(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3)
@classmethod
@@ -1834,7 +1826,7 @@ class Main_SD3_Diffusers_Config(DiffusersConfigBase, MainConfigBase, ModelConfig
return submodels
class Main_CogView4_Diffusers_Config(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
class Main_Diffusers_CogView4_Config(DiffusersConfigBase, MainConfigBase, ModelConfigBase):
base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4)
@classmethod
@@ -1927,15 +1919,15 @@ class IPAdapter_InvokeAI_Config_Base(IPAdapterConfigBase):
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
class IPAdapter_SD1_InvokeAI_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase):
class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class IPAdapter_SD2_InvokeAI_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase):
class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class IPAdapter_SDXL_InvokeAI_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase):
class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
@@ -2000,19 +1992,19 @@ class IPAdapter_Checkpoint_Config_Base(IPAdapterConfigBase):
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
class IPAdapter_SD1_Checkpoint_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class IPAdapter_SD2_Checkpoint_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class IPAdapter_SDXL_Checkpoint_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class IPAdapter_FLUX_Checkpoint_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@@ -2069,11 +2061,11 @@ class CLIPEmbed_Diffusers_Config_Base(DiffusersConfigBase):
raise NotAMatch(cls, f"variant is {recognized_variant}, not {expected_variant}")
class CLIPEmbed_G_Diffusers_Config(CLIPEmbed_Diffusers_Config_Base, ModelConfigBase):
class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, ModelConfigBase):
variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G)
class CLIPEmbed_L_Diffusers_Config(CLIPEmbed_Diffusers_Config_Base, ModelConfigBase):
class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, ModelConfigBase):
variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L)
@@ -2148,11 +2140,11 @@ class T2IAdapter_Diffusers_Config_Base(DiffusersConfigBase, ControlAdapterConfig
raise NotAMatch(cls, f"unrecognized adapter_type '{adapter_type}'")
class T2IAdapter_SD1_Diffusers_Config(T2IAdapter_Diffusers_Config_Base, ModelConfigBase):
class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class T2IAdapter_SDXL_Diffusers_Config(T2IAdapter_Diffusers_Config_Base, ModelConfigBase):
class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
@@ -2325,39 +2317,39 @@ def get_model_discriminator_value(v: Any) -> str:
AnyModelConfig = Annotated[
Union[
# Main (Pipeline) - diffusers format
Annotated[Main_SD1_Diffusers_Config, Main_SD1_Diffusers_Config.get_tag()],
Annotated[Main_SD2_Diffusers_Config, Main_SD2_Diffusers_Config.get_tag()],
Annotated[Main_SDXL_Diffusers_Config, Main_SDXL_Diffusers_Config.get_tag()],
Annotated[Main_SDXLRefiner_Diffusers_Config, Main_SDXLRefiner_Diffusers_Config.get_tag()],
Annotated[Main_SD3_Diffusers_Config, Main_SD3_Diffusers_Config.get_tag()],
Annotated[Main_CogView4_Diffusers_Config, Main_CogView4_Diffusers_Config.get_tag()],
Annotated[Main_Diffusers_SD1_Config, Main_Diffusers_SD1_Config.get_tag()],
Annotated[Main_Diffusers_SD2_Config, Main_Diffusers_SD2_Config.get_tag()],
Annotated[Main_Diffusers_SDXL_Config, Main_Diffusers_SDXL_Config.get_tag()],
Annotated[Main_Diffusers_SDXLRefiner_Config, Main_Diffusers_SDXLRefiner_Config.get_tag()],
Annotated[Main_Diffusers_SD3_Config, Main_Diffusers_SD3_Config.get_tag()],
Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()],
# Main (Pipeline) - checkpoint format
Annotated[Main_SD1_Checkpoint_Config, Main_SD1_Checkpoint_Config.get_tag()],
Annotated[Main_SD2_Checkpoint_Config, Main_SD2_Checkpoint_Config.get_tag()],
Annotated[Main_SDXL_Checkpoint_Config, Main_SDXL_Checkpoint_Config.get_tag()],
Annotated[Main_SDXLRefiner_Checkpoint_Config, Main_SDXLRefiner_Checkpoint_Config.get_tag()],
Annotated[Main_FLUX_Checkpoint_Config, Main_FLUX_Checkpoint_Config.get_tag()],
Annotated[Main_Checkpoint_SD1_Config, Main_Checkpoint_SD1_Config.get_tag()],
Annotated[Main_Checkpoint_SD2_Config, Main_Checkpoint_SD2_Config.get_tag()],
Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()],
Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()],
Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()],
# Main (Pipeline) - quantized formats
Annotated[Main_FLUX_BnBNF4_Config, Main_FLUX_BnBNF4_Config.get_tag()],
Annotated[Main_FLUX_GGUF_Config, Main_FLUX_GGUF_Config.get_tag()],
Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()],
Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()],
# VAE - checkpoint format
Annotated[VAE_SD1_Checkpoint_Config, VAE_SD1_Checkpoint_Config.get_tag()],
Annotated[VAE_SD2_Checkpoint_Config, VAE_SD2_Checkpoint_Config.get_tag()],
Annotated[VAE_SDXL_Checkpoint_Config, VAE_SDXL_Checkpoint_Config.get_tag()],
Annotated[VAE_FLUX_Checkpoint_Config, VAE_FLUX_Checkpoint_Config.get_tag()],
Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()],
Annotated[VAE_Checkpoint_SD2_Config, VAE_Checkpoint_SD2_Config.get_tag()],
Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()],
Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()],
# VAE - diffusers format
Annotated[VAE_SD1_Diffusers_Config, VAE_SD1_Diffusers_Config.get_tag()],
Annotated[VAE_SDXL_Diffusers_Config, VAE_SDXL_Diffusers_Config.get_tag()],
Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()],
Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()],
# ControlNet - checkpoint format
Annotated[ControlNet_SD1_Checkpoint_Config, ControlNet_SD1_Checkpoint_Config.get_tag()],
Annotated[ControlNet_SD2_Checkpoint_Config, ControlNet_SD2_Checkpoint_Config.get_tag()],
Annotated[ControlNet_SDXL_Checkpoint_Config, ControlNet_SDXL_Checkpoint_Config.get_tag()],
Annotated[ControlNet_FLUX_Checkpoint_Config, ControlNet_FLUX_Checkpoint_Config.get_tag()],
Annotated[ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD1_Config.get_tag()],
Annotated[ControlNet_Checkpoint_SD2_Config, ControlNet_Checkpoint_SD2_Config.get_tag()],
Annotated[ControlNet_Checkpoint_SDXL_Config, ControlNet_Checkpoint_SDXL_Config.get_tag()],
Annotated[ControlNet_Checkpoint_FLUX_Config, ControlNet_Checkpoint_FLUX_Config.get_tag()],
# ControlNet - diffusers format
Annotated[ControlNet_SD1_Diffusers_Config, ControlNet_SD1_Diffusers_Config.get_tag()],
Annotated[ControlNet_SD2_Diffusers_Config, ControlNet_SD2_Diffusers_Config.get_tag()],
Annotated[ControlNet_SDXL_Diffusers_Config, ControlNet_SDXL_Diffusers_Config.get_tag()],
Annotated[ControlNet_FLUX_Diffusers_Config, ControlNet_FLUX_Diffusers_Config.get_tag()],
Annotated[ControlNet_Diffusers_SD1_Config, ControlNet_Diffusers_SD1_Config.get_tag()],
Annotated[ControlNet_Diffusers_SD2_Config, ControlNet_Diffusers_SD2_Config.get_tag()],
Annotated[ControlNet_Diffusers_SDXL_Config, ControlNet_Diffusers_SDXL_Config.get_tag()],
Annotated[ControlNet_Diffusers_FLUX_Config, ControlNet_Diffusers_FLUX_Config.get_tag()],
# LoRA - LyCORIS format
Annotated[LoRA_LyCORIS_SD1_Config, LoRA_LyCORIS_SD1_Config.get_tag()],
Annotated[LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SD2_Config.get_tag()],
@@ -2367,38 +2359,39 @@ AnyModelConfig = Annotated[
Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()],
Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()],
# LoRA - diffusers format
Annotated[LoRA_SD1_Diffusers_Config, LoRA_SD1_Diffusers_Config.get_tag()],
Annotated[LoRA_SD2_Diffusers_Config, LoRA_SD2_Diffusers_Config.get_tag()],
Annotated[LoRA_SDXL_Diffusers_Config, LoRA_SDXL_Diffusers_Config.get_tag()],
Annotated[LoRA_FLUX_Diffusers_Config, LoRA_FLUX_Diffusers_Config.get_tag()],
Annotated[LoRA_Diffusers_SD1_Config, LoRA_Diffusers_SD1_Config.get_tag()],
Annotated[LoRA_Diffusers_SD2_Config, LoRA_Diffusers_SD2_Config.get_tag()],
Annotated[LoRA_Diffusers_SDXL_Config, LoRA_Diffusers_SDXL_Config.get_tag()],
Annotated[LoRA_Diffusers_FLUX_Config, LoRA_Diffusers_FLUX_Config.get_tag()],
# ControlLoRA - diffusers format
Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()],
# T5 Encoder - all formats
Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()],
Annotated[T5Encoder_BnBLLMint8_Config, T5Encoder_BnBLLMint8_Config.get_tag()],
# TI - file format
Annotated[TI_SD1_File_Config, TI_SD1_File_Config.get_tag()],
Annotated[TI_SD2_File_Config, TI_SD2_File_Config.get_tag()],
Annotated[TI_SDXL_File_Config, TI_SDXL_File_Config.get_tag()],
Annotated[TI_File_SD1_Config, TI_File_SD1_Config.get_tag()],
Annotated[TI_File_SD2_Config, TI_File_SD2_Config.get_tag()],
Annotated[TI_File_SDXL_Config, TI_File_SDXL_Config.get_tag()],
# TI - folder format
Annotated[TI_SD1_Folder_Config, TI_SD1_Folder_Config.get_tag()],
Annotated[TI_SD2_Folder_Config, TI_SD2_Folder_Config.get_tag()],
Annotated[TI_SDXL_Folder_Config, TI_SDXL_Folder_Config.get_tag()],
Annotated[TI_Folder_SD1_Config, TI_Folder_SD1_Config.get_tag()],
Annotated[TI_Folder_SD2_Config, TI_Folder_SD2_Config.get_tag()],
Annotated[TI_Folder_SDXL_Config, TI_Folder_SDXL_Config.get_tag()],
# IP Adapter - InvokeAI format
Annotated[IPAdapter_SD1_InvokeAI_Config, IPAdapter_SD1_InvokeAI_Config.get_tag()],
Annotated[IPAdapter_SD2_InvokeAI_Config, IPAdapter_SD2_InvokeAI_Config.get_tag()],
Annotated[IPAdapter_SDXL_InvokeAI_Config, IPAdapter_SDXL_InvokeAI_Config.get_tag()],
Annotated[IPAdapter_InvokeAI_SD1_Config, IPAdapter_InvokeAI_SD1_Config.get_tag()],
Annotated[IPAdapter_InvokeAI_SD2_Config, IPAdapter_InvokeAI_SD2_Config.get_tag()],
Annotated[IPAdapter_InvokeAI_SDXL_Config, IPAdapter_InvokeAI_SDXL_Config.get_tag()],
# IP Adapter - checkpoint format
Annotated[IPAdapter_SD1_Checkpoint_Config, IPAdapter_SD1_Checkpoint_Config.get_tag()],
Annotated[IPAdapter_SD2_Checkpoint_Config, IPAdapter_SD2_Checkpoint_Config.get_tag()],
Annotated[IPAdapter_SDXL_Checkpoint_Config, IPAdapter_SDXL_Checkpoint_Config.get_tag()],
Annotated[IPAdapter_FLUX_Checkpoint_Config, IPAdapter_FLUX_Checkpoint_Config.get_tag()],
Annotated[IPAdapter_Checkpoint_SD1_Config, IPAdapter_Checkpoint_SD1_Config.get_tag()],
Annotated[IPAdapter_Checkpoint_SD2_Config, IPAdapter_Checkpoint_SD2_Config.get_tag()],
Annotated[IPAdapter_Checkpoint_SDXL_Config, IPAdapter_Checkpoint_SDXL_Config.get_tag()],
Annotated[IPAdapter_Checkpoint_FLUX_Config, IPAdapter_Checkpoint_FLUX_Config.get_tag()],
# T2I Adapter - diffusers format
Annotated[T2IAdapter_SD1_Diffusers_Config, T2IAdapter_SD1_Diffusers_Config.get_tag()],
Annotated[T2IAdapter_SDXL_Diffusers_Config, T2IAdapter_SDXL_Diffusers_Config.get_tag()],
Annotated[T2IAdapter_Diffusers_SD1_Config, T2IAdapter_Diffusers_SD1_Config.get_tag()],
Annotated[T2IAdapter_Diffusers_SDXL_Config, T2IAdapter_Diffusers_SDXL_Config.get_tag()],
# Misc models
Annotated[Spandrel_Checkpoint_Config, Spandrel_Checkpoint_Config.get_tag()],
Annotated[CLIPEmbed_G_Diffusers_Config, CLIPEmbed_G_Diffusers_Config.get_tag()],
Annotated[CLIPEmbed_L_Diffusers_Config, CLIPEmbed_L_Diffusers_Config.get_tag()],
Annotated[CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_G_Config.get_tag()],
Annotated[CLIPEmbed_Diffusers_L_Config, CLIPEmbed_Diffusers_L_Config.get_tag()],
Annotated[CLIPVision_Diffusers_Config, CLIPVision_Diffusers_Config.get_tag()],
Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()],
Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()],

View File

@@ -42,9 +42,9 @@ from invokeai.backend.model_manager.config import (
ControlNet_Diffusers_Config_Base,
FLUXRedux_Checkpoint_Config,
IPAdapter_Checkpoint_Config_Base,
Main_FLUX_BnBNF4_Config,
Main_FLUX_Checkpoint_Config,
Main_FLUX_GGUF_Config,
Main_BnBNF4_FLUX_Config,
Main_Checkpoint_FLUX_Config,
Main_GGUF_FLUX_Config,
T5Encoder_BnBLLMint8_Config,
T5Encoder_T5Encoder_Config,
VAE_Checkpoint_Config_Base,
@@ -226,7 +226,7 @@ class FluxCheckpointModel(ModelLoader):
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, Main_FLUX_Checkpoint_Config)
assert isinstance(config, Main_Checkpoint_FLUX_Config)
model_path = Path(config.path)
with accelerate.init_empty_weights():
@@ -268,7 +268,7 @@ class FluxGGUFCheckpointModel(ModelLoader):
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, Main_FLUX_GGUF_Config)
assert isinstance(config, Main_GGUF_FLUX_Config)
model_path = Path(config.path)
with accelerate.init_empty_weights():
@@ -314,7 +314,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, Main_FLUX_BnBNF4_Config)
assert isinstance(config, Main_BnBNF4_FLUX_Config)
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."

View File

@@ -15,14 +15,14 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
DiffusersConfigBase,
Main_SD1_Checkpoint_Config,
Main_SD1_Diffusers_Config,
Main_SD2_Checkpoint_Config,
Main_SD2_Diffusers_Config,
Main_SDXL_Checkpoint_Config,
Main_SDXL_Diffusers_Config,
Main_SDXLRefiner_Checkpoint_Config,
Main_SDXLRefiner_Diffusers_Config,
Main_Checkpoint_SD1_Config,
Main_Diffusers_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Diffusers_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Diffusers_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
Main_Diffusers_SDXLRefiner_Config,
)
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@@ -117,14 +117,14 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
assert isinstance(
config,
(
Main_SD1_Diffusers_Config,
Main_SD2_Diffusers_Config,
Main_SDXL_Diffusers_Config,
Main_SDXLRefiner_Diffusers_Config,
Main_SD1_Checkpoint_Config,
Main_SD2_Checkpoint_Config,
Main_SDXL_Checkpoint_Config,
Main_SDXLRefiner_Checkpoint_Config,
Main_Diffusers_SD1_Config,
Main_Diffusers_SD2_Config,
Main_Diffusers_SDXL_Config,
Main_Diffusers_SDXLRefiner_Config,
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
),
)
try: