mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Create clip variant type, create new fucntions for discerning clipL and clipG in the frontend
This commit is contained in:
@@ -95,6 +95,13 @@ class SubModelType(str, Enum):
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
|
||||
class ClipVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
L = "large"
|
||||
G = "gigantic"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
@@ -150,9 +157,13 @@ class ModelSourceType(str, Enum):
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
AnyVariant: TypeAlias = Union[ModelRepoVariant, ClipVariantType, None]
|
||||
|
||||
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
model_type: ModelType
|
||||
variant: AnyVariant = None
|
||||
|
||||
|
||||
class MainModelDefaultSettings(BaseModel):
|
||||
@@ -430,6 +441,7 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
|
||||
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
variant: ClipVariantType = ClipVariantType.L
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import spandrel
|
||||
@@ -22,6 +22,7 @@ from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import i
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
ControlAdapterDefaultSettings,
|
||||
InvalidModelConfigException,
|
||||
@@ -37,7 +38,11 @@ from invokeai.backend.model_manager.config import (
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
get_clip_variant_type,
|
||||
lora_token_vector_length,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
@@ -130,6 +135,8 @@ class ModelProbe(object):
|
||||
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
|
||||
}
|
||||
|
||||
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
|
||||
@@ -176,7 +183,10 @@ class ModelProbe(object):
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
fields["variant"] = fields.get("variant") or probe.get_variant_type()
|
||||
variant_func = cls.TYPE2VARIANT.get(fields["type"], None)
|
||||
fields["variant"] = (
|
||||
fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type()
|
||||
)
|
||||
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
|
||||
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||
@@ -803,9 +813,11 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
continue
|
||||
model_loader = str(value[1])
|
||||
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
|
||||
variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None)
|
||||
submodels[SubModelType(key)] = SubmodelDefinition(
|
||||
path_or_prefix=(self.model_path / key).resolve().as_posix(),
|
||||
model_type=model_type,
|
||||
variant=variant_func and variant_func((self.model_path / key).as_posix()),
|
||||
)
|
||||
|
||||
return submodels
|
||||
|
||||
@@ -8,6 +8,7 @@ import safetensors
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_manager.config import ClipVariantType
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
|
||||
|
||||
@@ -165,3 +166,20 @@ def convert_bundle_to_flux_transformer_checkpoint(
|
||||
del transformer_state_dict[k]
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
|
||||
path = Path(location)
|
||||
config_path = path / "config.json"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
with open(config_path) as file:
|
||||
clip_conf = json.load(file)
|
||||
hidden_size = clip_conf.get("hidden_size", -1)
|
||||
match hidden_size:
|
||||
case 1280:
|
||||
return ClipVariantType.G
|
||||
case 768:
|
||||
return ClipVariantType.L
|
||||
case _:
|
||||
return None
|
||||
|
||||
@@ -2711,6 +2711,8 @@ export type components = {
|
||||
* @enum {string}
|
||||
*/
|
||||
type: "clip_embed";
|
||||
/** @default large */
|
||||
variant?: components["schemas"]["ClipVariantType"];
|
||||
};
|
||||
/** CLIPField */
|
||||
CLIPField: {
|
||||
@@ -3515,6 +3517,12 @@ export type components = {
|
||||
*/
|
||||
deleted: number;
|
||||
};
|
||||
/**
|
||||
* ClipVariantType
|
||||
* @description Variant type.
|
||||
* @enum {string}
|
||||
*/
|
||||
ClipVariantType: "large" | "gigantic";
|
||||
/**
|
||||
* CollectInvocation
|
||||
* @description Collects values into a collection
|
||||
@@ -14216,7 +14224,7 @@ export type components = {
|
||||
/**
|
||||
* CFG Scale
|
||||
* @description Classifier-Free Guidance scale
|
||||
* @default 7
|
||||
* @default 3.5
|
||||
*/
|
||||
cfg_scale?: number | number[];
|
||||
/**
|
||||
@@ -16259,6 +16267,8 @@ export type components = {
|
||||
/** Path Or Prefix */
|
||||
path_or_prefix: string;
|
||||
model_type: components["schemas"]["ModelType"];
|
||||
/** Variant */
|
||||
variant?: components["schemas"]["ModelRepoVariant"] | components["schemas"]["ClipVariantType"] | null;
|
||||
};
|
||||
/**
|
||||
* Subtract Integers
|
||||
|
||||
@@ -75,21 +75,25 @@ export type AnyModelConfig =
|
||||
| MainModelConfig
|
||||
| CLIPVisionDiffusersConfig;
|
||||
|
||||
const check_submodel_model_type = (submodels: AnyModelConfig['submodels'], model_type: string): boolean => {
|
||||
const check_submodel = (submodels: AnyModelConfig['submodels'], check_str: string): boolean => {
|
||||
for (const submodel in submodels) {
|
||||
if (submodel && submodels[submodel] && submodels[submodel].model_type === model_type) {
|
||||
if (
|
||||
submodel &&
|
||||
submodels[submodel] &&
|
||||
(submodels[submodel].model_type === check_str || submodels[submodel].variant === check_str)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const check_submodels = (indentifier: string, config: AnyModelConfig): boolean => {
|
||||
return (
|
||||
(config.type === 'main' &&
|
||||
const check_submodels = (indentifiers: string[], config: AnyModelConfig): boolean => {
|
||||
return indentifiers.every(
|
||||
(indentifier) =>
|
||||
config.type === 'main' &&
|
||||
config.submodels &&
|
||||
(indentifier in config.submodels || check_submodel_model_type(config.submodels, indentifier))) ||
|
||||
false
|
||||
(indentifier in config.submodels || check_submodel(config.submodels, indentifier))
|
||||
);
|
||||
};
|
||||
|
||||
@@ -98,15 +102,15 @@ export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelCo
|
||||
};
|
||||
|
||||
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
|
||||
return config.type === 'vae' || check_submodels('vae', config);
|
||||
return config.type === 'vae' || check_submodels(['vae'], config);
|
||||
};
|
||||
|
||||
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
|
||||
return (config.type === 'vae' || check_submodels('vae', config)) && config.base !== 'flux';
|
||||
return (config.type === 'vae' || check_submodels(['vae'], config)) && config.base !== 'flux';
|
||||
};
|
||||
|
||||
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
|
||||
return (config.type === 'vae' || check_submodels('vae', config)) && config.base === 'flux';
|
||||
return (config.type === 'vae' || check_submodels(['vae'], config)) && config.base === 'flux';
|
||||
};
|
||||
|
||||
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
|
||||
@@ -128,11 +132,23 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
|
||||
export const isT5EncoderModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig | MainModelConfig => {
|
||||
return config.type === 't5_encoder' || check_submodels('t5_encoder', config);
|
||||
return config.type === 't5_encoder' || check_submodels(['t5_encoder'], config);
|
||||
};
|
||||
|
||||
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => {
|
||||
return config.type === 'clip_embed' || check_submodels('clip_embed', config);
|
||||
return config.type === 'clip_embed' || check_submodels(['clip_embed'], config);
|
||||
};
|
||||
|
||||
export const isCLIPLEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => {
|
||||
return (
|
||||
(config.type === 'clip_embed' && config.variant === 'large') || check_submodels(['clip_embed', 'gigantic'], config)
|
||||
);
|
||||
};
|
||||
|
||||
export const isCLIPGEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => {
|
||||
return (
|
||||
(config.type === 'clip_embed' && config.variant === 'gigantic') || check_submodels(['clip_embed', 'large'], config)
|
||||
);
|
||||
};
|
||||
|
||||
export const isSpandrelImageToImageModelConfig = (
|
||||
|
||||
Reference in New Issue
Block a user