mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor: port MM probes to new api
- Add concept of match certainty to new probe - Port CLIP Embed models to new API - Fiddle with stuff
This commit is contained in:
@@ -13,7 +13,7 @@ from invokeai.app.util.t5_model_identifier import (
|
||||
preprocess_t5_encoder_model_identifier,
|
||||
preprocess_t5_tokenizer_model_identifier,
|
||||
)
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.flux.util import get_flux_max_seq_length
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
)
|
||||
@@ -94,5 +94,5 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
max_seq_len=get_flux_max_seq_length(transformer_config.variant),
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from shutil import move, rmtree
|
||||
@@ -370,6 +371,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
model_path = self.app_config.models_path / model.path
|
||||
if model_path.is_file() or model_path.is_symlink():
|
||||
model_path.unlink()
|
||||
assert model_path.parent != self.app_config.models_path
|
||||
os.rmdir(model_path.parent)
|
||||
elif model_path.is_dir():
|
||||
rmtree(model_path)
|
||||
self.unregister(key)
|
||||
@@ -607,9 +610,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# any given model - eliminating ambiguity and removing reliance on order.
|
||||
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
|
||||
try:
|
||||
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
|
||||
return ModelProbe.probe(model_path=model_path, fields=deepcopy(fields), hash_algo=hash_algo) # type: ignore
|
||||
except InvalidModelConfigException:
|
||||
return ModelConfigBase.classify(model_path, hash_algo, **fields)
|
||||
return ModelConfigBase.classify(mod=model_path, hash_algo=hash_algo, **fields)
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
||||
|
||||
@@ -21,6 +21,7 @@ from invokeai.backend.model_manager.config import (
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
FluxVariantType,
|
||||
ModelFormat,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
@@ -90,7 +91,9 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None)
|
||||
variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType] = Field(
|
||||
description="The variant of the model.", default=None
|
||||
)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
|
||||
@@ -141,10 +141,25 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.transaction() as cursor:
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
# The changes may mean the model config class changes. So we need to:
|
||||
#
|
||||
# 1. convert the existing record to a dict
|
||||
# 2. apply the changes to the dict
|
||||
# 3. create a new model config from the updated dict
|
||||
#
|
||||
# This way we ensure that the update does not inadvertently create an invalid model config.
|
||||
|
||||
# 1. convert the existing record to a dict
|
||||
record_as_dict = record.model_dump()
|
||||
|
||||
# 2. apply the changes to the dict
|
||||
for field_name in changes.model_fields_set:
|
||||
record_as_dict[field_name] = getattr(changes, field_name)
|
||||
|
||||
# 3. create a new model config from the updated dict
|
||||
record = ModelConfigFactory.make_config(record_as_dict)
|
||||
|
||||
# If we get this far, the updated model config is valid, so we can save it to the database.
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
cursor.execute(
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, AnyModelConfigValidator
|
||||
from invokeai.backend.model_manager.config import AnyModelConfigValidator
|
||||
|
||||
|
||||
class NormalizeResult(NamedTuple):
|
||||
@@ -30,7 +30,7 @@ class Migration22Callback:
|
||||
for model_id, config_json in rows:
|
||||
try:
|
||||
# Get the model config as a pydantic object
|
||||
config = self._load_model_config(config_json)
|
||||
config = AnyModelConfigValidator.validate_json(config_json)
|
||||
except ValidationError:
|
||||
# This could happen if the config schema changed in a way that makes old configs invalid. Unlikely
|
||||
# for users, more likely for devs testing out migration paths.
|
||||
@@ -216,11 +216,6 @@ class Migration22Callback:
|
||||
|
||||
self._logger.info("Pruned %d empty directories under %s", len(removed_dirs), self._models_dir)
|
||||
|
||||
def _load_model_config(self, config_json: str) -> AnyModelConfig:
|
||||
# The typing of the validator says it returns Unknown, but it's really a AnyModelConfig. This utility function
|
||||
# just makes that clear.
|
||||
return AnyModelConfigValidator.validate_json(config_json)
|
||||
|
||||
|
||||
def build_migration_22(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||
"""Builds the migration object for migrating from version 21 to version 22.
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Literal
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
||||
from invokeai.backend.model_manager.taxonomy import AnyVariant, FluxVariantType
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -41,30 +42,39 @@ PREFERED_KONTEXT_RESOLUTIONS = [
|
||||
]
|
||||
|
||||
|
||||
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||
"flux-dev": 512,
|
||||
"flux-dev-fill": 512,
|
||||
"flux-schnell": 256,
|
||||
_flux_max_seq_lengths: dict[AnyVariant, Literal[256, 512]] = {
|
||||
FluxVariantType.Dev: 512,
|
||||
FluxVariantType.DevFill: 512,
|
||||
FluxVariantType.Schnell: 256,
|
||||
}
|
||||
|
||||
|
||||
ae_params = {
|
||||
"flux": AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
}
|
||||
def get_flux_max_seq_length(variant: AnyVariant):
|
||||
try:
|
||||
return _flux_max_seq_lengths[variant]
|
||||
except KeyError:
|
||||
raise ValueError(f"Unknown variant for FLUX max seq len: {variant}")
|
||||
|
||||
|
||||
params = {
|
||||
"flux-dev": FluxParams(
|
||||
_flux_ae_params = AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
|
||||
def get_flux_ae_params() -> AutoEncoderParams:
|
||||
return _flux_ae_params
|
||||
|
||||
|
||||
_flux_transformer_params: dict[AnyVariant, FluxParams] = {
|
||||
FluxVariantType.Dev: FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
@@ -78,7 +88,7 @@ params = {
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
"flux-schnell": FluxParams(
|
||||
FluxVariantType.Schnell: FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
@@ -92,7 +102,7 @@ params = {
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
"flux-dev-fill": FluxParams(
|
||||
FluxVariantType.DevFill: FluxParams(
|
||||
in_channels=384,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
@@ -108,3 +118,10 @@ params = {
|
||||
guidance_embed=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_flux_transformers_params(variant: AnyVariant):
|
||||
try:
|
||||
return _flux_transformer_params[variant]
|
||||
except KeyError:
|
||||
raise ValueError(f"Unknown variant for FLUX transformer params: {variant}")
|
||||
|
||||
@@ -44,6 +44,7 @@ from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
FluxLoRAFormat,
|
||||
FluxVariantType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
@@ -51,6 +52,7 @@ from invokeai.backend.model_manager.taxonomy import (
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
variant_type_adapter,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
@@ -71,7 +73,7 @@ DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
model_type: ModelType
|
||||
variant: AnyVariant = None
|
||||
variant: AnyVariant | None = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@@ -111,6 +113,15 @@ class MatchSpeed(int, Enum):
|
||||
SLOW = 2
|
||||
|
||||
|
||||
class MatchCertainty(int, Enum):
|
||||
"""Represents the certainty of a config's 'matches' method."""
|
||||
|
||||
NEVER = 0
|
||||
MAYBE = 1
|
||||
EXACT = 2
|
||||
OVERRIDE = 3
|
||||
|
||||
|
||||
class LegacyProbeMixin:
|
||||
"""Mixin for classes using the legacy probe for model classification."""
|
||||
|
||||
@@ -194,20 +205,47 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
Created to deprecate ModelProbe.probe
|
||||
"""
|
||||
if isinstance(mod, Path | str):
|
||||
mod = ModelOnDisk(mod, hash_algo)
|
||||
mod = ModelOnDisk(Path(mod), hash_algo)
|
||||
|
||||
candidates = ModelConfigBase.USING_CLASSIFY_API
|
||||
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
|
||||
|
||||
overrides = overrides or {}
|
||||
ModelConfigBase.cast_overrides(**overrides)
|
||||
|
||||
matches: dict[Type[ModelConfigBase], MatchCertainty] = {}
|
||||
|
||||
for config_cls in sorted_by_match_speed:
|
||||
try:
|
||||
if not config_cls.matches(mod):
|
||||
score = config_cls.matches(mod, **overrides)
|
||||
|
||||
# A score of 0 means "no match"
|
||||
if score is MatchCertainty.NEVER:
|
||||
continue
|
||||
|
||||
matches[config_cls] = score
|
||||
|
||||
if score is MatchCertainty.EXACT or score is MatchCertainty.OVERRIDE:
|
||||
# Perfect match - skip further checks
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
|
||||
continue
|
||||
else:
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
|
||||
if matches:
|
||||
# Select the config class with the highest score
|
||||
sorted_by_score = sorted(matches.items(), key=lambda item: item[1].value)
|
||||
# Check if there are multiple classes with the same top score
|
||||
top_score = sorted_by_score[-1][1]
|
||||
top_classes = [cls for cls, score in sorted_by_score if score is top_score]
|
||||
if len(top_classes) > 1:
|
||||
logger.warning(
|
||||
f"Multiple model config classes matched with the same top score ({top_score}) for model {mod.name}: {[cls.__name__ for cls in top_classes]}. Using {top_classes[0].__name__}."
|
||||
)
|
||||
config_cls = top_classes[0]
|
||||
# Finally, create the config instance
|
||||
logger.info(f"Model {mod.name} classified as {config_cls.__name__} with score {top_score.name}")
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
|
||||
if app_config.allow_unknown_models:
|
||||
try:
|
||||
@@ -234,13 +272,13 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
"""Performs a quick check to determine if the config matches the model.
|
||||
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
|
||||
Returns a MatchCertainty score."""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def cast_overrides(overrides: dict[str, Any]):
|
||||
def cast_overrides(**overrides):
|
||||
"""Casts user overrides from str to Enum"""
|
||||
if "type" in overrides:
|
||||
overrides["type"] = ModelType(overrides["type"])
|
||||
@@ -255,13 +293,13 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
overrides["source_type"] = ModelSourceType(overrides["source_type"])
|
||||
|
||||
if "variant" in overrides:
|
||||
overrides["variant"] = ModelVariantType(overrides["variant"])
|
||||
overrides["variant"] = variant_type_adapter.validate_strings(overrides["variant"])
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
|
||||
"""Creates an instance of this config or raises InvalidModelConfigException."""
|
||||
fields = cls.parse(mod)
|
||||
cls.cast_overrides(overrides)
|
||||
cls.cast_overrides(**overrides)
|
||||
fields.update(overrides)
|
||||
|
||||
fields["path"] = mod.path.as_posix()
|
||||
@@ -283,8 +321,8 @@ class UnknownModelConfig(ModelConfigBase):
|
||||
format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
return False
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
@@ -370,17 +408,34 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_lora_override = overrides.get("type") is ModelType.LoRA
|
||||
is_omi_override = overrides.get("format") is ModelFormat.OMI
|
||||
|
||||
# If both type and format are overridden, skip the heuristic checks
|
||||
if is_lora_override and is_omi_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
# OMI LoRAs are always files, never directories
|
||||
if mod.path.is_dir():
|
||||
return False
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
# Avoid false positive match against ControlLoRA and Diffusers
|
||||
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
metadata = mod.metadata()
|
||||
return (
|
||||
is_omi_lora_heuristic = (
|
||||
bool(metadata.get("modelspec.sai_model_spec"))
|
||||
and metadata.get("ot_branch") == "omi_format"
|
||||
and metadata["modelspec.architecture"].split("/")[1].lower() == "lora"
|
||||
and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora"
|
||||
)
|
||||
|
||||
if is_omi_lora_heuristic:
|
||||
return MatchCertainty.EXACT
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
metadata = mod.metadata()
|
||||
@@ -402,27 +457,55 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_lora_override = overrides.get("type") is ModelType.LoRA
|
||||
is_omi_override = overrides.get("format") is ModelFormat.LyCORIS
|
||||
|
||||
# If both type and format are overridden, skip the heuristic checks and return a perfect score
|
||||
if is_lora_override and is_omi_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
# LyCORIS LoRAs are always files, never directories
|
||||
if mod.path.is_dir():
|
||||
return False
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
# Avoid false positive match against ControlLoRA and Diffusers
|
||||
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
return False
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
for key in state_dict.keys():
|
||||
if isinstance(key, int):
|
||||
continue
|
||||
|
||||
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
|
||||
return True
|
||||
# Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
|
||||
# Some main models have these keys, likely due to the creator merging in a LoRA.
|
||||
|
||||
has_key_with_lora_prefix = key.startswith(
|
||||
(
|
||||
"lora_te_",
|
||||
"lora_unet_",
|
||||
"lora_te1_",
|
||||
"lora_te2_",
|
||||
"lora_transformer_",
|
||||
)
|
||||
)
|
||||
|
||||
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
|
||||
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
|
||||
if key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
|
||||
return True
|
||||
has_key_with_lora_suffix = key.endswith(
|
||||
(
|
||||
"to_k_lora.up.weight",
|
||||
"to_q_lora.down.weight",
|
||||
"lora_A.weight",
|
||||
"lora_B.weight",
|
||||
)
|
||||
)
|
||||
|
||||
return False
|
||||
if has_key_with_lora_prefix or has_key_with_lora_suffix:
|
||||
return MatchCertainty.MAYBE
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
@@ -459,13 +542,31 @@ class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_lora_override = overrides.get("type") is ModelType.LoRA
|
||||
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
|
||||
|
||||
# If both type and format are overridden, skip the heuristic checks and return a perfect score
|
||||
if is_lora_override and is_diffusers_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
# Diffusers LoRAs are always directories, never files
|
||||
if mod.path.is_file():
|
||||
return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
is_flux_lora_diffusers = cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
|
||||
|
||||
suffixes = ["bin", "safetensors"]
|
||||
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
|
||||
return any(wf.exists() for wf in weight_files)
|
||||
has_lora_weight_file = any(wf.exists() for wf in weight_files)
|
||||
|
||||
if is_flux_lora_diffusers and has_lora_weight_file:
|
||||
return MatchCertainty.EXACT
|
||||
|
||||
if is_flux_lora_diffusers or has_lora_weight_file:
|
||||
return MatchCertainty.MAYBE
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
@@ -520,7 +621,7 @@ class MainConfigBase(ABC, BaseModel):
|
||||
default_settings: Optional[MainModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
variant: AnyVariant = ModelVariantType.Normal
|
||||
variant: ModelVariantType | FluxVariantType = ModelVariantType.Normal
|
||||
|
||||
|
||||
class VideoConfigBase(ABC, BaseModel):
|
||||
@@ -529,7 +630,7 @@ class VideoConfigBase(ABC, BaseModel):
|
||||
default_settings: Optional[MainModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
variant: AnyVariant = ModelVariantType.Normal
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
@@ -538,6 +639,14 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixi
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
# @classmethod
|
||||
# def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
# pass
|
||||
|
||||
# @classmethod
|
||||
# def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
# pass
|
||||
|
||||
|
||||
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
@@ -583,12 +692,51 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConf
|
||||
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for Clip Embeddings."""
|
||||
|
||||
variant: ClipVariantType = Field(description="Clip variant for this model")
|
||||
variant: ClipVariantType = Field(...)
|
||||
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
base: Literal[BaseModelType.Any] = BaseModelType.Any
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
clip_variant = cls.get_clip_variant_type(mod)
|
||||
if clip_variant is None:
|
||||
raise InvalidModelConfigException("Unable to determine CLIP variant type")
|
||||
|
||||
return {"variant": clip_variant}
|
||||
|
||||
@classmethod
|
||||
def get_clip_variant_type(cls, mod: ModelOnDisk) -> ClipVariantType | None:
|
||||
try:
|
||||
with open(mod.path / "config.json") as file:
|
||||
config = json.load(file)
|
||||
hidden_size = config.get("hidden_size")
|
||||
match hidden_size:
|
||||
case 1280:
|
||||
return ClipVariantType.G
|
||||
case 768:
|
||||
return ClipVariantType.L
|
||||
case _:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_clip_text_encoder(cls, mod: ModelOnDisk) -> bool:
|
||||
try:
|
||||
with open(mod.path / "config.json", "r") as file:
|
||||
config = json.load(file)
|
||||
architectures = config.get("architectures")
|
||||
return architectures[0] in (
|
||||
"CLIPModel",
|
||||
"CLIPTextModel",
|
||||
"CLIPTextModelWithProjection",
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
|
||||
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
"""Model config for CLIP-G Embeddings."""
|
||||
|
||||
variant: Literal[ClipVariantType.G] = ClipVariantType.G
|
||||
@@ -597,8 +745,28 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, Mode
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_clip_embed_override = overrides.get("type") is ModelType.CLIPEmbed
|
||||
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
|
||||
has_clip_variant_override = overrides.get("variant") is ClipVariantType.G
|
||||
|
||||
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
|
||||
if is_clip_embed_override and is_diffusers_override and has_clip_variant_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
is_clip_embed = cls.is_clip_text_encoder(mod)
|
||||
clip_variant = cls.get_clip_variant_type(mod)
|
||||
|
||||
if is_clip_embed and clip_variant is ClipVariantType.G:
|
||||
return MatchCertainty.EXACT
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
|
||||
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
"""Model config for CLIP-L Embeddings."""
|
||||
|
||||
variant: Literal[ClipVariantType.L] = ClipVariantType.L
|
||||
@@ -607,6 +775,26 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, Mode
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_clip_embed_override = overrides.get("type") is ModelType.CLIPEmbed
|
||||
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
|
||||
has_clip_variant_override = overrides.get("variant") is ClipVariantType.L
|
||||
|
||||
if is_clip_embed_override and is_diffusers_override and has_clip_variant_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_file():
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
is_clip_embed = cls.is_clip_text_encoder(mod)
|
||||
clip_variant = cls.get_clip_variant_type(mod)
|
||||
|
||||
if is_clip_embed and clip_variant is ClipVariantType.L:
|
||||
return MatchCertainty.EXACT
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
@@ -649,22 +837,30 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
"""Model config for Llava Onevision models."""
|
||||
|
||||
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
is_llava_override = overrides.get("type") is ModelType.LlavaOnevision
|
||||
is_diffusers_override = overrides.get("format") is ModelFormat.Diffusers
|
||||
|
||||
if is_llava_override and is_diffusers_override:
|
||||
return MatchCertainty.OVERRIDE
|
||||
|
||||
if mod.path.is_file():
|
||||
return False
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
try:
|
||||
with open(config_path, "r") as file:
|
||||
config = json.load(file)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
architectures = config.get("architectures")
|
||||
return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration"
|
||||
if architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration":
|
||||
return MatchCertainty.EXACT
|
||||
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
@@ -680,9 +876,9 @@ class ApiModelConfig(MainConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.Api] = ModelFormat.Api
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
# API models are not stored on disk, so we can't match them.
|
||||
return False
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
@@ -695,9 +891,9 @@ class VideoApiModelConfig(VideoConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.Api] = ModelFormat.Api
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
|
||||
# API models are not stored on disk, so we can't match them.
|
||||
return False
|
||||
return MatchCertainty.NEVER
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
@@ -735,8 +931,7 @@ def get_model_discriminator_value(v: Any) -> str:
|
||||
|
||||
# Previously, CLIPEmbed did not have any variants, meaning older database entries lack a variant field.
|
||||
# To maintain compatibility, we default to ClipVariantType.L in this case.
|
||||
if type_ == ModelType.CLIPEmbed.value and format_ == ModelFormat.Diffusers.value:
|
||||
variant_ = variant_ or ClipVariantType.L.value
|
||||
if type_ == ModelType.CLIPEmbed.value:
|
||||
return f"{type_}.{format_}.{variant_}"
|
||||
return f"{type_}.{format_}"
|
||||
|
||||
@@ -779,7 +974,7 @@ AnyModelConfig = Annotated[
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig)
|
||||
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings]
|
||||
|
||||
|
||||
@@ -787,8 +982,8 @@ class ModelConfigFactory:
|
||||
@staticmethod
|
||||
def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig:
|
||||
"""Return the appropriate config object from raw dict values."""
|
||||
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
||||
model = AnyModelConfigValidator.validate_python(model_data)
|
||||
if isinstance(model, CheckpointConfigBase) and timestamp:
|
||||
model.converted_at = timestamp
|
||||
validate_hash(model.hash)
|
||||
return model # type: ignore
|
||||
return model
|
||||
|
||||
@@ -33,6 +33,7 @@ from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
FluxVariantType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
@@ -63,6 +64,7 @@ from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
CkptType = Dict[str | int, Any]
|
||||
|
||||
|
||||
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: {
|
||||
@@ -106,7 +108,7 @@ class ProbeBase(object):
|
||||
"""Get model file format."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_variant_type(self) -> Optional[ModelVariantType]:
|
||||
def get_variant_type(self) -> AnyVariant | None:
|
||||
"""Get model variant type."""
|
||||
return None
|
||||
|
||||
@@ -256,7 +258,7 @@ class ModelProbe(object):
|
||||
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
|
||||
fields["submodels"] = get_submodels()
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
|
||||
model_info = ModelConfigFactory.make_config(fields)
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
@@ -580,7 +582,7 @@ class CheckpointProbeBase(ProbeBase):
|
||||
return ModelFormat.GGUFQuantized
|
||||
return ModelFormat("checkpoint")
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
def get_variant_type(self) -> AnyVariant:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||
base_type = self.get_base_type()
|
||||
if model_type != ModelType.Main:
|
||||
@@ -597,19 +599,26 @@ class CheckpointProbeBase(ProbeBase):
|
||||
)
|
||||
return ModelVariantType.Normal
|
||||
|
||||
is_flux_dev = (
|
||||
"guidance_in.out_layer.weight" in state_dict
|
||||
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
|
||||
)
|
||||
|
||||
# FLUX Model variant types are distinguished by input channels:
|
||||
# - Unquantized Dev and Schnell have in_channels=64
|
||||
# - BNB-NF4 Dev and Schnell have in_channels=1
|
||||
# - FLUX Fill has in_channels=384
|
||||
# - Unsure of quantized FLUX Fill models
|
||||
# - Unsure of GGUF-quantized models
|
||||
if in_channels == 384:
|
||||
if is_flux_dev and in_channels == 384:
|
||||
# This is a FLUX Fill model. FLUX Fill needs special handling throughout the application. The variant
|
||||
# type is used to determine whether to use the fill model or the base model.
|
||||
return ModelVariantType.Inpaint
|
||||
else:
|
||||
return FluxVariantType.DevFill
|
||||
elif is_flux_dev:
|
||||
# Fall back on "normal" variant type for all other FLUX models.
|
||||
return ModelVariantType.Normal
|
||||
return FluxVariantType.Dev
|
||||
else:
|
||||
return FluxVariantType.Schnell
|
||||
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
if in_channels == 9:
|
||||
|
||||
@@ -33,10 +33,11 @@ from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import (
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
|
||||
from invokeai.backend.flux.util import ae_params, params
|
||||
from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers_params
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
CheckpointConfigBase,
|
||||
CLIPEmbedCheckpointConfig,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
ControlNetCheckpointConfig,
|
||||
ControlNetDiffusersConfig,
|
||||
@@ -56,6 +57,7 @@ from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
@@ -90,7 +92,7 @@ class FluxVAELoader(ModelLoader):
|
||||
model_path = Path(config.path)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoEncoder(ae_params[config.config_path])
|
||||
model = AutoEncoder(get_flux_ae_params())
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
# VAE is broken in float16, which mps defaults to
|
||||
@@ -107,7 +109,7 @@ class FluxVAELoader(ModelLoader):
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
|
||||
class ClipCheckpointModel(ModelLoader):
|
||||
class CLIPDiffusersLoader(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
@@ -129,6 +131,27 @@ class ClipCheckpointModel(ModelLoader):
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Checkpoint)
|
||||
class CLIPCheckpointLoader(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CLIPEmbedCheckpointConfig):
|
||||
raise ValueError("Only CLIPEmbedCheckpointConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.TextEncoder:
|
||||
return CLIPTextModel.from_pretrained(Path(config.path), use_safetensors=True)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Only TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
|
||||
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
@@ -229,7 +252,7 @@ class FluxCheckpointModel(ModelLoader):
|
||||
model_path = Path(config.path)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params[config.config_path])
|
||||
model = Flux(get_flux_transformers_params(config.variant))
|
||||
|
||||
sd = load_file(model_path)
|
||||
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||
@@ -271,7 +294,7 @@ class FluxGGUFCheckpointModel(ModelLoader):
|
||||
model_path = Path(config.path)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params[config.config_path])
|
||||
model = Flux(get_flux_transformers_params(config.variant))
|
||||
|
||||
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
|
||||
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
|
||||
@@ -322,7 +345,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
|
||||
with SilenceWarnings():
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params[config.config_path])
|
||||
model = Flux(get_flux_transformers_params(config.variant))
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
sd = load_file(model_path)
|
||||
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||
@@ -362,7 +385,7 @@ class FluxControlnetModel(ModelLoader):
|
||||
def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
|
||||
with accelerate.init_empty_weights():
|
||||
# HACK(ryand): Is it safe to assume dev here?
|
||||
model = XLabsControlNetFlux(params["flux-dev"])
|
||||
model = XLabsControlNetFlux(get_flux_transformers_params(ModelVariantType.FluxDev))
|
||||
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
86
invokeai/backend/model_manager/single_file_config_files.py
Normal file
86
invokeai/backend/model_manager/single_file_config_files.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LegacyConfigKey:
|
||||
type: ModelType
|
||||
base: BaseModelType
|
||||
variant: ModelVariantType | None = None
|
||||
pred: SchedulerPredictionType | None = None
|
||||
|
||||
|
||||
LEGACY_CONFIG_MAP: dict[LegacyConfigKey, str] = {
|
||||
LegacyConfigKey(
|
||||
ModelType.Main,
|
||||
BaseModelType.StableDiffusion1,
|
||||
ModelVariantType.Normal,
|
||||
SchedulerPredictionType.Epsilon,
|
||||
): "stable-diffusion/v1-inference.yaml",
|
||||
LegacyConfigKey(
|
||||
ModelType.Main,
|
||||
BaseModelType.StableDiffusion1,
|
||||
ModelVariantType.Normal,
|
||||
SchedulerPredictionType.VPrediction,
|
||||
): "stable-diffusion/v1-inference-v.yaml",
|
||||
LegacyConfigKey(
|
||||
ModelType.Main,
|
||||
BaseModelType.StableDiffusion1,
|
||||
ModelVariantType.Inpaint,
|
||||
): "stable-diffusion/v1-inpainting-inference.yaml",
|
||||
LegacyConfigKey(
|
||||
ModelType.Main,
|
||||
BaseModelType.StableDiffusion2,
|
||||
ModelVariantType.Normal,
|
||||
SchedulerPredictionType.Epsilon,
|
||||
): "stable-diffusion/v2-inference.yaml",
|
||||
LegacyConfigKey(
|
||||
ModelType.Main,
|
||||
BaseModelType.StableDiffusion2,
|
||||
ModelVariantType.Normal,
|
||||
SchedulerPredictionType.VPrediction,
|
||||
): "stable-diffusion/v2-inference-v.yaml",
|
||||
LegacyConfigKey(
|
||||
ModelType.Main,
|
||||
BaseModelType.StableDiffusion2,
|
||||
ModelVariantType.Inpaint,
|
||||
SchedulerPredictionType.Epsilon,
|
||||
): "stable-diffusion/v2-inpainting-inference.yaml",
|
||||
LegacyConfigKey(
|
||||
ModelType.Main,
|
||||
BaseModelType.StableDiffusion2,
|
||||
ModelVariantType.Inpaint,
|
||||
SchedulerPredictionType.VPrediction,
|
||||
): "stable-diffusion/v2-inpainting-inference-v.yaml",
|
||||
LegacyConfigKey(
|
||||
ModelType.Main,
|
||||
BaseModelType.StableDiffusion2,
|
||||
ModelVariantType.Depth,
|
||||
): "stable-diffusion/v2-midas-inference.yaml",
|
||||
LegacyConfigKey(
|
||||
ModelType.Main,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
ModelVariantType.Normal,
|
||||
): "stable-diffusion/sd_xl_base.yaml",
|
||||
LegacyConfigKey(
|
||||
ModelType.Main,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
ModelVariantType.Inpaint,
|
||||
): "stable-diffusion/sd_xl_inpaint.yaml",
|
||||
LegacyConfigKey(
|
||||
ModelType.Main,
|
||||
BaseModelType.StableDiffusionXLRefiner,
|
||||
ModelVariantType.Normal,
|
||||
): "stable-diffusion/sd_xl_refiner.yaml",
|
||||
LegacyConfigKey(ModelType.ControlNet, BaseModelType.StableDiffusion1): "controlnet/cldm_v15.yaml",
|
||||
LegacyConfigKey(ModelType.ControlNet, BaseModelType.StableDiffusion2): "controlnet/cldm_v21.yaml",
|
||||
LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusion1): "stable-diffusion/v1-inference.yaml",
|
||||
LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusion2): "stable-diffusion/v2-inference.yaml",
|
||||
LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusionXL): "stable-diffusion/sd_xl_base.yaml",
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import diffusers
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
@@ -30,6 +31,7 @@ class BaseModelType(str, Enum):
|
||||
Imagen4 = "imagen4"
|
||||
Gemini2_5 = "gemini-2.5"
|
||||
ChatGPT4o = "chatgpt-4o"
|
||||
# This is actually the FLUX Kontext API model. Local FLUX Kontext is just BaseModelType.Flux.
|
||||
FluxKontext = "flux-kontext"
|
||||
Veo3 = "veo3"
|
||||
Runway = "runway"
|
||||
@@ -92,6 +94,12 @@ class ModelVariantType(str, Enum):
|
||||
Depth = "depth"
|
||||
|
||||
|
||||
class FluxVariantType(str, Enum):
|
||||
Schnell = "schnell"
|
||||
Dev = "dev"
|
||||
DevFill = "dev_fill"
|
||||
|
||||
|
||||
class ModelFormat(str, Enum):
|
||||
"""Storage format of model."""
|
||||
|
||||
@@ -149,4 +157,7 @@ class FluxLoRAFormat(str, Enum):
|
||||
AIToolkit = "flux.aitoolkit"
|
||||
|
||||
|
||||
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
|
||||
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, FluxVariantType]
|
||||
variant_type_adapter = TypeAdapter[ModelVariantType | ClipVariantType | FluxVariantType](
|
||||
ModelVariantType | ClipVariantType | FluxVariantType
|
||||
)
|
||||
|
||||
@@ -12,7 +12,10 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.util import InvokeAILogger
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
|
||||
def is_state_dict_likely_in_flux_aitoolkit_format(
|
||||
state_dict: dict[str, Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
if metadata:
|
||||
try:
|
||||
software = json.loads(metadata.get("software", "{}"))
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
|
||||
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_aitoolkit_format,
|
||||
@@ -14,7 +16,10 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
|
||||
)
|
||||
|
||||
|
||||
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
|
||||
def flux_format_from_state_dict(
|
||||
state_dict: dict[str, Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> FluxLoRAFormat | None:
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict):
|
||||
return FluxLoRAFormat.Kohya
|
||||
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
|
||||
|
||||
@@ -4,7 +4,8 @@ import accelerate
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.flux.util import get_flux_transformers_params
|
||||
from invokeai.backend.model_manager.taxonomy import ModelVariantType
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
|
||||
|
||||
@@ -22,7 +23,7 @@ def main():
|
||||
|
||||
with log_time("Initialize FLUX transformer on meta device"):
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
p = params["flux-schnell"]
|
||||
p = get_flux_transformers_params(ModelVariantType.FluxSchnell)
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
|
||||
@@ -7,7 +7,8 @@ import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.flux.util import get_flux_transformers_params
|
||||
from invokeai.backend.model_manager.taxonomy import ModelVariantType
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
|
||||
|
||||
@@ -35,7 +36,7 @@ def main():
|
||||
# inference_dtype = torch.bfloat16
|
||||
with log_time("Initialize FLUX transformer on meta device"):
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
p = params["flux-schnell"]
|
||||
p = get_flux_transformers_params(ModelVariantType.FluxSchnell)
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
|
||||
@@ -223,6 +223,9 @@ export const MODEL_VARIANT_TO_LONG_NAME: Record<ModelVariantType, string> = {
|
||||
normal: 'Normal',
|
||||
inpaint: 'Inpaint',
|
||||
depth: 'Depth',
|
||||
dev: 'FLUX Dev',
|
||||
dev_fill: 'FLUX Dev Fill',
|
||||
schnell: 'FLUX Schnell',
|
||||
};
|
||||
|
||||
export const MODEL_FORMAT_TO_LONG_NAME: Record<ModelFormat, string> = {
|
||||
|
||||
@@ -147,7 +147,7 @@ export const zSubModelType = z.enum([
|
||||
]);
|
||||
|
||||
export const zClipVariantType = z.enum(['large', 'gigantic']);
|
||||
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']);
|
||||
export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth', 'dev', 'dev_fill', 'schnell']);
|
||||
export type ModelVariantType = z.infer<typeof zModelVariantType>;
|
||||
export const zModelFormat = z.enum([
|
||||
'omi',
|
||||
|
||||
@@ -5490,7 +5490,7 @@ export type components = {
|
||||
* Config Path
|
||||
* @description path to the checkpoint model config file
|
||||
*/
|
||||
config_path: string;
|
||||
config_path?: string | null;
|
||||
/**
|
||||
* Converted At
|
||||
* @description When this model was last converted to diffusers
|
||||
@@ -14818,7 +14818,7 @@ export type components = {
|
||||
* Config Path
|
||||
* @description path to the checkpoint model config file
|
||||
*/
|
||||
config_path: string;
|
||||
config_path?: string | null;
|
||||
/**
|
||||
* Converted At
|
||||
* @description When this model was last converted to diffusers
|
||||
@@ -14927,7 +14927,7 @@ export type components = {
|
||||
* Config Path
|
||||
* @description path to the checkpoint model config file
|
||||
*/
|
||||
config_path: string;
|
||||
config_path?: string | null;
|
||||
/**
|
||||
* Converted At
|
||||
* @description When this model was last converted to diffusers
|
||||
@@ -15128,7 +15128,7 @@ export type components = {
|
||||
* Config Path
|
||||
* @description path to the checkpoint model config file
|
||||
*/
|
||||
config_path: string;
|
||||
config_path?: string | null;
|
||||
/**
|
||||
* Converted At
|
||||
* @description When this model was last converted to diffusers
|
||||
@@ -17487,7 +17487,7 @@ export type components = {
|
||||
* @description Variant type.
|
||||
* @enum {string}
|
||||
*/
|
||||
ModelVariantType: "normal" | "inpaint" | "depth";
|
||||
ModelVariantType: "normal" | "inpaint" | "depth" | "flux_dev" | "flux_dev_fill" | "flux_schnell";
|
||||
/**
|
||||
* ModelsList
|
||||
* @description Return list of configs.
|
||||
@@ -22199,7 +22199,7 @@ export type components = {
|
||||
* Config Path
|
||||
* @description path to the checkpoint model config file
|
||||
*/
|
||||
config_path: string;
|
||||
config_path?: string | null;
|
||||
/**
|
||||
* Converted At
|
||||
* @description When this model was last converted to diffusers
|
||||
|
||||
@@ -2,7 +2,8 @@ import accelerate
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.flux.util import get_flux_transformers_params
|
||||
from invokeai.backend.model_manager.taxonomy import ModelVariantType
|
||||
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
|
||||
_group_state_by_submodel,
|
||||
is_state_dict_likely_in_flux_aitoolkit_format,
|
||||
@@ -44,7 +45,7 @@ def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format():
|
||||
|
||||
# Initialize a FLUX model on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params["flux-schnell"])
|
||||
model = Flux(get_flux_transformers_params(ModelVariantType.FluxSchnell))
|
||||
model_keys = set(model.state_dict().keys())
|
||||
|
||||
for converted_key_prefix in converted_key_prefixes:
|
||||
|
||||
@@ -3,7 +3,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.flux.util import get_flux_transformers_params
|
||||
from invokeai.backend.model_manager.taxonomy import ModelVariantType
|
||||
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
|
||||
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
@@ -63,7 +64,7 @@ def test_convert_flux_transformer_kohya_state_dict_to_invoke_format():
|
||||
|
||||
# Initialize a FLUX model on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params["flux-dev"])
|
||||
model = Flux(get_flux_transformers_params(ModelVariantType.FluxSchnell))
|
||||
model_keys = set(model.state_dict().keys())
|
||||
|
||||
# Assert that the converted state dict matches the keys in the actual model.
|
||||
|
||||
@@ -115,7 +115,7 @@ class MinimalConfigExample(ModelConfigBase):
|
||||
fun_quote: str
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
def matches(cls, mod: ModelOnDisk, **overrides) -> bool:
|
||||
return mod.path.suffix == ".json"
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user