Add SigLIP model type and probing.

This commit is contained in:
Ryan Dick
2025-02-27 15:51:29 +00:00
committed by psychedelicious
parent 7f10f8f96a
commit 34959ef573
3 changed files with 27 additions and 0 deletions

View File

@@ -152,6 +152,7 @@ class FieldDescriptions:
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
spandrel_image_to_image_model = "Image-to-Image model"
vllm_model = "VLLM model"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"

View File

@@ -76,6 +76,7 @@ class ModelType(str, Enum):
T2IAdapter = "t2i_adapter"
T5Encoder = "t5_encoder"
SpandrelImageToImage = "spandrel_image_to_image"
SigLIP = "siglip"
class SubModelType(str, Enum):
@@ -528,6 +529,17 @@ class SpandrelImageToImageConfig(ModelConfigBase):
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
class SigLIPConfig(DiffusersConfigBase):
"""Model config for SigLIP."""
type: Literal[ModelType.SigLIP] = ModelType.SigLIP
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.SigLIP.value}.{ModelFormat.Diffusers.value}")
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
@@ -575,6 +587,7 @@ AnyModelConfig = Annotated[
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]

View File

@@ -139,6 +139,7 @@ class ModelProbe(object):
"FluxControlNetModel": ModelType.ControlNet,
"SD3Transformer2DModel": ModelType.Main,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
"SiglipModel": ModelType.SigLIP,
}
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
@@ -752,6 +753,11 @@ class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
return BaseModelType.Any
class SigLIPCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
@@ -1022,6 +1028,11 @@ class SpandrelImageToImageFolderProbe(FolderProbeBase):
raise NotImplementedError()
class SigLIPFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
@@ -1055,6 +1066,7 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
@@ -1066,5 +1078,6 @@ ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpoint
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)