diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index cae423afe3..4cae15b8e3 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -39,7 +39,7 @@ from invokeai.backend.model_manager.config import ( AnyModelConfig, CheckpointConfigBase, InvalidModelConfigException, - ModelConfigBase, + ModelConfigFactory, ) from invokeai.backend.model_manager.legacy_probe import ModelProbe from invokeai.backend.model_manager.metadata import ( @@ -612,7 +612,11 @@ class ModelInstallService(ModelInstallServiceBase): try: return ModelProbe.probe(model_path=model_path, fields=deepcopy(fields), hash_algo=hash_algo) # type: ignore except InvalidModelConfigException: - return ModelConfigBase.classify(mod=model_path, fields=deepcopy(fields), hash_algo=hash_algo) + return ModelConfigFactory.from_model_on_disk( + mod=model_path, + overrides=deepcopy(fields), + hash_algo=hash_algo, + ) def _register( self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None @@ -633,7 +637,7 @@ class ModelInstallService(ModelInstallServiceBase): info.path = model_path.as_posix() - if isinstance(info, CheckpointConfigBase): + if isinstance(info, CheckpointConfigBase) and info.config_path is not None: # Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the # invoke-managed legacy config dir, we use a relative path. legacy_config_path = self.app_config.legacy_conf_path / info.config_path diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 8f5f7c89f6..bb77cf6fdf 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -86,34 +86,86 @@ class NotAMatch(Exception): reason: The reason why the model did not match. """ - def __init__(self, config_class: "Type[AnyModelConfig]", reason: str): + def __init__( + self, + config_class: type, + reason: str, + ): super().__init__(f"{config_class.__name__} does not match: {reason}") DEFAULTS_PRECISION = Literal["fp16", "fp32"] -def get_class_name_from_config(config: dict[str, Any]) -> Optional[str]: - if "_class_name" in config: - return config["_class_name"] - elif "architectures" in config: - return config["architectures"][0] - else: - return None +def get_config_or_raise( + config_class: type, + config_path: Path, +) -> dict[str, Any]: + """Load the config file at the given path, or raise NotAMatch if it cannot be loaded.""" + if not config_path.exists(): + raise NotAMatch(config_class, f"missing config file: {config_path}") + + try: + config = load_json(config_path) + return config + except Exception as e: + raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e -def validate_overrides( - config_class: "Type[AnyModelConfig]", overrides: dict[str, Any], allowed: dict[str, Any] +def raise_for_class_names( + config_class: type, + config_path: Path, + valid_class_names: set[str], ) -> None: - for key, value in allowed.items(): - if key not in overrides: + """Raise NotAMatch if the config file is missing or does not contain a valid class name.""" + + config = get_config_or_raise(config_class, config_path) + + try: + if "_class_name" in config: + config_class_name = config["_class_name"] + elif "architectures" in config: + config_class_name = config["architectures"][0] + else: + raise ValueError("missing _class_name or architectures field") + except Exception as e: + raise NotAMatch(config_class, f"unable to determine class name from config file: {config_path}") from e + + if config_class_name not in valid_class_names: + raise NotAMatch(config_class, f"model class is not one of {valid_class_names}, got {config_class_name}") + + +def matches_overrides( + config_class: "Type[AnyModelConfig]", + provided_overrides: dict[str, Any], + valid_overrides: dict[str, Any], +) -> bool: + """Check if the provided overrides match the valid overrides for this config class. + + Args: + config_class: The config class that is being tested. + provided_overrides: The overrides provided by the user. + valid_overrides: The overrides that are valid for this config class. + + Returns: + True if all provided overrides match the valid overrides, False if some valid overrides are missing. + + Raises: + NotAMatch if any override does not match the allowed value. + """ + is_perfect_match = True + for key, value in valid_overrides.items(): + if key not in provided_overrides: + is_perfect_match = False continue - if overrides[key] != value: + if provided_overrides[key] != value: raise NotAMatch( config_class, - f"override {key}={overrides[key]} does not match required value {key}={value}", + f"override {key}={provided_overrides[key]} does not match required value {key}={value}", ) + return is_perfect_match + class SubmodelDefinition(BaseModel): path_or_prefix: str @@ -327,36 +379,32 @@ def load_json(path: Path) -> dict[str, Any]: class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder + VALID_OVERRIDES: ClassVar = { + "type": ModelType.T5Encoder, + "format": ModelFormat.T5Encoder, + } + + VALID_CLASS_NAMES: ClassVar = { + "T5EncoderModel", + } + @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - - if type_override is not None and type_override is not ModelType.T5Encoder: - raise NotAMatch(cls, f"type override is {type_override}, not T5Encoder") - - if format_override is not None and format_override is not ModelFormat.T5Encoder: - raise NotAMatch(cls, f"format override is {format_override}, not T5Encoder") - - if type_override is ModelType.T5Encoder and format_override is ModelFormat.T5Encoder: + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, + ): return cls(**fields) if mod.path.is_file(): raise NotAMatch(cls, "model path is a file, not a directory") - # Heuristic: Look for the T5EncoderModel class name in the config - try: - config = load_json(mod.path / "text_encoder_2" / "config.json") - except Exception as e: - raise NotAMatch(cls, "unable to load text_encoder_2/config.json") from e - - try: - config_class_name = get_class_name_from_config(config) - except Exception as e: - raise NotAMatch(cls, "unable to determine class name from text_encoder_2/config.json") from e - - if config_class_name != "T5EncoderModel": - raise NotAMatch(cls, "model class is not T5EncoderModel") + raise_for_class_names( + config_class=cls, + config_path=mod.path / "text_encoder_2" / "config.json", + valid_class_names=cls.VALID_CLASS_NAMES, + ) # Heuristic: Look for the presence of the unquantized config file (not present for bnb-quantized models) has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists() @@ -370,33 +418,30 @@ class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase): format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b + VALID_OVERRIDES: ClassVar = { + "type": ModelType.T5Encoder, + "format": ModelFormat.BnbQuantizedLlmInt8b, + } + + VALID_CLASS_NAMES: ClassVar = { + "T5EncoderModel", + } + @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - - if type_override is not None and type_override is not ModelType.T5Encoder: - raise NotAMatch(cls, f"type override is {type_override}, not T5Encoder") - - if format_override is not None and format_override is not ModelFormat.BnbQuantizedLlmInt8b: - raise NotAMatch(cls, f"format override is {format_override}, not BnbQuantizedLlmInt8b") - - if type_override is ModelType.T5Encoder and format_override is ModelFormat.BnbQuantizedLlmInt8b: + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, + ): return cls(**fields) # Heuristic: Look for the T5EncoderModel class name in the config - try: - config = load_json(mod.path / "text_encoder_2" / "config.json") - except Exception as e: - raise NotAMatch(cls, "unable to load text_encoder_2/config.json") from e - - try: - config_class_name = get_class_name_from_config(config) - except Exception as e: - raise NotAMatch(cls, "unable to determine class name from text_encoder_2/config.json") from e - - if config_class_name != "T5EncoderModel": - raise NotAMatch(cls, "model class is not T5EncoderModel") + raise_for_class_names( + config_class=cls, + config_path=mod.path / "text_encoder_2" / "config.json", + valid_class_names=cls.VALID_CLASS_NAMES, + ) # Heuristic: look for the quantization in the filename name filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix()) @@ -413,18 +458,18 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase): class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): format: Literal[ModelFormat.OMI] = ModelFormat.OMI + VALID_OVERRIDES: ClassVar = { + "type": ModelType.LoRA, + "format": ModelFormat.OMI, + } + @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - - if type_override is not None and type_override is not ModelType.LoRA: - raise NotAMatch(cls, f"type override is {type_override}, not LoRA") - - if format_override is not None and format_override is not ModelFormat.OMI: - raise NotAMatch(cls, f"format override is {format_override}, not OMI") - - if type_override is ModelType.LoRA and format_override is ModelFormat.OMI: + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, + ): return cls(**fields) # Heuristic: OMI LoRAs are always files, never directories @@ -446,12 +491,12 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): if not is_omi_lora_heuristic: raise NotAMatch(cls, "model does not match OMI LoRA heuristics") - base = fields.get("base") or cls.get_base_or_raise(mod) + base = fields.get("base") or cls._get_base_or_raise(mod) return cls(**fields, base=base) @classmethod - def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: metadata = mod.metadata() architecture = metadata["modelspec.architecture"] @@ -468,18 +513,18 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS + VALID_OVERRIDES: ClassVar = { + "type": ModelType.LoRA, + "format": ModelFormat.LyCORIS, + } + @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - - if type_override is not None and type_override is not ModelType.LoRA: - raise NotAMatch(cls, f"type override is {type_override}, not LoRA") - - if format_override is not None and format_override is not ModelFormat.LyCORIS: - raise NotAMatch(cls, f"format override is {format_override}, not LyCORIS") - - if type_override is ModelType.LoRA and format_override is ModelFormat.LyCORIS: + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, + ): return cls(**fields) # Heuristic: LyCORIS LoRAs are always files, never directories @@ -544,18 +589,18 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + VALID_OVERRIDES: ClassVar = { + "type": ModelType.LoRA, + "format": ModelFormat.Diffusers, + } + @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - - if type_override is not None and type_override is not ModelType.LoRA: - raise NotAMatch(cls, f"type override is {type_override}, not LoRA") - - if format_override is not None and format_override is not ModelFormat.Diffusers: - raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") - - if type_override is ModelType.LoRA and format_override is ModelFormat.Diffusers: + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, + ): return cls(**fields) # Heuristic: Diffusers LoRAs are always directories, never files @@ -583,8 +628,31 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase): format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + VALID_OVERRIDES: ClassVar = { + "type": ModelType.VAE, + "format": ModelFormat.Checkpoint, + } + @classmethod - def get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, + ): + return cls(**fields) + + if mod.path.is_dir(): + raise NotAMatch(cls, "model path is a directory, not a file") + + if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}): + raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics") + + base = fields.get("base") or cls._get_base_or_raise(mod) + return cls(**fields, base=base) + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name for regexp, basetype in [ (r"xl", BaseModelType.StableDiffusionXL), @@ -597,36 +665,41 @@ class VAECheckpointConfig(VAEConfigBase, CheckpointConfigBase, ModelConfigBase): raise NotAMatch(cls, "cannot determine base type") - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - - if type_override is not None and type_override is not ModelType.VAE: - raise NotAMatch(cls, f"type override is {type_override}, not VAE") - - if format_override is not None and format_override is not ModelFormat.Checkpoint: - raise NotAMatch(cls, f"format override is {format_override}, not Checkpoint") - - if type_override is ModelType.VAE and format_override is ModelFormat.Checkpoint: - return cls(**fields) - - if mod.path.is_dir(): - raise NotAMatch(cls, "model path is a directory, not a file") - - if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}): - raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics") - - base = fields.get("base") or cls.get_base_or_raise(mod) - return cls(**fields, base=base) - class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): """Model config for standalone VAE models (diffusers version).""" format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - CLASS_NAMES: ClassVar = {"AutoencoderKL", "AutoencoderTiny"} + VALID_OVERRIDES: ClassVar = { + "type": ModelType.VAE, + "format": ModelFormat.Diffusers, + } + VALID_CLASS_NAMES: ClassVar = { + "AutoencoderKL", + "AutoencoderTiny", + } + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, + ): + return cls(**fields) + + if mod.path.is_file(): + raise NotAMatch(cls, "model path is a file, not a directory") + + raise_for_class_names( + config_class=cls, + config_path=mod.path / "config.json", + valid_class_names=cls.VALID_CLASS_NAMES, + ) + + base = fields.get("base") or cls._get_base_or_raise(mod) + return cls(**fields, base=base) @classmethod def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool: @@ -648,7 +721,8 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): return name @classmethod - def get_base(cls, mod: ModelOnDisk, config: dict[str, Any]) -> BaseModelType: + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + config = get_config_or_raise(cls, mod.path / "config.json") if cls._config_looks_like_sdxl(config): return BaseModelType.StableDiffusionXL elif cls._name_looks_like_sdxl(mod): @@ -657,39 +731,6 @@ class VAEDiffusersConfig(VAEConfigBase, DiffusersConfigBase, ModelConfigBase): # TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO. return BaseModelType.StableDiffusion1 - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - - if type_override is not None and type_override is not ModelType.VAE: - raise NotAMatch(cls, f"type override is {type_override}, not VAE") - - if format_override is not None and format_override is not ModelFormat.Diffusers: - raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") - - if type_override is ModelType.VAE and format_override is ModelFormat.Diffusers: - return cls(**fields) - - if mod.path.is_file(): - raise NotAMatch(cls, "model path is a file, not a directory") - - try: - config = load_json(mod.path / "config.json") - except Exception as e: - raise NotAMatch(cls, "unable to load config.json") from e - - try: - config_class_name = get_class_name_from_config(config) - except Exception as e: - raise NotAMatch(cls, "unable to determine class name from config") from e - - if config_class_name not in cls.CLASS_NAMES: - raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}") - - base = fields.get("base") or cls.get_base(mod, config) - return cls(**fields, base=base) - class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): """Model config for ControlNet models (diffusers version).""" @@ -710,7 +751,7 @@ class TextualInversionConfigBase(ABC, BaseModel): KNOWN_KEYS: ClassVar = {"string_to_param", "emb_params", "clip_g"} @classmethod - def file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool: + def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool: try: p = path or mod.path @@ -738,11 +779,15 @@ class TextualInversionConfigBase(ABC, BaseModel): return False @classmethod - def get_base(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: + def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: p = path or mod.path try: state_dict = mod.load_state_dict(p) + except Exception as e: + raise NotAMatch(cls, f"unable to load state dict from {p}: {e}") from e + + try: if "string_to_token" in state_dict: token_dim = list(state_dict["string_to_param"].values())[0].shape[-1] elif "emb_params" in state_dict: @@ -751,49 +796,18 @@ class TextualInversionConfigBase(ABC, BaseModel): token_dim = state_dict["clip_g"].shape[-1] else: token_dim = list(state_dict.values())[0].shape[0] + except Exception as e: + raise NotAMatch(cls, f"unable to determine token dimension from state dict in {p}: {e}") from e - match token_dim: - case 768: - return BaseModelType.StableDiffusion1 - case 1024: - return BaseModelType.StableDiffusion2 - case 1280: - return BaseModelType.StableDiffusionXL - case _: - pass - except Exception: - pass - - raise InvalidModelConfigException(f"{p}: Could not determine base type") - - @classmethod - def get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: - p = path or mod.path - - try: - state_dict = mod.load_state_dict(p) - if "string_to_token" in state_dict: - token_dim = list(state_dict["string_to_param"].values())[0].shape[-1] - elif "emb_params" in state_dict: - token_dim = state_dict["emb_params"].shape[-1] - elif "clip_g" in state_dict: - token_dim = state_dict["clip_g"].shape[-1] - else: - token_dim = list(state_dict.values())[0].shape[0] - - match token_dim: - case 768: - return BaseModelType.StableDiffusion1 - case 1024: - return BaseModelType.StableDiffusion2 - case 1280: - return BaseModelType.StableDiffusionXL - case _: - pass - except Exception: - pass - - raise InvalidModelConfigException(f"{p}: Could not determine base type") + match token_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatch(cls, f"unrecognized token dimension {token_dim}") class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): @@ -801,31 +815,31 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase): format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile + VALID_OVERRIDES: ClassVar = { + "type": ModelType.TextualInversion, + "format": ModelFormat.EmbeddingFile, + } + @classmethod def get_tag(cls) -> Tag: return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}") @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - - if type_override is not None and type_override is not ModelType.TextualInversion: - raise NotAMatch(cls, f"type override is {type_override}, not TextualInversion") - - if format_override is not None and format_override is not ModelFormat.EmbeddingFile: - raise NotAMatch(cls, f"format override is {format_override}, not EmbeddingFile") - - if type_override is ModelType.TextualInversion and format_override is ModelFormat.EmbeddingFile: + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, + ): return cls(**fields) if mod.path.is_dir(): raise NotAMatch(cls, "model path is a directory, not a file") - if not cls.file_looks_like_embedding(mod): + if not cls._file_looks_like_embedding(mod): raise NotAMatch(cls, "model does not look like a textual inversion embedding file") - base = fields.get("base") or cls.get_base_or_raise(mod) + base = fields.get("base") or cls._get_base_or_raise(mod) return cls(**fields, base=base) @@ -834,30 +848,30 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase): format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder + VALID_OVERRIDES: ClassVar = { + "type": ModelType.TextualInversion, + "format": ModelFormat.EmbeddingFolder, + } + @classmethod def get_tag(cls) -> Tag: return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}") @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - - if type_override is not None and type_override is not ModelType.TextualInversion: - raise NotAMatch(cls, f"type override is {type_override}, not TextualInversion") - - if format_override is not None and format_override is not ModelFormat.EmbeddingFolder: - raise NotAMatch(cls, f"format override is {format_override}, not EmbeddingFolder") - - if type_override is ModelType.TextualInversion and format_override is ModelFormat.EmbeddingFolder: + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, + ): return cls(**fields) if mod.path.is_file(): raise NotAMatch(cls, "model path is a file, not a directory") for p in mod.weight_files(): - if cls.file_looks_like_embedding(mod, p): - base = fields.get("base") or cls.get_base_or_raise(mod, p) + if cls._file_looks_like_embedding(mod, p): + base = fields.get("base") or cls._get_base_or_raise(mod, p) return cls(**fields, base=base) raise NotAMatch(cls, "model does not look like a textual inversion embedding folder") @@ -937,7 +951,7 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase): format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers base: Literal[BaseModelType.Any] = BaseModelType.Any - CLASS_NAMES: ClassVar = { + VALID_CLASS_NAMES: ClassVar = { "CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection", @@ -963,47 +977,37 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): variant: Literal[ClipVariantType.G] = ClipVariantType.G + VALID_OVERRIDES: ClassVar = { + "type": ModelType.CLIPEmbed, + "format": ModelFormat.Diffusers, + "variant": ClipVariantType.G, + } + @classmethod def get_tag(cls) -> Tag: return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}") @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - variant_override = fields.get("variant") - - if type_override is not None and type_override is not ModelType.CLIPEmbed: - raise NotAMatch(cls, f"type override is {type_override}, not CLIPEmbed") - - if format_override is not None and format_override is not ModelFormat.Diffusers: - raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") - - if variant_override is not None and variant_override is not ClipVariantType.G: - raise NotAMatch(cls, f"variant override is {variant_override}, not G") - - if ( - type_override is ModelType.CLIPEmbed - and format_override is ModelFormat.Diffusers - and variant_override is ClipVariantType.G + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, ): return cls(**fields) if mod.path.is_file(): raise NotAMatch(cls, "model path is a file, not a directory") - try: - config = load_json(mod.path / "config.json") - except Exception as e: - raise NotAMatch(cls, "unable to load config.json") from e + config_path = mod.path / "config.json" - try: - config_class_name = get_class_name_from_config(config) - except Exception as e: - raise NotAMatch(cls, "unable to determine class name from config") from e + raise_for_class_names( + config_class=cls, + config_path=config_path, + valid_class_names=cls.VALID_CLASS_NAMES, + ) - if config_class_name not in cls.CLASS_NAMES: - raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}") + config = get_config_or_raise(cls, config_path) clip_variant = cls.get_clip_variant_type(config) @@ -1018,48 +1022,37 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase): variant: Literal[ClipVariantType.L] = ClipVariantType.L + VALID_OVERRIDES: ClassVar = { + "type": ModelType.CLIPEmbed, + "format": ModelFormat.Diffusers, + "variant": ClipVariantType.L, + } + @classmethod def get_tag(cls) -> Tag: return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}") @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - variant_override = fields.get("variant") - - if type_override is not None and type_override is not ModelType.CLIPEmbed: - raise NotAMatch(cls, f"type override is {type_override}, not CLIPEmbed") - - if format_override is not None and format_override is not ModelFormat.Diffusers: - raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") - - if variant_override is not None and variant_override is not ClipVariantType.L: - raise NotAMatch(cls, f"variant override is {variant_override}, not L") - - if ( - type_override is ModelType.CLIPEmbed - and format_override is ModelFormat.Diffusers - and variant_override is ClipVariantType.L + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, ): return cls(**fields) if mod.path.is_file(): raise NotAMatch(cls, "model path is a file, not a directory") - try: - config = load_json(mod.path / "config.json") - except Exception as e: - raise NotAMatch(cls, "unable to load config.json") from e + config_path = mod.path / "config.json" - try: - config_class_name = get_class_name_from_config(config) - except Exception as e: - raise NotAMatch(cls, "unable to determine class name from config") from e - - if config_class_name not in cls.CLASS_NAMES: - raise NotAMatch(cls, f"model class is not one of {cls.CLASS_NAMES}") + raise_for_class_names( + config_class=cls, + config_path=config_path, + valid_class_names=cls.VALID_CLASS_NAMES, + ) + config = get_config_or_raise(cls, config_path) clip_variant = cls.get_clip_variant_type(config) if clip_variant is not ClipVariantType.L: @@ -1089,25 +1082,18 @@ class SpandrelImageToImageConfig(ModelConfigBase): type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + VALID_OVERRIDES: ClassVar = { + "type": ModelType.SpandrelImageToImage, + "format": ModelFormat.Checkpoint, + "base": BaseModelType.Any, + } + @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - base_override = fields.get("base") - - if type_override is not None and type_override is not ModelType.SpandrelImageToImage: - raise NotAMatch(cls, f"type override is {type_override}, not SpandrelImageToImage") - - if format_override is not None and format_override is not ModelFormat.Checkpoint: - raise NotAMatch(cls, f"format override is {format_override}, not Checkpoint") - - if base_override is not None and base_override is not BaseModelType.Any: - raise NotAMatch(cls, f"base override is {base_override}, not Any") - - if ( - type_override is ModelType.SpandrelImageToImage - and format_override is ModelFormat.Checkpoint - and base_override is BaseModelType.Any + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, ): return cls(**fields) @@ -1151,40 +1137,36 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): base: Literal[BaseModelType.Any] = BaseModelType.Any variant: Literal[ModelVariantType.Normal] = ModelVariantType.Normal + VALID_OVERRIDES: ClassVar = { + "type": ModelType.LlavaOnevision, + "format": ModelFormat.Diffusers, + } + + VALID_CLASS_NAMES: ClassVar = { + "LlavaOnevisionForConditionalGeneration", + } + @classmethod def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: - type_override = fields.get("type") - format_override = fields.get("format") - - if type_override is not None and type_override is not ModelType.LlavaOnevision: - raise NotAMatch(cls, f"type override is {type_override}, not LlavaOnevision") - - if format_override is not None and format_override is not ModelFormat.Diffusers: - raise NotAMatch(cls, f"format override is {format_override}, not Diffusers") - - if type_override is ModelType.LlavaOnevision and format_override is ModelFormat.Diffusers: + if matches_overrides( + config_class=cls, + provided_overrides=fields, + valid_overrides=cls.VALID_OVERRIDES, + ): return cls(**fields) if mod.path.is_file(): raise NotAMatch(cls, "model path is a file, not a directory") - # Heuristic: Look for the LlavaOnevisionForConditionalGeneration class name in the config - try: - config = load_json(mod.path / "config.json") - except Exception as e: - raise NotAMatch(cls, "unable to load config.json") from e + config_path = mod.path / "config.json" - try: - config_class_name = get_class_name_from_config(config) - except Exception as e: - raise NotAMatch(cls, "unable to determine class name from config.json") from e + raise_for_class_names( + config_class=cls, + config_path=config_path, + valid_class_names=cls.VALID_CLASS_NAMES, + ) - if config_class_name != "LlavaOnevisionForConditionalGeneration": - raise NotAMatch(cls, "model class is not LlavaOnevisionForConditionalGeneration") - - base = fields.get("base") or BaseModelType.Any - variant = fields.get("variant") or ModelVariantType.Normal - return cls(**fields, base=base, variant=variant) + return cls(**fields) class ApiModelConfig(MainConfigBase, ModelConfigBase): diff --git a/scripts/classify-model.py b/scripts/classify-model.py index 6411b4c705..78bbd5a2f6 100755 --- a/scripts/classify-model.py +++ b/scripts/classify-model.py @@ -7,7 +7,8 @@ from pathlib import Path from typing import get_args from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS -from invokeai.backend.model_manager import InvalidModelConfigException, ModelConfigBase, ModelProbe +from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe +from invokeai.backend.model_manager.config import ModelConfigFactory algos = ", ".join(set(get_args(HASHING_ALGORITHMS))) @@ -30,7 +31,7 @@ def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS): try: return ModelProbe.probe(path, hash_algo=hash_algo) except InvalidModelConfigException: - return ModelConfigBase.classify(path, hash_algo) + return ModelConfigFactory.from_model_on_disk(mod=path, hash_algo=hash_algo,) for path in args.model_path: diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 8112ccdd19..e24a3ac8bd 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -132,7 +132,10 @@ class MinimalConfigExample(ModelConfigBase): def test_minimal_working_example(datadir: Path): model_path = datadir / "minimal_config_model.json" overrides = {"base": BaseModelType.StableDiffusion1} - config = ModelConfigBase.classify(model_path, **overrides) + config = ModelConfigFactory.from_model_on_disk( + mod=model_path, + overrides=overrides, + ) assert isinstance(config, MinimalConfigExample) assert config.base == BaseModelType.StableDiffusion1 @@ -160,7 +163,10 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading): try: stripped_mod = StrippedModelOnDisk(path) - new_config = ModelConfigBase.classify(stripped_mod, hash=fake_hash, key=fake_key) + new_config = ModelConfigFactory.from_model_on_disk( + mod=stripped_mod, + overrides={"hash": fake_hash, "key": fake_key}, + ) except InvalidModelConfigException: pass