mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor(mm): add config validation utils, make it all consistent and clean
This commit is contained in:
@@ -26,6 +26,7 @@ import re
|
||||
import time
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from functools import cache
|
||||
from inspect import isabstract
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
@@ -40,6 +41,7 @@ from typing import (
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError
|
||||
from pydantic_core import CoreSchema, SchemaValidator
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
@@ -102,6 +104,38 @@ class NotAMatch(Exception):
|
||||
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
# Utility from https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
|
||||
def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema:
|
||||
schema: CoreSchema = model.__pydantic_core_schema__.copy()
|
||||
# we shallow copied, be careful not to mutate the original schema!
|
||||
|
||||
assert schema["type"] in ["definitions", "model"]
|
||||
|
||||
# find the field schema
|
||||
field_schema = schema["schema"] # type: ignore
|
||||
while "fields" not in field_schema:
|
||||
field_schema = field_schema["schema"] # type: ignore
|
||||
|
||||
field_schema = field_schema["fields"][field_name]["schema"] # type: ignore
|
||||
|
||||
# if the original schema is a definition schema, replace the model schema with the field schema
|
||||
if schema["type"] == "definitions":
|
||||
schema["schema"] = field_schema
|
||||
return schema
|
||||
else:
|
||||
return field_schema
|
||||
|
||||
|
||||
@cache
|
||||
def validator(model: type[BaseModel], field_name: str) -> SchemaValidator:
|
||||
return SchemaValidator(find_field_schema(model, field_name))
|
||||
|
||||
|
||||
def validate_model_field(model: type[BaseModel], field_name: str, value: Any) -> Any:
|
||||
return validator(model, field_name).validate_python(value)
|
||||
|
||||
|
||||
# These utility functions are tightly coupled to the config classes below in order to make the process of raising
|
||||
# NotAMatch exceptions as easy and consistent as possible.
|
||||
|
||||
@@ -123,12 +157,15 @@ def _get_config_or_raise(
|
||||
raise NotAMatch(config_class, f"unable to load config file: {config_path}") from e
|
||||
|
||||
|
||||
def _validate_class_names(
|
||||
def _get_class_name_from_config(
|
||||
config_class: type,
|
||||
config_path: Path,
|
||||
valid_class_names: set[str],
|
||||
) -> None:
|
||||
"""Raise NotAMatch if the config file is missing or does not contain a valid class name."""
|
||||
) -> str:
|
||||
"""Load the config file and return the class name.
|
||||
|
||||
Raises:
|
||||
NotAMatch if the config file is missing or does not contain a valid class name.
|
||||
"""
|
||||
|
||||
config = _get_config_or_raise(config_class, config_path)
|
||||
|
||||
@@ -142,36 +179,48 @@ def _validate_class_names(
|
||||
except Exception as e:
|
||||
raise NotAMatch(config_class, f"unable to determine class name from config file: {config_path}") from e
|
||||
|
||||
if config_class_name not in valid_class_names:
|
||||
raise NotAMatch(config_class, f"model class is not one of {valid_class_names}, got {config_class_name}")
|
||||
if not isinstance(config_class_name, str):
|
||||
raise NotAMatch(config_class, f"_class_name or architectures field is not a string: {config_class_name}")
|
||||
|
||||
return config_class_name
|
||||
|
||||
|
||||
def _validate_overrides(
|
||||
config_class: type,
|
||||
provided_overrides: dict[str, Any],
|
||||
valid_overrides: dict[str, Any],
|
||||
) -> None:
|
||||
"""Check if the provided overrides match the valid overrides for this config class.
|
||||
def _validate_class_name(config_class: type[BaseModel], config_path: Path, expected: set[str]) -> None:
|
||||
"""Check if the class name in the config file matches the expected class names.
|
||||
|
||||
Args:
|
||||
config_class: The config class that is being tested.
|
||||
provided_overrides: The overrides provided by the user.
|
||||
valid_overrides: The overrides that are valid for this config class.
|
||||
config_path: The path to the config file.
|
||||
expected: The expected class names."""
|
||||
|
||||
class_name = _get_class_name_from_config(config_class, config_path)
|
||||
if class_name not in expected:
|
||||
raise NotAMatch(config_class, f"invalid class name from config: {class_name}")
|
||||
|
||||
|
||||
def _validate_override_fields(
|
||||
config_class: type[BaseModel],
|
||||
override_fields: dict[str, Any],
|
||||
) -> None:
|
||||
"""Check if the provided override fields are valid for the config class.
|
||||
|
||||
Args:
|
||||
config_class: The config class that is being tested.
|
||||
override_fields: The override fields provided by the user.
|
||||
|
||||
Raises:
|
||||
NotAMatch if any override does not match the allowed value.
|
||||
NotAMatch if any override field is invalid for the config.
|
||||
"""
|
||||
for key, value in valid_overrides.items():
|
||||
if key not in provided_overrides:
|
||||
continue
|
||||
if provided_overrides[key] != value:
|
||||
raise NotAMatch(
|
||||
config_class,
|
||||
f"override {key}={provided_overrides[key]} does not match required value {key}={value}",
|
||||
)
|
||||
for field_name, override_value in override_fields.items():
|
||||
if field_name not in config_class.model_fields:
|
||||
raise NotAMatch(config_class, f"unknown override field: {field_name}")
|
||||
try:
|
||||
validate_model_field(config_class, field_name, override_value)
|
||||
except ValidationError as e:
|
||||
raise NotAMatch(config_class, f"invalid override for field '{field_name}': {e}") from e
|
||||
|
||||
|
||||
def _raise_if_not_file(
|
||||
def _validate_is_file(
|
||||
config_class: type,
|
||||
mod: ModelOnDisk,
|
||||
) -> None:
|
||||
@@ -180,7 +229,7 @@ def _raise_if_not_file(
|
||||
raise NotAMatch(config_class, "model path is not a file")
|
||||
|
||||
|
||||
def _raise_if_not_dir(
|
||||
def _validate_is_dir(
|
||||
config_class: type,
|
||||
mod: ModelOnDisk,
|
||||
) -> None:
|
||||
@@ -346,7 +395,10 @@ class UnknownModelConfig(ModelConfigBase):
|
||||
class CheckpointConfigBase(ABC, BaseModel):
|
||||
"""Base class for checkpoint-style models."""
|
||||
|
||||
config_path: str | None = Field(None, description="Path to the config for this model, if any.")
|
||||
config_path: str | None = Field(
|
||||
description="Path to the config for this model, if any.",
|
||||
default=None,
|
||||
)
|
||||
converted_at: float | None = Field(
|
||||
description="When this model was last converted to diffusers",
|
||||
default_factory=time.time,
|
||||
@@ -365,66 +417,57 @@ class T5EncoderConfig(ModelConfigBase):
|
||||
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
|
||||
format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.T5Encoder,
|
||||
"format": ModelFormat.T5Encoder,
|
||||
}
|
||||
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"T5EncoderModel",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES)
|
||||
_validate_class_name(cls, mod.path / "config.json", {"T5EncoderModel"})
|
||||
|
||||
# Heuristic: Look for the presence of the unquantized config file (not present for bnb-quantized models)
|
||||
cls._validate_has_unquantized_config_file(mod)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_has_unquantized_config_file(cls, mod: ModelOnDisk) -> None:
|
||||
has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists()
|
||||
|
||||
if not has_unquantized_config:
|
||||
raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json")
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
class T5EncoderBnbQuantizedLlmInt8bConfig(ModelConfigBase):
|
||||
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
||||
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
|
||||
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.T5Encoder,
|
||||
"format": ModelFormat.BnbQuantizedLlmInt8b,
|
||||
}
|
||||
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"T5EncoderModel",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
# Heuristic: Look for the T5EncoderModel class name in the config
|
||||
_validate_class_names(cls, mod.path / "text_encoder_2" / "config.json", cls.VALID_CLASS_NAMES)
|
||||
_validate_class_name(cls, mod.path / "config.json", {"T5EncoderModel"})
|
||||
|
||||
# Heuristic: look for the quantization in the filename name
|
||||
filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix())
|
||||
cls._validate_filename_looks_like_bnb_quantized(mod)
|
||||
|
||||
# Heuristic: Look for the presence of "SCB" suffixes in state dict keys
|
||||
has_scb_key_suffix = mod.has_keys_ending_with("SCB")
|
||||
|
||||
if not filename_looks_like_bnb and not has_scb_key_suffix:
|
||||
raise NotAMatch(cls, "missing bnb quantization indicators")
|
||||
cls._validate_model_looks_like_bnb_quantized(mod)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_filename_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
|
||||
filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix())
|
||||
if not filename_looks_like_bnb:
|
||||
raise NotAMatch(cls, "filename does not look like bnb quantized llm_int8")
|
||||
|
||||
@classmethod
|
||||
def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
|
||||
has_scb_key_suffix = mod.has_keys_ending_with("SCB")
|
||||
if not has_scb_key_suffix:
|
||||
raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8")
|
||||
|
||||
|
||||
class LoRAConfigBase(ABC, BaseModel):
|
||||
"""Base class for LoRA models."""
|
||||
@@ -453,36 +496,40 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
base: Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL] = Field()
|
||||
format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.LoRA,
|
||||
"format": ModelFormat.OMI,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
# OMI LoRAs are always files
|
||||
_raise_if_not_file(cls, mod)
|
||||
_validate_is_file(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
|
||||
if get_flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA")
|
||||
cls._validate_is_not_controllora_or_diffusers(mod)
|
||||
|
||||
# Heuristic: Look for OMI LoRA metadata
|
||||
cls._validate_metadata_looks_like_omi(mod)
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _validate_is_not_controllora_or_diffusers(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model is a ControlLoRA or Diffusers LoRA."""
|
||||
flux_format = get_flux_lora_format(mod)
|
||||
if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA")
|
||||
|
||||
@classmethod
|
||||
def _validate_metadata_looks_like_omi(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model metadata does not look like an OMI LoRA."""
|
||||
metadata = mod.metadata()
|
||||
is_omi_lora_heuristic = (
|
||||
|
||||
metadata_looks_like_omi_lora = (
|
||||
bool(metadata.get("modelspec.sai_model_spec"))
|
||||
and metadata.get("ot_branch") == "omi_format"
|
||||
and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora"
|
||||
)
|
||||
|
||||
if not is_omi_lora_heuristic:
|
||||
raise NotAMatch(cls, "model does not match OMI LoRA heuristics")
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
|
||||
return cls(**fields, base=base)
|
||||
if not metadata_looks_like_omi_lora:
|
||||
raise NotAMatch(cls, "metadata does not look like OMI LoRA")
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL]:
|
||||
@@ -512,21 +559,20 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
|
||||
format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.LoRA,
|
||||
"format": ModelFormat.LyCORIS,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
_validate_is_file(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
# Heuristic: differential diagnosis vs ControlLoRA and Diffusers
|
||||
if get_flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
raise NotAMatch(cls, "model is a ControlLoRA or Diffusers LoRA")
|
||||
cls._validate_is_not_controllora_or_diffusers(mod)
|
||||
|
||||
cls._validate_looks_like_lora(mod)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
|
||||
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
|
||||
# Some main models have these keys, likely due to the creator merging in a LoRA.
|
||||
has_key_with_lora_prefix = mod.has_keys_starting_with(
|
||||
@@ -551,7 +597,12 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
if not has_key_with_lora_prefix and not has_key_with_lora_suffix:
|
||||
raise NotAMatch(cls, "model does not match LyCORIS LoRA heuristics")
|
||||
|
||||
return cls(**fields)
|
||||
@classmethod
|
||||
def _validate_is_not_controllora_or_diffusers(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model is a ControlLoRA or Diffusers LoRA."""
|
||||
flux_format = get_flux_lora_format(mod)
|
||||
if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA")
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> LoRALyCORIS_SupportedBases:
|
||||
@@ -591,15 +642,21 @@ class ControlLoRALyCORISConfig(ControlAdapterConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
_validate_is_file(cls, mod)
|
||||
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
cls._validate_looks_like_control_lora(mod)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_control_lora(cls, mod: ModelOnDisk) -> None:
|
||||
state_dict = mod.load_state_dict()
|
||||
|
||||
if not is_state_dict_likely_flux_control(state_dict):
|
||||
raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA")
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
ControlLoRADiffusers_SupportedBases: TypeAlias = Literal[BaseModelType.Flux]
|
||||
|
||||
@@ -618,28 +675,31 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
||||
base: LoRADiffusers_SupportedBases = Field()
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.LoRA,
|
||||
"format": ModelFormat.Diffusers,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
# Diffusers-style models always directories
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
is_flux_lora_diffusers = get_flux_lora_format(mod) is FluxLoRAFormat.Diffusers
|
||||
cls._validate_looks_like_diffusers_lora(mod)
|
||||
|
||||
cls._validate_has_lora_weight_file(mod)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_diffusers_lora(cls, mod: ModelOnDisk) -> None:
|
||||
flux_lora_format = get_flux_lora_format(mod)
|
||||
if flux_lora_format is not FluxLoRAFormat.Diffusers:
|
||||
raise NotAMatch(cls, "model does not look like a FLUX Diffusers LoRA")
|
||||
|
||||
@classmethod
|
||||
def _validate_has_lora_weight_file(cls, mod: ModelOnDisk) -> None:
|
||||
suffixes = ["bin", "safetensors"]
|
||||
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
|
||||
has_lora_weight_file = any(wf.exists() for wf in weight_files)
|
||||
|
||||
if not is_flux_lora_diffusers and not has_lora_weight_file:
|
||||
raise NotAMatch(cls, "model does not match Diffusers LoRA heuristics")
|
||||
|
||||
return cls(**fields)
|
||||
if not has_lora_weight_file:
|
||||
raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
|
||||
|
||||
|
||||
VAECheckpointConfig_SupportedBases: TypeAlias = Literal[
|
||||
@@ -657,11 +717,6 @@ class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase):
|
||||
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
|
||||
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.VAE,
|
||||
"format": ModelFormat.Checkpoint,
|
||||
}
|
||||
|
||||
REGEX_TO_BASE: ClassVar[dict[str, VAECheckpointConfig_SupportedBases]] = {
|
||||
r"xl": BaseModelType.StableDiffusionXL,
|
||||
r"sd2": BaseModelType.StableDiffusion2,
|
||||
@@ -671,16 +726,20 @@ class VAECheckpointConfig(CheckpointConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
_validate_is_file(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
|
||||
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
|
||||
cls._validate_looks_like_vae(mod)
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
|
||||
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
|
||||
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAECheckpointConfig_SupportedBases:
|
||||
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
|
||||
@@ -704,22 +763,13 @@ class VAEDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.VAE,
|
||||
"format": ModelFormat.Diffusers,
|
||||
}
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"AutoencoderKL",
|
||||
"AutoencoderTiny",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES)
|
||||
_validate_class_name(cls, mod.path / "config.json", {"AutoencoderKL", "AutoencoderTiny"})
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
@@ -770,23 +820,13 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, M
|
||||
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.ControlNet,
|
||||
"format": ModelFormat.Diffusers,
|
||||
}
|
||||
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"ControlNetModel",
|
||||
"FluxControlNetModel",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES)
|
||||
_validate_class_name(cls, mod.path / "config.json", {"ControlNetModel", "FluxControlNetModel"})
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
|
||||
@@ -829,17 +869,20 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
|
||||
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
|
||||
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.ControlNet,
|
||||
"format": ModelFormat.Checkpoint,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
_validate_is_file(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
cls._validate_looks_like_controlnet(mod)
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None:
|
||||
if not mod.has_keys_starting_with(
|
||||
{
|
||||
"controlnet",
|
||||
@@ -855,10 +898,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase,
|
||||
):
|
||||
raise NotAMatch(cls, "state dict does not look like a ControlNet checkpoint")
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetCheckpoint_SupportedBases:
|
||||
state_dict = mod.load_state_dict()
|
||||
@@ -970,20 +1009,11 @@ class TextualInversionFileConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
|
||||
format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.TextualInversion,
|
||||
"format": ModelFormat.EmbeddingFile,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
_validate_is_file(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
if not cls._file_looks_like_embedding(mod):
|
||||
raise NotAMatch(cls, "model does not look like a textual inversion embedding file")
|
||||
@@ -997,20 +1027,11 @@ class TextualInversionFolderConfig(TextualInversionConfigBase, ModelConfigBase):
|
||||
|
||||
format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.TextualInversion,
|
||||
"format": ModelFormat.EmbeddingFolder,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
for p in mod.weight_files():
|
||||
if cls._file_looks_like_embedding(mod, p):
|
||||
@@ -1105,28 +1126,31 @@ class IPAdapterInvokeAIConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
# time. Need to go through the history to make sure I'm understanding this fully.
|
||||
image_encoder_model_id: str = Field()
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.IPAdapter,
|
||||
"format": ModelFormat.InvokeAI,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
cls._validate_has_weights_file(mod)
|
||||
|
||||
cls._validate_has_image_encoder_metadata_file(mod)
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _validate_has_weights_file(cls, mod: ModelOnDisk) -> None:
|
||||
weights_file = mod.path / "ip_adapter.bin"
|
||||
if not weights_file.exists():
|
||||
raise NotAMatch(cls, "missing ip_adapter.bin weights file")
|
||||
|
||||
@classmethod
|
||||
def _validate_has_image_encoder_metadata_file(cls, mod: ModelOnDisk) -> None:
|
||||
image_encoder_metadata_file = mod.path / "image_encoder.txt"
|
||||
if not image_encoder_metadata_file.exists():
|
||||
raise NotAMatch(cls, "missing image_encoder.txt metadata file")
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterInvokeAI_SupportedBases:
|
||||
state_dict = mod.load_state_dict()
|
||||
@@ -1161,17 +1185,19 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
base: IPAdapterCheckpoint_SupportedBases = Field()
|
||||
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.IPAdapter,
|
||||
"format": ModelFormat.Checkpoint,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
_validate_is_file(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
cls._validate_looks_like_ip_adapter(mod)
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None:
|
||||
if not mod.has_keys_starting_with(
|
||||
{
|
||||
"image_proj.",
|
||||
@@ -1182,9 +1208,6 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
):
|
||||
raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics")
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpoint_SupportedBases:
|
||||
state_dict = mod.load_state_dict()
|
||||
@@ -1215,14 +1238,8 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"CLIPModel",
|
||||
"CLIPTextModel",
|
||||
"CLIPTextModelWithProjection",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None:
|
||||
def _get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None:
|
||||
try:
|
||||
hidden_size = config.get("hidden_size")
|
||||
match hidden_size:
|
||||
@@ -1241,69 +1258,64 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
|
||||
variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.CLIPEmbed,
|
||||
"format": ModelFormat.Diffusers,
|
||||
"variant": ClipVariantType.G,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
_validate_class_name(
|
||||
cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"}
|
||||
)
|
||||
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
cls._validate_clip_g_variant(mod)
|
||||
|
||||
config = _get_config_or_raise(cls, config_path)
|
||||
return cls(**fields)
|
||||
|
||||
clip_variant = cls.get_clip_variant_type(config)
|
||||
@classmethod
|
||||
def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None:
|
||||
config = _get_config_or_raise(cls, mod.path / "config.json")
|
||||
clip_variant = cls._get_clip_variant_type(config)
|
||||
|
||||
if clip_variant is not ClipVariantType.G:
|
||||
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
"""Model config for CLIP-L Embeddings."""
|
||||
|
||||
variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.CLIPEmbed,
|
||||
"format": ModelFormat.Diffusers,
|
||||
"variant": ClipVariantType.L,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
_validate_class_name(
|
||||
cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"}
|
||||
)
|
||||
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
|
||||
config = _get_config_or_raise(cls, config_path)
|
||||
clip_variant = cls.get_clip_variant_type(config)
|
||||
|
||||
if clip_variant is not ClipVariantType.L:
|
||||
raise NotAMatch(cls, "model does not match CLIP-L heuristics")
|
||||
cls._validate_clip_l_variant(mod)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None:
|
||||
config = _get_config_or_raise(cls, mod.path / "config.json")
|
||||
clip_variant = cls._get_clip_variant_type(config)
|
||||
|
||||
if clip_variant is not ClipVariantType.L:
|
||||
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
@@ -1312,24 +1324,13 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision)
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.CLIPVision,
|
||||
"format": ModelFormat.Diffusers,
|
||||
}
|
||||
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"CLIPVisionModelWithProjection",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
_validate_class_name(cls, mod.path / "config.json", {"CLIPVisionModelWithProjection"})
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@@ -1347,24 +1348,13 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
|
||||
type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.T2IAdapter,
|
||||
"format": ModelFormat.Diffusers,
|
||||
}
|
||||
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"T2IAdapter",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
_validate_class_name(cls, mod.path / "config.json", {"T2IAdapter"})
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
|
||||
@@ -1392,17 +1382,18 @@ class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
type: Literal[ModelType.SpandrelImageToImage] = Field(default=ModelType.SpandrelImageToImage)
|
||||
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.SpandrelImageToImage,
|
||||
"format": ModelFormat.Checkpoint,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
_validate_is_file(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
cls._validate_spandrel_loads_model(mod)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_spandrel_loads_model(cls, mod: ModelOnDisk) -> None:
|
||||
try:
|
||||
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
|
||||
# explored to avoid this:
|
||||
@@ -1413,7 +1404,6 @@ class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
|
||||
# maintain it, and the risk of false positive detections is higher.
|
||||
SpandrelImageToImageModel.load_from_file(mod.path)
|
||||
return cls(**fields)
|
||||
except Exception as e:
|
||||
raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e
|
||||
|
||||
@@ -1425,24 +1415,13 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.SigLIP,
|
||||
"format": ModelFormat.Diffusers,
|
||||
}
|
||||
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"SiglipModel",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
_validate_class_name(cls, mod.path / "config.json", {"SiglipModel"})
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@@ -1454,16 +1433,11 @@ class FluxReduxConfig(ModelConfigBase):
|
||||
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
||||
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.FluxRedux,
|
||||
"format": ModelFormat.Checkpoint,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_file(cls, mod)
|
||||
_validate_is_file(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
if not is_state_dict_likely_flux_redux(mod.load_state_dict()):
|
||||
raise NotAMatch(cls, "model does not match FLUX Tools Redux heuristics")
|
||||
@@ -1478,24 +1452,13 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
||||
variant: Literal[ModelVariantType.Normal] = Field(default=ModelVariantType.Normal)
|
||||
|
||||
VALID_OVERRIDES: ClassVar = {
|
||||
"type": ModelType.LlavaOnevision,
|
||||
"format": ModelFormat.Diffusers,
|
||||
}
|
||||
|
||||
VALID_CLASS_NAMES: ClassVar = {
|
||||
"LlavaOnevisionForConditionalGeneration",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_raise_if_not_dir(cls, mod)
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
_validate_class_names(cls, config_path, cls.VALID_CLASS_NAMES)
|
||||
_validate_class_name(cls, mod.path / "config.json", {"LlavaOnevisionForConditionalGeneration"})
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user