mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 06:18:03 -05:00
Compare commits
6 Commits
v5.9.1
...
brandon/fl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6d044358c8 | ||
|
|
59a2165d5e | ||
|
|
27659fb23b | ||
|
|
ba1e35b34b | ||
|
|
84dc7fbbd9 | ||
|
|
7d5f9b6664 |
@@ -95,7 +95,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
# Apply LoRA models to the CLIP encoder.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers, ModelFormat.Checkpoint]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
|
||||
@@ -205,6 +205,7 @@ class CheckpointConfigBase(ModelConfigBase):
|
||||
converted_at: Optional[float] = Field(
|
||||
description="When this model was last converted to diffusers", default_factory=time.time
|
||||
)
|
||||
submodels: dict[ModelType, str] = {}
|
||||
|
||||
|
||||
class DiffusersConfigBase(ModelConfigBase):
|
||||
|
||||
@@ -7,7 +7,16 @@ from typing import Optional
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForTextEncoding,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
T5Config,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
)
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
@@ -45,6 +54,7 @@ from invokeai.backend.model_manager.config import (
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
FilteredStringDict,
|
||||
convert_bundle_to_flux_transformer_checkpoint,
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
@@ -186,12 +196,74 @@ class FluxCheckpointModel(ModelLoader):
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
|
||||
if not isinstance(config, MainCheckpointConfig):
|
||||
raise ValueError("Only MainCheckpointConfig models are currently supported here.")
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer:
|
||||
return CLIPTokenizer.from_pretrained(
|
||||
"InvokeAI/clip-vit-large-patch14-text-encoder", subfolder="bfloat16/tokenizer"
|
||||
)
|
||||
case SubModelType.TextEncoder:
|
||||
if not (prefix := config.submodels.get(ModelType.CLIPEmbed)):
|
||||
raise ValueError(f"This model does not contain a {ModelType.T5Encoder} prefix")
|
||||
model = CLIPTextModel(
|
||||
CLIPTextConfig(
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
projection_dim=768,
|
||||
)
|
||||
)
|
||||
sd = load_file(Path(config.path))
|
||||
encoder_keys = [
|
||||
k[len(prefix) :]
|
||||
for k in sd.keys()
|
||||
if k.startswith(prefix) and not k.endswith("text_projection.weight")
|
||||
]
|
||||
clip_sd = FilteredStringDict(sd, encoder_keys, prefix)
|
||||
model.load_state_dict(state_dict=clip_sd)
|
||||
return model
|
||||
case SubModelType.Tokenizer2:
|
||||
prefix = config.submodels.get(ModelType.T5Encoder)
|
||||
return T5Tokenizer.from_pretrained(
|
||||
"InvokeAI/t5-v1_1-xxl", subfolder="bfloat16/tokenizer_2", max_length=512
|
||||
)
|
||||
case SubModelType.TextEncoder2:
|
||||
if not (prefix := config.submodels.get(ModelType.T5Encoder)):
|
||||
raise ValueError(f"This model does not contain a {ModelType.T5Encoder} prefix")
|
||||
sd = load_file(Path(config.path))
|
||||
model = T5EncoderModel(
|
||||
T5Config(
|
||||
d_model=4096,
|
||||
d_ff=10240,
|
||||
num_layers=24,
|
||||
num_decoder_layers=24,
|
||||
num_heads=64,
|
||||
feed_forward_proj="gated-gelu",
|
||||
is_encoder_decoder=False,
|
||||
)
|
||||
)
|
||||
embeds = model.get_input_embeddings() # type: ignore
|
||||
if not isinstance(embeds, torch.nn.Embedding):
|
||||
raise ValueError("Unable to load given T5 Encoder")
|
||||
sd[f"{prefix}encoder.embed_tokens.weight"] = embeds.weight
|
||||
encoder_keys = [k[len(prefix) :] for k in sd.keys() if k.startswith(prefix)]
|
||||
t5_sd = FilteredStringDict(sd, encoder_keys, prefix)
|
||||
model.load_state_dict(state_dict=t5_sd)
|
||||
return model
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config)
|
||||
case SubModelType.VAE:
|
||||
model_path = Path(config.path)
|
||||
if not (prefix := config.submodels.get(ModelType.VAE)):
|
||||
raise ValueError(f"This model does not contain a {ModelType.VAE} prefix")
|
||||
with SilenceWarnings():
|
||||
model = AutoEncoder(ae_params["flux"])
|
||||
sd = load_file(model_path)
|
||||
encoder_keys = [k[len(prefix) :] for k in sd.keys() if k.startswith(prefix)]
|
||||
t5_sd = FilteredStringDict(sd, encoder_keys, prefix)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
return model
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
|
||||
@@ -215,6 +215,9 @@ class ModelProbe(object):
|
||||
fields["base"] == BaseModelType.StableDiffusion2
|
||||
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
get_submodels = getattr(probe, "get_submodels", None)
|
||||
if callable(get_submodels):
|
||||
fields["submodels"] = get_submodels()
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
|
||||
return model_info
|
||||
@@ -443,6 +446,13 @@ class ModelProbe(object):
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
PREFIX_MAP = {
|
||||
"text_encoders.t5xxl.transformer.": ModelType.T5Encoder,
|
||||
"model.diffusion_model.": ModelType.Main,
|
||||
"text_encoders.clip_l.transformer.": ModelType.CLIPEmbed,
|
||||
"vae": ModelType.VAE,
|
||||
}
|
||||
|
||||
# Probing utilities
|
||||
MODEL_NAME_TO_PREPROCESSOR = {
|
||||
"canny": "canny_image_processor",
|
||||
@@ -492,6 +502,16 @@ class CheckpointProbeBase(ProbeBase):
|
||||
super().__init__(model_path)
|
||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
|
||||
def get_submodels(self) -> dict[ModelType, str]:
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
submodels: dict[ModelType, str] = {}
|
||||
submodels = {
|
||||
model_type: prefix
|
||||
for prefix, model_type in PREFIX_MAP.items()
|
||||
if any(key.startswith(prefix) for key in state_dict.keys() if isinstance(key, str))
|
||||
}
|
||||
return submodels
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
if (
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Utilities for parsing model files, used mostly by probe.py"""
|
||||
|
||||
import json
|
||||
from collections.abc import MutableMapping
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Generic, Iterable, Iterator, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -165,3 +166,65 @@ def convert_bundle_to_flux_transformer_checkpoint(
|
||||
del transformer_state_dict[k]
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
_KT = TypeVar("_KT", bound=str)
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class FilteredStringDict(MutableMapping[_KT, _VT], Generic[_KT, _VT]):
|
||||
def __init__(self, original_dict: Dict[_KT, _VT], keys: Iterable[_KT], prefix: Optional[_KT] = None) -> None:
|
||||
self._original_dict: Dict[_KT, _VT] = original_dict
|
||||
self.prefix = prefix
|
||||
self._keys = set(keys) # Keys without the prefix
|
||||
|
||||
def _get_prefixed_key(self, key: _KT) -> _KT:
|
||||
if self.prefix:
|
||||
prefixed_key = self.prefix + key
|
||||
if not isinstance(prefixed_key, type(key)):
|
||||
raise ValueError("Unable to prefix keys")
|
||||
return prefixed_key
|
||||
return key
|
||||
|
||||
def __getitem__(self, key: _KT) -> _VT:
|
||||
if key in self._keys:
|
||||
prefixed_key = self._get_prefixed_key(key)
|
||||
return self._original_dict[prefixed_key]
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key: _KT, value: _VT) -> None:
|
||||
if key in self._keys:
|
||||
prefixed_key = self._get_prefixed_key(key)
|
||||
self._original_dict[prefixed_key] = value
|
||||
else:
|
||||
raise KeyError(f"Key {key} not allowed in FilteredDict")
|
||||
|
||||
def __delitem__(self, key: _KT) -> None:
|
||||
if key in self._keys:
|
||||
prefixed_key = self._get_prefixed_key(key)
|
||||
del self._original_dict[prefixed_key]
|
||||
self._keys.remove(key)
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self) -> Iterator[_KT]:
|
||||
return iter(self._keys)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._keys)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
items = ", ".join(f"{k}: {self._original_dict[self._get_prefixed_key(k)]!r}" for k in self._keys)
|
||||
return f"{{{items}}}"
|
||||
|
||||
def keys(self) -> Iterator[_KT]:
|
||||
return iter(self._keys)
|
||||
|
||||
def values(self) -> Iterator[_VT]:
|
||||
for key in self._keys:
|
||||
yield self[key] # Access values lazily
|
||||
|
||||
def items(self) -> Iterator[Tuple[_KT, _VT]]:
|
||||
for key in self._keys:
|
||||
yield key, self[key] # Access values lazily
|
||||
|
||||
@@ -6,7 +6,7 @@ import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPEmbedModelConfig } from 'services/api/types';
|
||||
import type { CheckpointModelConfig, CLIPEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -19,7 +19,7 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
const _onChange = useCallback(
|
||||
(value: CLIPEmbedModelConfig | null) => {
|
||||
(value: CLIPEmbedModelConfig | CheckpointModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } f
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
import type { CheckpointModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -19,7 +19,7 @@ const FluxVAEModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useFluxVAEModels();
|
||||
const _onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
(value: VAEModelConfig | CheckpointModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -7,7 +7,11 @@ import { selectIsModelsTabDisabled } from 'features/system/store/configSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
|
||||
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
|
||||
import type {
|
||||
CheckpointModelConfig,
|
||||
T5EncoderBnbQuantizedLlmInt8bModelConfig,
|
||||
T5EncoderModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -20,7 +24,7 @@ const T5EncoderModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useT5EncoderModels();
|
||||
const _onChange = useCallback(
|
||||
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
|
||||
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | CheckpointModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
import type { CheckpointModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -16,7 +16,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||
const _onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
(value: VAEModelConfig | CheckpointModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPEmbedModelConfig } from 'services/api/types';
|
||||
import type { CheckpointModelConfig, CLIPEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
const ParamCLIPEmbedModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -15,7 +15,7 @@ const ParamCLIPEmbedModelSelect = () => {
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(clipEmbedModel: CLIPEmbedModelConfig | null) => {
|
||||
(clipEmbedModel: CLIPEmbedModelConfig | CheckpointModelConfig | null) => {
|
||||
if (clipEmbedModel) {
|
||||
dispatch(clipEmbedModelSelected(zModelIdentifierField.parse(clipEmbedModel)));
|
||||
}
|
||||
|
||||
@@ -6,7 +6,11 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
|
||||
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
|
||||
import type {
|
||||
CheckpointModelConfig,
|
||||
T5EncoderBnbQuantizedLlmInt8bModelConfig,
|
||||
T5EncoderModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
const ParamT5EncoderModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -15,7 +19,9 @@ const ParamT5EncoderModelSelect = () => {
|
||||
const [modelConfigs, { isLoading }] = useT5EncoderModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(t5EncoderModel: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
|
||||
(
|
||||
t5EncoderModel: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | CheckpointModelConfig | null
|
||||
) => {
|
||||
if (t5EncoderModel) {
|
||||
dispatch(t5EncoderModelSelected(zModelIdentifierField.parse(t5EncoderModel)));
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
import type { CheckpointModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
const ParamFLUXVAEModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -16,7 +16,7 @@ const ParamFLUXVAEModelSelect = () => {
|
||||
const [modelConfigs, { isLoading }] = useFluxVAEModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(vae: VAEModelConfig | null) => {
|
||||
(vae: VAEModelConfig | CheckpointModelConfig | null) => {
|
||||
if (vae) {
|
||||
dispatch(fluxVAESelected(zModelIdentifierField.parse(vae)));
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
import type { CheckpointModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
const ParamVAEModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -16,7 +16,7 @@ const ParamVAEModelSelect = () => {
|
||||
const vae = useAppSelector(selectVAE);
|
||||
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||
const getIsDisabled = useCallback(
|
||||
(vae: VAEModelConfig): boolean => {
|
||||
(vae: VAEModelConfig | CheckpointModelConfig): boolean => {
|
||||
const isCompatible = base === vae.base;
|
||||
const hasMainModel = Boolean(base);
|
||||
return !hasMainModel || !isCompatible;
|
||||
@@ -24,7 +24,7 @@ const ParamVAEModelSelect = () => {
|
||||
[base]
|
||||
);
|
||||
const _onChange = useCallback(
|
||||
(vae: VAEModelConfig | null) => {
|
||||
(vae: VAEModelConfig | CheckpointModelConfig | null) => {
|
||||
dispatch(vaeSelected(vae ? zModelIdentifierField.parse(vae) : null));
|
||||
},
|
||||
[dispatch]
|
||||
|
||||
@@ -11061,6 +11061,13 @@ export type components = {
|
||||
* @default false
|
||||
*/
|
||||
upcast_attention?: boolean;
|
||||
/**
|
||||
* Submodels
|
||||
* @default {}
|
||||
*/
|
||||
submodels?: {
|
||||
[key: string]: string;
|
||||
};
|
||||
};
|
||||
/**
|
||||
* MainDiffusersConfig
|
||||
|
||||
@@ -75,20 +75,27 @@ export type AnyModelConfig =
|
||||
| MainModelConfig
|
||||
| CLIPVisionDiffusersConfig;
|
||||
|
||||
const check_submodels = (model_type: string, config: AnyModelConfig): boolean => {
|
||||
return (
|
||||
(config.format === 'checkpoint' && config.type === 'main' && config?.submodels && model_type in config.submodels) ||
|
||||
false
|
||||
);
|
||||
};
|
||||
|
||||
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
|
||||
return config.type === 'lora';
|
||||
};
|
||||
|
||||
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
|
||||
return config.type === 'vae';
|
||||
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | CheckpointModelConfig => {
|
||||
return config.type === 'vae' || check_submodels('vae', config);
|
||||
};
|
||||
|
||||
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
|
||||
return config.type === 'vae' && config.base !== 'flux';
|
||||
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | CheckpointModelConfig => {
|
||||
return (config.type === 'vae' || check_submodels('vae', config)) && config.base !== 'flux';
|
||||
};
|
||||
|
||||
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
|
||||
return config.type === 'vae' && config.base === 'flux';
|
||||
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | CheckpointModelConfig => {
|
||||
return (config.type === 'vae' || check_submodels('vae', config)) && config.base === 'flux';
|
||||
};
|
||||
|
||||
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
|
||||
@@ -109,12 +116,14 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
|
||||
|
||||
export const isT5EncoderModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig => {
|
||||
return config.type === 't5_encoder';
|
||||
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig | CheckpointModelConfig => {
|
||||
return config.type === 't5_encoder' || check_submodels('t5_encoder', config);
|
||||
};
|
||||
|
||||
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig => {
|
||||
return config.type === 'clip_embed';
|
||||
export const isCLIPEmbedModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is CLIPEmbedModelConfig | CheckpointModelConfig => {
|
||||
return config.type === 'clip_embed' || check_submodels('clip_embed', config);
|
||||
};
|
||||
|
||||
export const isSpandrelImageToImageModelConfig = (
|
||||
|
||||
Reference in New Issue
Block a user