mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add SigLIP model type and probing.
This commit is contained in:
committed by
psychedelicious
parent
7f10f8f96a
commit
34959ef573
@@ -152,6 +152,7 @@ class FieldDescriptions:
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
spandrel_image_to_image_model = "Image-to-Image model"
|
||||
vllm_model = "VLLM model"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
|
||||
@@ -76,6 +76,7 @@ class ModelType(str, Enum):
|
||||
T2IAdapter = "t2i_adapter"
|
||||
T5Encoder = "t5_encoder"
|
||||
SpandrelImageToImage = "spandrel_image_to_image"
|
||||
SigLIP = "siglip"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
@@ -528,6 +529,17 @@ class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class SigLIPConfig(DiffusersConfigBase):
|
||||
"""Model config for SigLIP."""
|
||||
|
||||
type: Literal[ModelType.SigLIP] = ModelType.SigLIP
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.SigLIP.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
@@ -575,6 +587,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
@@ -139,6 +139,7 @@ class ModelProbe(object):
|
||||
"FluxControlNetModel": ModelType.ControlNet,
|
||||
"SD3Transformer2DModel": ModelType.Main,
|
||||
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
|
||||
"SiglipModel": ModelType.SigLIP,
|
||||
}
|
||||
|
||||
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
|
||||
@@ -752,6 +753,11 @@ class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class SigLIPCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
@@ -1022,6 +1028,11 @@ class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SigLIPFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
@@ -1055,6 +1066,7 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
@@ -1066,5 +1078,6 @@ ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpoint
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
|
||||
Reference in New Issue
Block a user