Add LlavaOnevision model type and probing logic.

This commit is contained in:
Ryan Dick
2025-02-26 02:52:45 +00:00
committed by psychedelicious
parent 28d3356710
commit 3f29293e39
2 changed files with 26 additions and 0 deletions

View File

@@ -78,6 +78,7 @@ class ModelType(str, Enum):
SpandrelImageToImage = "spandrel_image_to_image"
SigLIP = "siglip"
FluxRedux = "flux_redux"
LlavaOnevision = "llava_onevision"
class SubModelType(str, Enum):
@@ -552,6 +553,17 @@ class FluxReduxConfig(ModelConfigBase):
return Tag(f"{ModelType.FluxRedux.value}.{ModelFormat.Checkpoint.value}")
class LlavaOnevisionConfig(DiffusersConfigBase):
"""Model config for Llava Onevision models."""
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LlavaOnevision.value}.{ModelFormat.Diffusers.value}")
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
@@ -601,6 +613,7 @@ AnyModelConfig = Annotated[
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]

View File

@@ -141,6 +141,7 @@ class ModelProbe(object):
"SD3Transformer2DModel": ModelType.Main,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
"SiglipModel": ModelType.SigLIP,
"LlavaOnevisionForConditionalGeneration": ModelType.LlavaOnevision,
}
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
@@ -767,6 +768,11 @@ class FluxReduxCheckpointProbe(CheckpointProbeBase):
return BaseModelType.Flux
class LlavaOnevisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
@@ -1047,6 +1053,11 @@ class FluxReduxFolderProbe(FolderProbeBase):
raise NotImplementedError()
class LlaveOnevisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
@@ -1082,6 +1093,7 @@ ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderPro
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
@@ -1095,5 +1107,6 @@ ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpoi
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)