Compare commits

...

6 Commits

Author SHA1 Message Date
Brandon Rising
6d044358c8 Run ruff 2024-10-22 14:55:00 -04:00
Brandon Rising
59a2165d5e Remove submodel loads from gguf models, move submodel property to checkpoint config 2024-10-22 14:54:07 -04:00
Brandon Rising
27659fb23b Fix typing in frontend 2024-10-22 14:25:43 -04:00
Brandon Rising
ba1e35b34b Setup loading clip and vae models from bundled checkpoints 2024-10-22 14:25:43 -04:00
Brandon Rising
84dc7fbbd9 Update way t5 encoders and tokenizers are loaded on checkpoint bundles 2024-10-22 14:25:43 -04:00
Brandon Rising
7d5f9b6664 Initial test at bundled model installs for flux 2024-10-22 14:25:43 -04:00
15 changed files with 215 additions and 33 deletions

View File

@@ -95,7 +95,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
if clip_text_encoder_config.format in [ModelFormat.Diffusers, ModelFormat.Checkpoint]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(

View File

@@ -205,6 +205,7 @@ class CheckpointConfigBase(ModelConfigBase):
converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
)
submodels: dict[ModelType, str] = {}
class DiffusersConfigBase(ModelConfigBase):

View File

@@ -7,7 +7,16 @@ from typing import Optional
import accelerate
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from transformers import (
AutoConfig,
AutoModelForTextEncoding,
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
T5Config,
T5EncoderModel,
T5Tokenizer,
)
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
@@ -45,6 +54,7 @@ from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.util.model_util import (
FilteredStringDict,
convert_bundle_to_flux_transformer_checkpoint,
)
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
@@ -186,12 +196,74 @@ class FluxCheckpointModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
if not isinstance(config, MainCheckpointConfig):
raise ValueError("Only MainCheckpointConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer:
return CLIPTokenizer.from_pretrained(
"InvokeAI/clip-vit-large-patch14-text-encoder", subfolder="bfloat16/tokenizer"
)
case SubModelType.TextEncoder:
if not (prefix := config.submodels.get(ModelType.CLIPEmbed)):
raise ValueError(f"This model does not contain a {ModelType.T5Encoder} prefix")
model = CLIPTextModel(
CLIPTextConfig(
hidden_size=768,
intermediate_size=3072,
projection_dim=768,
)
)
sd = load_file(Path(config.path))
encoder_keys = [
k[len(prefix) :]
for k in sd.keys()
if k.startswith(prefix) and not k.endswith("text_projection.weight")
]
clip_sd = FilteredStringDict(sd, encoder_keys, prefix)
model.load_state_dict(state_dict=clip_sd)
return model
case SubModelType.Tokenizer2:
prefix = config.submodels.get(ModelType.T5Encoder)
return T5Tokenizer.from_pretrained(
"InvokeAI/t5-v1_1-xxl", subfolder="bfloat16/tokenizer_2", max_length=512
)
case SubModelType.TextEncoder2:
if not (prefix := config.submodels.get(ModelType.T5Encoder)):
raise ValueError(f"This model does not contain a {ModelType.T5Encoder} prefix")
sd = load_file(Path(config.path))
model = T5EncoderModel(
T5Config(
d_model=4096,
d_ff=10240,
num_layers=24,
num_decoder_layers=24,
num_heads=64,
feed_forward_proj="gated-gelu",
is_encoder_decoder=False,
)
)
embeds = model.get_input_embeddings() # type: ignore
if not isinstance(embeds, torch.nn.Embedding):
raise ValueError("Unable to load given T5 Encoder")
sd[f"{prefix}encoder.embed_tokens.weight"] = embeds.weight
encoder_keys = [k[len(prefix) :] for k in sd.keys() if k.startswith(prefix)]
t5_sd = FilteredStringDict(sd, encoder_keys, prefix)
model.load_state_dict(state_dict=t5_sd)
return model
case SubModelType.Transformer:
return self._load_from_singlefile(config)
case SubModelType.VAE:
model_path = Path(config.path)
if not (prefix := config.submodels.get(ModelType.VAE)):
raise ValueError(f"This model does not contain a {ModelType.VAE} prefix")
with SilenceWarnings():
model = AutoEncoder(ae_params["flux"])
sd = load_file(model_path)
encoder_keys = [k[len(prefix) :] for k in sd.keys() if k.startswith(prefix)]
t5_sd = FilteredStringDict(sd, encoder_keys, prefix)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
return model
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"

View File

@@ -215,6 +215,9 @@ class ModelProbe(object):
fields["base"] == BaseModelType.StableDiffusion2
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)
get_submodels = getattr(probe, "get_submodels", None)
if callable(get_submodels):
fields["submodels"] = get_submodels()
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
return model_info
@@ -443,6 +446,13 @@ class ModelProbe(object):
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
PREFIX_MAP = {
"text_encoders.t5xxl.transformer.": ModelType.T5Encoder,
"model.diffusion_model.": ModelType.Main,
"text_encoders.clip_l.transformer.": ModelType.CLIPEmbed,
"vae": ModelType.VAE,
}
# Probing utilities
MODEL_NAME_TO_PREPROCESSOR = {
"canny": "canny_image_processor",
@@ -492,6 +502,16 @@ class CheckpointProbeBase(ProbeBase):
super().__init__(model_path)
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
def get_submodels(self) -> dict[ModelType, str]:
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
submodels: dict[ModelType, str] = {}
submodels = {
model_type: prefix
for prefix, model_type in PREFIX_MAP.items()
if any(key.startswith(prefix) for key in state_dict.keys() if isinstance(key, str))
}
return submodels
def get_format(self) -> ModelFormat:
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
if (

View File

@@ -1,8 +1,9 @@
"""Utilities for parsing model files, used mostly by probe.py"""
import json
from collections.abc import MutableMapping
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Generic, Iterable, Iterator, Optional, Tuple, TypeVar, Union
import safetensors
import torch
@@ -165,3 +166,65 @@ def convert_bundle_to_flux_transformer_checkpoint(
del transformer_state_dict[k]
return original_state_dict
_KT = TypeVar("_KT", bound=str)
_VT = TypeVar("_VT")
class FilteredStringDict(MutableMapping[_KT, _VT], Generic[_KT, _VT]):
def __init__(self, original_dict: Dict[_KT, _VT], keys: Iterable[_KT], prefix: Optional[_KT] = None) -> None:
self._original_dict: Dict[_KT, _VT] = original_dict
self.prefix = prefix
self._keys = set(keys) # Keys without the prefix
def _get_prefixed_key(self, key: _KT) -> _KT:
if self.prefix:
prefixed_key = self.prefix + key
if not isinstance(prefixed_key, type(key)):
raise ValueError("Unable to prefix keys")
return prefixed_key
return key
def __getitem__(self, key: _KT) -> _VT:
if key in self._keys:
prefixed_key = self._get_prefixed_key(key)
return self._original_dict[prefixed_key]
else:
raise KeyError(key)
def __setitem__(self, key: _KT, value: _VT) -> None:
if key in self._keys:
prefixed_key = self._get_prefixed_key(key)
self._original_dict[prefixed_key] = value
else:
raise KeyError(f"Key {key} not allowed in FilteredDict")
def __delitem__(self, key: _KT) -> None:
if key in self._keys:
prefixed_key = self._get_prefixed_key(key)
del self._original_dict[prefixed_key]
self._keys.remove(key)
else:
raise KeyError(key)
def __iter__(self) -> Iterator[_KT]:
return iter(self._keys)
def __len__(self) -> int:
return len(self._keys)
def __repr__(self) -> str:
items = ", ".join(f"{k}: {self._original_dict[self._get_prefixed_key(k)]!r}" for k in self._keys)
return f"{{{items}}}"
def keys(self) -> Iterator[_KT]:
return iter(self._keys)
def values(self) -> Iterator[_VT]:
for key in self._keys:
yield self[key] # Access values lazily
def items(self) -> Iterator[Tuple[_KT, _VT]]:
for key in self._keys:
yield key, self[key] # Access values lazily

View File

@@ -6,7 +6,7 @@ import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPEmbedModelConfig } from 'services/api/types';
import type { CheckpointModelConfig, CLIPEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -19,7 +19,7 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(value: CLIPEmbedModelConfig | null) => {
(value: CLIPEmbedModelConfig | CheckpointModelConfig | null) => {
if (!value) {
return;
}

View File

@@ -6,7 +6,7 @@ import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } f
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { CheckpointModelConfig, VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -19,7 +19,7 @@ const FluxVAEModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const _onChange = useCallback(
(value: VAEModelConfig | null) => {
(value: VAEModelConfig | CheckpointModelConfig | null) => {
if (!value) {
return;
}

View File

@@ -7,7 +7,11 @@ import { selectIsModelsTabDisabled } from 'features/system/store/configSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
import type {
CheckpointModelConfig,
T5EncoderBnbQuantizedLlmInt8bModelConfig,
T5EncoderModelConfig,
} from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -20,7 +24,7 @@ const T5EncoderModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const _onChange = useCallback(
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | CheckpointModelConfig | null) => {
if (!value) {
return;
}

View File

@@ -5,7 +5,7 @@ import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { CheckpointModelConfig, VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -16,7 +16,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useVAEModels();
const _onChange = useCallback(
(value: VAEModelConfig | null) => {
(value: VAEModelConfig | CheckpointModelConfig | null) => {
if (!value) {
return;
}

View File

@@ -6,7 +6,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPEmbedModelConfig } from 'services/api/types';
import type { CheckpointModelConfig, CLIPEmbedModelConfig } from 'services/api/types';
const ParamCLIPEmbedModelSelect = () => {
const dispatch = useAppDispatch();
@@ -15,7 +15,7 @@ const ParamCLIPEmbedModelSelect = () => {
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(clipEmbedModel: CLIPEmbedModelConfig | null) => {
(clipEmbedModel: CLIPEmbedModelConfig | CheckpointModelConfig | null) => {
if (clipEmbedModel) {
dispatch(clipEmbedModelSelected(zModelIdentifierField.parse(clipEmbedModel)));
}

View File

@@ -6,7 +6,11 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
import type {
CheckpointModelConfig,
T5EncoderBnbQuantizedLlmInt8bModelConfig,
T5EncoderModelConfig,
} from 'services/api/types';
const ParamT5EncoderModelSelect = () => {
const dispatch = useAppDispatch();
@@ -15,7 +19,9 @@ const ParamT5EncoderModelSelect = () => {
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const _onChange = useCallback(
(t5EncoderModel: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
(
t5EncoderModel: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | CheckpointModelConfig | null
) => {
if (t5EncoderModel) {
dispatch(t5EncoderModelSelected(zModelIdentifierField.parse(t5EncoderModel)));
}

View File

@@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { CheckpointModelConfig, VAEModelConfig } from 'services/api/types';
const ParamFLUXVAEModelSelect = () => {
const dispatch = useAppDispatch();
@@ -16,7 +16,7 @@ const ParamFLUXVAEModelSelect = () => {
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const _onChange = useCallback(
(vae: VAEModelConfig | null) => {
(vae: VAEModelConfig | CheckpointModelConfig | null) => {
if (vae) {
dispatch(fluxVAESelected(zModelIdentifierField.parse(vae)));
}

View File

@@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { CheckpointModelConfig, VAEModelConfig } from 'services/api/types';
const ParamVAEModelSelect = () => {
const dispatch = useAppDispatch();
@@ -16,7 +16,7 @@ const ParamVAEModelSelect = () => {
const vae = useAppSelector(selectVAE);
const [modelConfigs, { isLoading }] = useVAEModels();
const getIsDisabled = useCallback(
(vae: VAEModelConfig): boolean => {
(vae: VAEModelConfig | CheckpointModelConfig): boolean => {
const isCompatible = base === vae.base;
const hasMainModel = Boolean(base);
return !hasMainModel || !isCompatible;
@@ -24,7 +24,7 @@ const ParamVAEModelSelect = () => {
[base]
);
const _onChange = useCallback(
(vae: VAEModelConfig | null) => {
(vae: VAEModelConfig | CheckpointModelConfig | null) => {
dispatch(vaeSelected(vae ? zModelIdentifierField.parse(vae) : null));
},
[dispatch]

View File

@@ -11061,6 +11061,13 @@ export type components = {
* @default false
*/
upcast_attention?: boolean;
/**
* Submodels
* @default {}
*/
submodels?: {
[key: string]: string;
};
};
/**
* MainDiffusersConfig

View File

@@ -75,20 +75,27 @@ export type AnyModelConfig =
| MainModelConfig
| CLIPVisionDiffusersConfig;
const check_submodels = (model_type: string, config: AnyModelConfig): boolean => {
return (
(config.format === 'checkpoint' && config.type === 'main' && config?.submodels && model_type in config.submodels) ||
false
);
};
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
return config.type === 'lora';
};
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae';
export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | CheckpointModelConfig => {
return config.type === 'vae' || check_submodels('vae', config);
};
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae' && config.base !== 'flux';
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | CheckpointModelConfig => {
return (config.type === 'vae' || check_submodels('vae', config)) && config.base !== 'flux';
};
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae' && config.base === 'flux';
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig | CheckpointModelConfig => {
return (config.type === 'vae' || check_submodels('vae', config)) && config.base === 'flux';
};
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
@@ -109,12 +116,14 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
export const isT5EncoderModelConfig = (
config: AnyModelConfig
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig => {
return config.type === 't5_encoder';
): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig | CheckpointModelConfig => {
return config.type === 't5_encoder' || check_submodels('t5_encoder', config);
};
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig => {
return config.type === 'clip_embed';
export const isCLIPEmbedModelConfig = (
config: AnyModelConfig
): config is CLIPEmbedModelConfig | CheckpointModelConfig => {
return config.type === 'clip_embed' || check_submodels('clip_embed', config);
};
export const isSpandrelImageToImageModelConfig = (