Port LLaVA to new API (#7817)

## Summary

- Port LLaVA model config to new classification API
- Add 2 test cases (stripped LLaVA models variants to git-lfs)

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
jazzhaiku
2025-03-24 22:50:54 +11:00
committed by GitHub
32 changed files with 160 additions and 6 deletions

View File

@@ -21,6 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
"""
# pyright: reportIncompatibleVariableOverride=false
import json
import logging
import time
from abc import ABC, abstractmethod
@@ -232,6 +233,23 @@ class ModelOnDisk:
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def repo_variant(self):
if self.format_type == ModelFormat.Checkpoint:
return None
weight_files = list(self.path.glob("**/*.safetensors"))
weight_files.extend(list(self.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
@staticmethod
def load_state_dict(path: Path):
with SilenceWarnings():
@@ -359,6 +377,21 @@ class ModelConfigBase(ABC, BaseModel):
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
pass
@staticmethod
def cast_overrides(overrides: dict[str, Any]):
"""Casts user overrides from str to Enum"""
if "type" in overrides:
overrides["type"] = ModelType(overrides["type"])
if "format" in overrides:
overrides["format"] = ModelFormat(overrides["format"])
if "base" in overrides:
overrides["base"] = BaseModelType(overrides["base"])
if "source_type" in overrides:
overrides["source_type"] = ModelSourceType(overrides["source_type"])
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
@@ -366,14 +399,21 @@ class ModelConfigBase(ABC, BaseModel):
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
fields = cls.parse(mod)
cls.cast_overrides(overrides)
fields.update(overrides)
type = fields.get("type") or cls.model_fields["type"].default
base = fields.get("base") or cls.model_fields["base"].default
fields["path"] = mod.path.as_posix()
fields["source"] = fields.get("source") or fields["path"]
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["name"] = mod.name
fields["name"] = name = fields.get("name") or mod.name
fields["hash"] = fields.get("hash") or mod.hash()
fields["key"] = fields.get("key") or uuid_string()
fields["description"] = fields.get("description") or f"{base.value} {type.value} model {name}"
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()
fields.update(overrides)
return cls(**fields)
@@ -625,12 +665,34 @@ class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class LlavaOnevisionConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for Llava Onevision models."""
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.format_type == ModelFormat.Checkpoint:
return False
config_path = mod.path / "config.json"
try:
with open(config_path, "r") as file:
config = json.load(file)
except FileNotFoundError:
return False
architectures = config.get("architectures")
return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration"
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": BaseModelType.Any,
"variant": ModelVariantType.Normal,
}
def get_model_discriminator_value(v: Any) -> str:
"""

View File

@@ -148,22 +148,24 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
configs_with_tests = set()
model_paths = ModelSearch().search(datadir / "stripped_models")
fake_hash = "abcdefgh" # skip hashing to make test quicker
fake_key = "123" # fixed uuid for comparison
for path in model_paths:
legacy_config = new_config = None
try:
legacy_config = ModelProbe.probe(path, {"hash": fake_hash})
legacy_config = ModelProbe.probe(path, {"hash": fake_hash, "key": fake_key})
except InvalidModelConfigException:
pass
try:
new_config = ModelConfigBase.classify(path, hash=fake_hash)
new_config = ModelConfigBase.classify(path, hash=fake_hash, key=fake_key)
except InvalidModelConfigException:
pass
if legacy_config and new_config:
assert legacy_config == new_config
assert type(legacy_config) is type(new_config)
assert legacy_config.model_dump_json() == new_config.model_dump_json()
elif legacy_config:
assert type(legacy_config) in ModelConfigBase._USING_LEGACY_PROBE

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:33e0fb93dadacb864bd2f2e8441e147daa2baceb67f94d3ef5283b495572cea0
size 122

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2466d1704df30f0067f28d8e30e0190a1bf74e5b430942697af974d162a056bd
size 826

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:839a4fba0bd6949f0db22d4f840935cb0318f6ac28b29c9ce1a5b15735a4a740
size 2591

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:89dc53229f50b59570b6852056dafeac8116c458f1a748bff491b6d4d24d3b51
size 126

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8831e4f1a044471340f7c0a83d7bd71306a5b867e95fd870f74d0c5308a904d5
size 1671853

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a0e4b0349d188ce8618b4cfdc01f87d58f912bf9c55cfb8a2d80b9ecae39a870
size 136697

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3644c108b9f0fa53e62ff422a9be6639642f0e64dab4a71f961c7911d4386384
size 1732

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:04e9899e93f2a412c94e153cab4081457c9a44defb3b2c0b9df673d42c42cdd0
size 178

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f4f79e08d97f4d1c87f8d89264f525c8789da3b73b3bb55d1e12f692f41a7b1b
size 367

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3c0ce3213b50ff38d8aa1e91136a2d2cb142a3f569246170872e439cb2a29d15
size 7028579

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:494a5592a446535be00acc531ccf7a53fd6c6c392c122d444c389160261572e0
size 1800

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1e71b2e75b90ddf696692529485b4a75fd54ecd4bbc03e2cc7be4af032875765
size 428

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ca10d7e9fb3ed18575dd1e277a2579c16d108e32f27439684afa0e10b1440910
size 2776833

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:33e0fb93dadacb864bd2f2e8441e147daa2baceb67f94d3ef5283b495572cea0
size 122

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2466d1704df30f0067f28d8e30e0190a1bf74e5b430942697af974d162a056bd
size 826

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:69277c2c9ba8a4f61d2e72f79fbe02e043b8c3af670858a606818527f971a0c2
size 2528

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:89dc53229f50b59570b6852056dafeac8116c458f1a748bff491b6d4d24d3b51
size 126

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8831e4f1a044471340f7c0a83d7bd71306a5b867e95fd870f74d0c5308a904d5
size 1671853

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:073982033138424bfbc1cfb89344dead5abe5f07662e98c1451d273e3bc8d4a4
size 97724

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:eda752c2a09970cd7cb70499a70cb0b1bc9a43a62278602e36e5ee44ddc98650
size 24660

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7e861fb6dde62d5a36b3ad6ab525cf892711a2ecca0b92844e001f473e0c84b1
size 23020

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:591625a572138c8c333dc296af0502e23c76c76cc8f26dc6ed0901ae83bd08f6
size 893

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e389b969d3b9b7120f136fbf6592cbbc1c07326157cae6f863fb64559261dae9
size 77952

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3644c108b9f0fa53e62ff422a9be6639642f0e64dab4a71f961c7911d4386384
size 1732

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:04e9899e93f2a412c94e153cab4081457c9a44defb3b2c0b9df673d42c42cdd0
size 178

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f4f79e08d97f4d1c87f8d89264f525c8789da3b73b3bb55d1e12f692f41a7b1b
size 367

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3c0ce3213b50ff38d8aa1e91136a2d2cb142a3f569246170872e439cb2a29d15
size 7028579

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b3151cdc153a44fa4e28abc8fc83567fc663aef7cd9dc041f5d66e617fe10b8e
size 1801

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1e71b2e75b90ddf696692529485b4a75fd54ecd4bbc03e2cc7be4af032875765
size 428

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ca10d7e9fb3ed18575dd1e277a2579c16d108e32f27439684afa0e10b1440910
size 2776833