Create clip variant type, create new fucntions for discerning clipL and clipG in the frontend

This commit is contained in:
Brandon Rising
2024-10-30 11:10:08 -04:00
committed by Brandon
parent 1eca4f12c8
commit b87f4e59a5
5 changed files with 84 additions and 16 deletions

View File

@@ -95,6 +95,13 @@ class SubModelType(str, Enum):
SafetyChecker = "safety_checker" SafetyChecker = "safety_checker"
class ClipVariantType(str, Enum):
"""Variant type."""
L = "large"
G = "gigantic"
class ModelVariantType(str, Enum): class ModelVariantType(str, Enum):
"""Variant type.""" """Variant type."""
@@ -150,9 +157,13 @@ class ModelSourceType(str, Enum):
DEFAULTS_PRECISION = Literal["fp16", "fp32"] DEFAULTS_PRECISION = Literal["fp16", "fp32"]
AnyVariant: TypeAlias = Union[ModelRepoVariant, ClipVariantType, None]
class SubmodelDefinition(BaseModel): class SubmodelDefinition(BaseModel):
path_or_prefix: str path_or_prefix: str
model_type: ModelType model_type: ModelType
variant: AnyVariant = None
class MainModelDefaultSettings(BaseModel): class MainModelDefaultSettings(BaseModel):
@@ -430,6 +441,7 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
variant: ClipVariantType = ClipVariantType.L
@staticmethod @staticmethod
def get_tag() -> Tag: def get_tag() -> Tag:

View File

@@ -1,7 +1,7 @@
import json import json
import re import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Callable, Dict, Literal, Optional, Union
import safetensors.torch import safetensors.torch
import spandrel import spandrel
@@ -22,6 +22,7 @@ from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import i
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
AnyVariant,
BaseModelType, BaseModelType,
ControlAdapterDefaultSettings, ControlAdapterDefaultSettings,
InvalidModelConfigException, InvalidModelConfigException,
@@ -37,7 +38,11 @@ from invokeai.backend.model_manager.config import (
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta from invokeai.backend.model_manager.util.model_util import (
get_clip_variant_type,
lora_token_vector_length,
read_checkpoint_meta,
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@@ -130,6 +135,8 @@ class ModelProbe(object):
"CLIPTextModelWithProjection": ModelType.CLIPEmbed, "CLIPTextModelWithProjection": ModelType.CLIPEmbed,
} }
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
@classmethod @classmethod
def register_probe( def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase] cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
@@ -176,7 +183,10 @@ class ModelProbe(object):
fields["path"] = model_path.as_posix() fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type() fields["base"] = fields.get("base") or probe.get_base_type()
fields["variant"] = fields.get("variant") or probe.get_variant_type() variant_func = cls.TYPE2VARIANT.get(fields["type"], None)
fields["variant"] = (
fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type()
)
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type() fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id() fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path) fields["name"] = fields.get("name") or cls.get_model_name(model_path)
@@ -803,9 +813,11 @@ class PipelineFolderProbe(FolderProbeBase):
continue continue
model_loader = str(value[1]) model_loader = str(value[1])
if model_type := ModelProbe.CLASS2TYPE.get(model_loader): if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None)
submodels[SubModelType(key)] = SubmodelDefinition( submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=(self.model_path / key).resolve().as_posix(), path_or_prefix=(self.model_path / key).resolve().as_posix(),
model_type=model_type, model_type=model_type,
variant=variant_func and variant_func((self.model_path / key).as_posix()),
) )
return submodels return submodels

View File

@@ -8,6 +8,7 @@ import safetensors
import torch import torch
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from invokeai.backend.model_manager.config import ClipVariantType
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
@@ -165,3 +166,20 @@ def convert_bundle_to_flux_transformer_checkpoint(
del transformer_state_dict[k] del transformer_state_dict[k]
return original_state_dict return original_state_dict
def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
path = Path(location)
config_path = path / "config.json"
if not config_path.exists():
return None
with open(config_path) as file:
clip_conf = json.load(file)
hidden_size = clip_conf.get("hidden_size", -1)
match hidden_size:
case 1280:
return ClipVariantType.G
case 768:
return ClipVariantType.L
case _:
return None

View File

@@ -2711,6 +2711,8 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
type: "clip_embed"; type: "clip_embed";
/** @default large */
variant?: components["schemas"]["ClipVariantType"];
}; };
/** CLIPField */ /** CLIPField */
CLIPField: { CLIPField: {
@@ -3515,6 +3517,12 @@ export type components = {
*/ */
deleted: number; deleted: number;
}; };
/**
* ClipVariantType
* @description Variant type.
* @enum {string}
*/
ClipVariantType: "large" | "gigantic";
/** /**
* CollectInvocation * CollectInvocation
* @description Collects values into a collection * @description Collects values into a collection
@@ -14216,7 +14224,7 @@ export type components = {
/** /**
* CFG Scale * CFG Scale
* @description Classifier-Free Guidance scale * @description Classifier-Free Guidance scale
* @default 7 * @default 3.5
*/ */
cfg_scale?: number | number[]; cfg_scale?: number | number[];
/** /**
@@ -16259,6 +16267,8 @@ export type components = {
/** Path Or Prefix */ /** Path Or Prefix */
path_or_prefix: string; path_or_prefix: string;
model_type: components["schemas"]["ModelType"]; model_type: components["schemas"]["ModelType"];
/** Variant */
variant?: components["schemas"]["ModelRepoVariant"] | components["schemas"]["ClipVariantType"] | null;
}; };
/** /**
* Subtract Integers * Subtract Integers

View File

@@ -75,21 +75,25 @@ export type AnyModelConfig =
| MainModelConfig | MainModelConfig
| CLIPVisionDiffusersConfig; | CLIPVisionDiffusersConfig;
const check_submodel_model_type = (submodels: AnyModelConfig['submodels'], model_type: string): boolean => { const check_submodel = (submodels: AnyModelConfig['submodels'], check_str: string): boolean => {
for (const submodel in submodels) { for (const submodel in submodels) {
if (submodel && submodels[submodel] && submodels[submodel].model_type === model_type) { if (
submodel &&
submodels[submodel] &&
(submodels[submodel].model_type === check_str || submodels[submodel].variant === check_str)
) {
return true; return true;
} }
} }
return false; return false;
}; };
const check_submodels = (indentifier: string, config: AnyModelConfig): boolean => { const check_submodels = (indentifiers: string[], config: AnyModelConfig): boolean => {
return ( return indentifiers.every(
(config.type === 'main' && (indentifier) =>
config.type === 'main' &&
config.submodels && config.submodels &&
(indentifier in config.submodels || check_submodel_model_type(config.submodels, indentifier))) || (indentifier in config.submodels || check_submodel(config.submodels, indentifier))
false
); );
}; };
@@ -98,15 +102,15 @@ export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelCo
}; };
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => { export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
return config.type === 'vae' || check_submodels('vae', config); return config.type === 'vae' || check_submodels(['vae'], config);
}; };
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => { export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
return (config.type === 'vae' || check_submodels('vae', config)) && config.base !== 'flux'; return (config.type === 'vae' || check_submodels(['vae'], config)) && config.base !== 'flux';
}; };
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => { export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | MainModelConfig => {
return (config.type === 'vae' || check_submodels('vae', config)) && config.base === 'flux'; return (config.type === 'vae' || check_submodels(['vae'], config)) && config.base === 'flux';
}; };
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => { export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
@@ -128,11 +132,23 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
export const isT5EncoderModelConfig = ( export const isT5EncoderModelConfig = (
config: AnyModelConfig config: AnyModelConfig
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig | MainModelConfig => { ): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig | MainModelConfig => {
return config.type === 't5_encoder' || check_submodels('t5_encoder', config); return config.type === 't5_encoder' || check_submodels(['t5_encoder'], config);
}; };
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => { export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => {
return config.type === 'clip_embed' || check_submodels('clip_embed', config); return config.type === 'clip_embed' || check_submodels(['clip_embed'], config);
};
export const isCLIPLEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => {
return (
(config.type === 'clip_embed' && config.variant === 'large') || check_submodels(['clip_embed', 'gigantic'], config)
);
};
export const isCLIPGEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig | MainModelConfig => {
return (
(config.type === 'clip_embed' && config.variant === 'gigantic') || check_submodels(['clip_embed', 'large'], config)
);
}; };
export const isSpandrelImageToImageModelConfig = ( export const isSpandrelImageToImageModelConfig = (