diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 8aa57066cb..e08d267a3c 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -95,6 +95,13 @@ class SubModelType(str, Enum): SafetyChecker = "safety_checker" +class ClipVariantType(str, Enum): + """Variant type.""" + + L = "large" + G = "gigantic" + + class ModelVariantType(str, Enum): """Variant type.""" @@ -150,9 +157,13 @@ class ModelSourceType(str, Enum): DEFAULTS_PRECISION = Literal["fp16", "fp32"] +AnyVariant: TypeAlias = Union[ModelRepoVariant, ClipVariantType, None] + + class SubmodelDefinition(BaseModel): path_or_prefix: str model_type: ModelType + variant: AnyVariant = None class MainModelDefaultSettings(BaseModel): @@ -430,6 +441,7 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase): type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + variant: ClipVariantType = ClipVariantType.L @staticmethod def get_tag() -> Tag: diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 4471100441..e988ac0372 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -1,7 +1,7 @@ import json import re from pathlib import Path -from typing import Any, Dict, Literal, Optional, Union +from typing import Any, Callable, Dict, Literal, Optional, Union import safetensors.torch import spandrel @@ -22,6 +22,7 @@ from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import i from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS from invokeai.backend.model_manager.config import ( AnyModelConfig, + AnyVariant, BaseModelType, ControlAdapterDefaultSettings, InvalidModelConfigException, @@ -37,7 +38,11 @@ from invokeai.backend.model_manager.config import ( SubModelType, ) from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader -from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta +from invokeai.backend.model_manager.util.model_util import ( + get_clip_variant_type, + lora_token_vector_length, + read_checkpoint_meta, +) from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel @@ -130,6 +135,8 @@ class ModelProbe(object): "CLIPTextModelWithProjection": ModelType.CLIPEmbed, } + TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type} + @classmethod def register_probe( cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase] @@ -176,7 +183,10 @@ class ModelProbe(object): fields["path"] = model_path.as_posix() fields["type"] = fields.get("type") or model_type fields["base"] = fields.get("base") or probe.get_base_type() - fields["variant"] = fields.get("variant") or probe.get_variant_type() + variant_func = cls.TYPE2VARIANT.get(fields["type"], None) + fields["variant"] = ( + fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type() + ) fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type() fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id() fields["name"] = fields.get("name") or cls.get_model_name(model_path) @@ -803,9 +813,11 @@ class PipelineFolderProbe(FolderProbeBase): continue model_loader = str(value[1]) if model_type := ModelProbe.CLASS2TYPE.get(model_loader): + variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None) submodels[SubModelType(key)] = SubmodelDefinition( path_or_prefix=(self.model_path / key).resolve().as_posix(), model_type=model_type, + variant=variant_func and variant_func((self.model_path / key).as_posix()), ) return submodels diff --git a/invokeai/backend/model_manager/util/model_util.py b/invokeai/backend/model_manager/util/model_util.py index fec59a60f8..d4efdeb0ae 100644 --- a/invokeai/backend/model_manager/util/model_util.py +++ b/invokeai/backend/model_manager/util/model_util.py @@ -8,6 +8,7 @@ import safetensors import torch from picklescan.scanner import scan_file_path +from invokeai.backend.model_manager.config import ClipVariantType from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader @@ -165,3 +166,20 @@ def convert_bundle_to_flux_transformer_checkpoint( del transformer_state_dict[k] return original_state_dict + + +def get_clip_variant_type(location: str) -> Optional[ClipVariantType]: + path = Path(location) + config_path = path / "config.json" + if not config_path.exists(): + return None + with open(config_path) as file: + clip_conf = json.load(file) + hidden_size = clip_conf.get("hidden_size", -1) + match hidden_size: + case 1280: + return ClipVariantType.G + case 768: + return ClipVariantType.L + case _: + return None diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 5818e3a4fc..c8a679c0ba 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -2711,6 +2711,8 @@ export type components = { * @enum {string} */ type: "clip_embed"; + /** @default large */ + variant?: components["schemas"]["ClipVariantType"]; }; /** CLIPField */ CLIPField: { @@ -3515,6 +3517,12 @@ export type components = { */ deleted: number; }; + /** + * ClipVariantType + * @description Variant type. + * @enum {string} + */ + ClipVariantType: "large" | "gigantic"; /** * CollectInvocation * @description Collects values into a collection @@ -14216,7 +14224,7 @@ export type components = { /** * CFG Scale * @description Classifier-Free Guidance scale - * @default 7 + * @default 3.5 */ cfg_scale?: number | number[]; /** @@ -16259,6 +16267,8 @@ export type components = { /** Path Or Prefix */ path_or_prefix: string; model_type: components["schemas"]["ModelType"]; + /** Variant */ + variant?: components["schemas"]["ModelRepoVariant"] | components["schemas"]["ClipVariantType"] | null; }; /** * Subtract Integers diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 1e486acb7e..4815c61125 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -75,21 +75,25 @@ export type AnyModelConfig = | MainModelConfig | CLIPVisionDiffusersConfig; -const check_submodel_model_type = (submodels: AnyModelConfig['submodels'], model_type: string): boolean => { +const check_submodel = (submodels: AnyModelConfig['submodels'], check_str: string): boolean => { for (const submodel in submodels) { - if (submodel && submodels[submodel] && submodels[submodel].model_type === model_type) { + if ( + submodel && + submodels[submodel] && + (submodels[submodel].model_type === check_str || submodels[submodel].variant === check_str) + ) { return true; } } return false; }; -const check_submodels = (indentifier: string, config: AnyModelConfig): boolean => { - return ( - (config.type === 'main' && +const check_submodels = (indentifiers: string[], config: AnyModelConfig): boolean => { + return indentifiers.every( + (indentifier) => + config.type === 'main' && config.submodels && - (indentifier in config.submodels || check_submodel_model_type(config.submodels, indentifier))) || - false + (indentifier in config.submodels || check_submodel(config.submodels, indentifier)) ); }; @@ -98,15 +102,15 @@ export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelCo }; export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => { - return config.type === 'vae' || check_submodels('vae', config); + return config.type === 'vae' || check_submodels(['vae'], config); }; export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => { - return (config.type === 'vae' || check_submodels('vae', config)) && config.base !== 'flux'; + return (config.type === 'vae' || check_submodels(['vae'], config)) && config.base !== 'flux'; }; export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => { - return (config.type === 'vae' || check_submodels('vae', config)) && config.base === 'flux'; + return (config.type === 'vae' || check_submodels(['vae'], config)) && config.base === 'flux'; }; export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => { @@ -128,11 +132,23 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd export const isT5EncoderModelConfig = ( config: AnyModelConfig ): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig | MainModelConfig => { - return config.type === 't5_encoder' || check_submodels('t5_encoder', config); + return config.type === 't5_encoder' || check_submodels(['t5_encoder'], config); }; export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => { - return config.type === 'clip_embed' || check_submodels('clip_embed', config); + return config.type === 'clip_embed' || check_submodels(['clip_embed'], config); +}; + +export const isCLIPLEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => { + return ( + (config.type === 'clip_embed' && config.variant === 'large') || check_submodels(['clip_embed', 'gigantic'], config) + ); +}; + +export const isCLIPGEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => { + return ( + (config.type === 'clip_embed' && config.variant === 'gigantic') || check_submodels(['clip_embed', 'large'], config) + ); }; export const isSpandrelImageToImageModelConfig = (