mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 06:18:03 -05:00
Compare commits
7 Commits
v5.9.1
...
lstein/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
434b19ac4a | ||
|
|
bc0ab06ccc | ||
|
|
081c590335 | ||
|
|
4013213578 | ||
|
|
7391e23dc8 | ||
|
|
4d95129c27 | ||
|
|
ef18ecd788 |
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "openai/clip-vit-large-patch14",
|
||||
"architectures": [
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 77,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.43.3",
|
||||
"vocab_size": 49408
|
||||
}
|
||||
@@ -406,6 +406,17 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class CLIPEmbedCheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for CLIP Embedding checkpoints."""
|
||||
|
||||
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||
format: Literal[ModelFormat.Checkpoint]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
@@ -481,6 +492,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPEmbedCheckpointConfig, CLIPEmbedCheckpointConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
@@ -7,8 +7,17 @@ from typing import Optional
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForTextEncoding,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
)
|
||||
|
||||
import invokeai.backend.assets.model_base_conf_files as model_conf_files
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
@@ -23,6 +32,7 @@ from invokeai.backend.model_manager import (
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
CLIPEmbedCheckpointConfig,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
@@ -71,7 +81,7 @@ class FluxVAELoader(ModelLoader):
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
|
||||
class ClipCheckpointModel(ModelLoader):
|
||||
class ClipDiffusersModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
@@ -93,6 +103,39 @@ class ClipCheckpointModel(ModelLoader):
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Checkpoint)
|
||||
class ClipCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CLIPEmbedCheckpointConfig):
|
||||
raise ValueError("Only CLIPEmbedCheckpointConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer:
|
||||
# Clip embedding checkpoints don't have an integrated tokenizer, so we cheat and fetch it into the HuggingFace cache
|
||||
# TODO: Fix this ugly workaround
|
||||
return CLIPTokenizer.from_pretrained(
|
||||
"InvokeAI/clip-vit-large-patch14-text-encoder", subfolder="bfloat16/tokenizer"
|
||||
)
|
||||
case SubModelType.TextEncoder:
|
||||
config_json = CLIPTextConfig.from_json_file(Path(model_conf_files.__path__[0], config.config_path))
|
||||
model = CLIPTextModel(config_json)
|
||||
state_dict = load_file(config.path)
|
||||
new_dict = {key: value for (key, value) in state_dict.items() if key.startswith("text_model.")}
|
||||
model.load_state_dict(new_dict)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
raise ValueError(
|
||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
|
||||
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
@@ -8,7 +8,6 @@ import spandrel
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
@@ -31,6 +30,7 @@ from invokeai.backend.model_manager.config import (
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
CkptType = Dict[str | int, Any]
|
||||
@@ -184,7 +184,9 @@ class ModelProbe(object):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE, ModelType.CLIPEmbed] and fields[
|
||||
"format"
|
||||
] in [
|
||||
ModelFormat.Checkpoint,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
]:
|
||||
@@ -207,7 +209,6 @@ class ModelProbe(object):
|
||||
fields["base"] == BaseModelType.StableDiffusion2
|
||||
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
|
||||
return model_info
|
||||
|
||||
@@ -258,6 +259,8 @@ class ModelProbe(object):
|
||||
return ModelType.IPAdapter
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
elif key.startswith(("text_model.embeddings", "text_model.encoder")):
|
||||
return ModelType.CLIPEmbed
|
||||
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
@@ -394,6 +397,8 @@ class ModelProbe(object):
|
||||
if base_type is BaseModelType.StableDiffusionXL
|
||||
else "stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
elif model_type is ModelType.CLIPEmbed:
|
||||
return Path("clip_text_model", "config.json")
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
|
||||
@@ -665,6 +670,11 @@ class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CLIPEmbedCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
@@ -822,7 +832,7 @@ class ONNXFolderProbe(PipelineFolderProbe):
|
||||
if (self.model_path / "unet" / "config.json").exists():
|
||||
return super().get_base_type()
|
||||
else:
|
||||
logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"')
|
||||
InvokeAILogger.get_logger().warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"')
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
@@ -956,6 +966,7 @@ ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInver
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPEmbed, CLIPEmbedCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user