Create clip variant type, create new fucntions for discerning clipL and clipG in the frontend

This commit is contained in:
Brandon Rising
2024-10-30 11:10:08 -04:00
committed by Brandon
parent 1eca4f12c8
commit b87f4e59a5
5 changed files with 84 additions and 16 deletions

View File

@@ -95,6 +95,13 @@ class SubModelType(str, Enum):
SafetyChecker = "safety_checker"
class ClipVariantType(str, Enum):
"""Variant type."""
L = "large"
G = "gigantic"
class ModelVariantType(str, Enum):
"""Variant type."""
@@ -150,9 +157,13 @@ class ModelSourceType(str, Enum):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
AnyVariant: TypeAlias = Union[ModelRepoVariant, ClipVariantType, None]
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
variant: AnyVariant = None
class MainModelDefaultSettings(BaseModel):
@@ -430,6 +441,7 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
variant: ClipVariantType = ClipVariantType.L
@staticmethod
def get_tag() -> Tag:

View File

@@ -1,7 +1,7 @@
import json
import re
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
from typing import Any, Callable, Dict, Literal, Optional, Union
import safetensors.torch
import spandrel
@@ -22,6 +22,7 @@ from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import i
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.config import (
AnyModelConfig,
AnyVariant,
BaseModelType,
ControlAdapterDefaultSettings,
InvalidModelConfigException,
@@ -37,7 +38,11 @@ from invokeai.backend.model_manager.config import (
SubModelType,
)
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.model_manager.util.model_util import (
get_clip_variant_type,
lora_token_vector_length,
read_checkpoint_meta,
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@@ -130,6 +135,8 @@ class ModelProbe(object):
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
}
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
@classmethod
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
@@ -176,7 +183,10 @@ class ModelProbe(object):
fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type()
fields["variant"] = fields.get("variant") or probe.get_variant_type()
variant_func = cls.TYPE2VARIANT.get(fields["type"], None)
fields["variant"] = (
fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type()
)
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
@@ -803,9 +813,11 @@ class PipelineFolderProbe(FolderProbeBase):
continue
model_loader = str(value[1])
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None)
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=(self.model_path / key).resolve().as_posix(),
model_type=model_type,
variant=variant_func and variant_func((self.model_path / key).as_posix()),
)
return submodels

View File

@@ -8,6 +8,7 @@ import safetensors
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.model_manager.config import ClipVariantType
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
@@ -165,3 +166,20 @@ def convert_bundle_to_flux_transformer_checkpoint(
del transformer_state_dict[k]
return original_state_dict
def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
path = Path(location)
config_path = path / "config.json"
if not config_path.exists():
return None
with open(config_path) as file:
clip_conf = json.load(file)
hidden_size = clip_conf.get("hidden_size", -1)
match hidden_size:
case 1280:
return ClipVariantType.G
case 768:
return ClipVariantType.L
case _:
return None

View File

@@ -2711,6 +2711,8 @@ export type components = {
* @enum {string}
*/
type: "clip_embed";
/** @default large */
variant?: components["schemas"]["ClipVariantType"];
};
/** CLIPField */
CLIPField: {
@@ -3515,6 +3517,12 @@ export type components = {
*/
deleted: number;
};
/**
* ClipVariantType
* @description Variant type.
* @enum {string}
*/
ClipVariantType: "large" | "gigantic";
/**
* CollectInvocation
* @description Collects values into a collection
@@ -14216,7 +14224,7 @@ export type components = {
/**
* CFG Scale
* @description Classifier-Free Guidance scale
* @default 7
* @default 3.5
*/
cfg_scale?: number | number[];
/**
@@ -16259,6 +16267,8 @@ export type components = {
/** Path Or Prefix */
path_or_prefix: string;
model_type: components["schemas"]["ModelType"];
/** Variant */
variant?: components["schemas"]["ModelRepoVariant"] | components["schemas"]["ClipVariantType"] | null;
};
/**
* Subtract Integers

View File

@@ -75,21 +75,25 @@ export type AnyModelConfig =
| MainModelConfig
| CLIPVisionDiffusersConfig;
const check_submodel_model_type = (submodels: AnyModelConfig['submodels'], model_type: string): boolean => {
const check_submodel = (submodels: AnyModelConfig['submodels'], check_str: string): boolean => {
for (const submodel in submodels) {
if (submodel && submodels[submodel] && submodels[submodel].model_type === model_type) {
if (
submodel &&
submodels[submodel] &&
(submodels[submodel].model_type === check_str || submodels[submodel].variant === check_str)
) {
return true;
}
}
return false;
};
const check_submodels = (indentifier: string, config: AnyModelConfig): boolean => {
return (
(config.type === 'main' &&
const check_submodels = (indentifiers: string[], config: AnyModelConfig): boolean => {
return indentifiers.every(
(indentifier) =>
config.type === 'main' &&
config.submodels &&
(indentifier in config.submodels || check_submodel_model_type(config.submodels, indentifier))) ||
false
(indentifier in config.submodels || check_submodel(config.submodels, indentifier))
);
};
@@ -98,15 +102,15 @@ export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelCo
};
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
return config.type === 'vae' || check_submodels('vae', config);
return config.type === 'vae' || check_submodels(['vae'], config);
};
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
return (config.type === 'vae' || check_submodels('vae', config)) && config.base !== 'flux';
return (config.type === 'vae' || check_submodels(['vae'], config)) && config.base !== 'flux';
};
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
return (config.type === 'vae' || check_submodels('vae', config)) && config.base === 'flux';
return (config.type === 'vae' || check_submodels(['vae'], config)) && config.base === 'flux';
};
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
@@ -128,11 +132,23 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
export const isT5EncoderModelConfig = (
config: AnyModelConfig
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig | MainModelConfig => {
return config.type === 't5_encoder' || check_submodels('t5_encoder', config);
return config.type === 't5_encoder' || check_submodels(['t5_encoder'], config);
};
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => {
return config.type === 'clip_embed' || check_submodels('clip_embed', config);
return config.type === 'clip_embed' || check_submodels(['clip_embed'], config);
};
export const isCLIPLEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => {
return (
(config.type === 'clip_embed' && config.variant === 'large') || check_submodels(['clip_embed', 'gigantic'], config)
);
};
export const isCLIPGEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => {
return (
(config.type === 'clip_embed' && config.variant === 'gigantic') || check_submodels(['clip_embed', 'large'], config)
);
};
export const isSpandrelImageToImageModelConfig = (