mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Port LLaVA to new API (#7817)
## Summary - Port LLaVA model config to new classification API - Add 2 test cases (stripped LLaVA models variants to git-lfs) ## Related Issues / Discussions <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions <!--WHEN APPLICABLE: Describe how you have tested the changes in this PR. Provide enough detail that a reviewer can reproduce your tests.--> ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [ ] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
@@ -21,6 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
"""
|
||||
|
||||
# pyright: reportIncompatibleVariableOverride=false
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -232,6 +233,23 @@ class ModelOnDisk:
|
||||
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
|
||||
return {f for f in self.path.rglob("*") if f.suffix in extensions}
|
||||
|
||||
def repo_variant(self):
|
||||
if self.format_type == ModelFormat.Checkpoint:
|
||||
return None
|
||||
|
||||
weight_files = list(self.path.glob("**/*.safetensors"))
|
||||
weight_files.extend(list(self.path.glob("**/*.bin")))
|
||||
for x in weight_files:
|
||||
if ".fp16" in x.suffixes:
|
||||
return ModelRepoVariant.FP16
|
||||
if "openvino_model" in x.name:
|
||||
return ModelRepoVariant.OpenVINO
|
||||
if "flax_model" in x.name:
|
||||
return ModelRepoVariant.Flax
|
||||
if x.suffix == ".onnx":
|
||||
return ModelRepoVariant.ONNX
|
||||
return ModelRepoVariant.Default
|
||||
|
||||
@staticmethod
|
||||
def load_state_dict(path: Path):
|
||||
with SilenceWarnings():
|
||||
@@ -359,6 +377,21 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def cast_overrides(overrides: dict[str, Any]):
|
||||
"""Casts user overrides from str to Enum"""
|
||||
if "type" in overrides:
|
||||
overrides["type"] = ModelType(overrides["type"])
|
||||
|
||||
if "format" in overrides:
|
||||
overrides["format"] = ModelFormat(overrides["format"])
|
||||
|
||||
if "base" in overrides:
|
||||
overrides["base"] = BaseModelType(overrides["base"])
|
||||
|
||||
if "source_type" in overrides:
|
||||
overrides["source_type"] = ModelSourceType(overrides["source_type"])
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
|
||||
"""Creates an instance of this config or raises InvalidModelConfigException."""
|
||||
@@ -366,14 +399,21 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
|
||||
|
||||
fields = cls.parse(mod)
|
||||
cls.cast_overrides(overrides)
|
||||
fields.update(overrides)
|
||||
|
||||
type = fields.get("type") or cls.model_fields["type"].default
|
||||
base = fields.get("base") or cls.model_fields["base"].default
|
||||
|
||||
fields["path"] = mod.path.as_posix()
|
||||
fields["source"] = fields.get("source") or fields["path"]
|
||||
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
||||
fields["name"] = mod.name
|
||||
fields["name"] = name = fields.get("name") or mod.name
|
||||
fields["hash"] = fields.get("hash") or mod.hash()
|
||||
fields["key"] = fields.get("key") or uuid_string()
|
||||
fields["description"] = fields.get("description") or f"{base.value} {type.value} model {name}"
|
||||
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()
|
||||
|
||||
fields.update(overrides)
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
@@ -625,12 +665,34 @@ class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
|
||||
class LlavaOnevisionConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
"""Model config for Llava Onevision models."""
|
||||
|
||||
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
if mod.format_type == ModelFormat.Checkpoint:
|
||||
return False
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
try:
|
||||
with open(config_path, "r") as file:
|
||||
config = json.load(file)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
architectures = config.get("architectures")
|
||||
return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration"
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {
|
||||
"base": BaseModelType.Any,
|
||||
"variant": ModelVariantType.Normal,
|
||||
}
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
|
||||
@@ -148,22 +148,24 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
|
||||
configs_with_tests = set()
|
||||
model_paths = ModelSearch().search(datadir / "stripped_models")
|
||||
fake_hash = "abcdefgh" # skip hashing to make test quicker
|
||||
fake_key = "123" # fixed uuid for comparison
|
||||
|
||||
for path in model_paths:
|
||||
legacy_config = new_config = None
|
||||
|
||||
try:
|
||||
legacy_config = ModelProbe.probe(path, {"hash": fake_hash})
|
||||
legacy_config = ModelProbe.probe(path, {"hash": fake_hash, "key": fake_key})
|
||||
except InvalidModelConfigException:
|
||||
pass
|
||||
|
||||
try:
|
||||
new_config = ModelConfigBase.classify(path, hash=fake_hash)
|
||||
new_config = ModelConfigBase.classify(path, hash=fake_hash, key=fake_key)
|
||||
except InvalidModelConfigException:
|
||||
pass
|
||||
|
||||
if legacy_config and new_config:
|
||||
assert legacy_config == new_config
|
||||
assert type(legacy_config) is type(new_config)
|
||||
assert legacy_config.model_dump_json() == new_config.model_dump_json()
|
||||
|
||||
elif legacy_config:
|
||||
assert type(legacy_config) in ModelConfigBase._USING_LEGACY_PROBE
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:33e0fb93dadacb864bd2f2e8441e147daa2baceb67f94d3ef5283b495572cea0
|
||||
size 122
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2466d1704df30f0067f28d8e30e0190a1bf74e5b430942697af974d162a056bd
|
||||
size 826
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:839a4fba0bd6949f0db22d4f840935cb0318f6ac28b29c9ce1a5b15735a4a740
|
||||
size 2591
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:89dc53229f50b59570b6852056dafeac8116c458f1a748bff491b6d4d24d3b51
|
||||
size 126
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8831e4f1a044471340f7c0a83d7bd71306a5b867e95fd870f74d0c5308a904d5
|
||||
size 1671853
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a0e4b0349d188ce8618b4cfdc01f87d58f912bf9c55cfb8a2d80b9ecae39a870
|
||||
size 136697
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3644c108b9f0fa53e62ff422a9be6639642f0e64dab4a71f961c7911d4386384
|
||||
size 1732
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:04e9899e93f2a412c94e153cab4081457c9a44defb3b2c0b9df673d42c42cdd0
|
||||
size 178
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f4f79e08d97f4d1c87f8d89264f525c8789da3b73b3bb55d1e12f692f41a7b1b
|
||||
size 367
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3c0ce3213b50ff38d8aa1e91136a2d2cb142a3f569246170872e439cb2a29d15
|
||||
size 7028579
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:494a5592a446535be00acc531ccf7a53fd6c6c392c122d444c389160261572e0
|
||||
size 1800
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1e71b2e75b90ddf696692529485b4a75fd54ecd4bbc03e2cc7be4af032875765
|
||||
size 428
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ca10d7e9fb3ed18575dd1e277a2579c16d108e32f27439684afa0e10b1440910
|
||||
size 2776833
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:33e0fb93dadacb864bd2f2e8441e147daa2baceb67f94d3ef5283b495572cea0
|
||||
size 122
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2466d1704df30f0067f28d8e30e0190a1bf74e5b430942697af974d162a056bd
|
||||
size 826
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:69277c2c9ba8a4f61d2e72f79fbe02e043b8c3af670858a606818527f971a0c2
|
||||
size 2528
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:89dc53229f50b59570b6852056dafeac8116c458f1a748bff491b6d4d24d3b51
|
||||
size 126
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8831e4f1a044471340f7c0a83d7bd71306a5b867e95fd870f74d0c5308a904d5
|
||||
size 1671853
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:073982033138424bfbc1cfb89344dead5abe5f07662e98c1451d273e3bc8d4a4
|
||||
size 97724
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:eda752c2a09970cd7cb70499a70cb0b1bc9a43a62278602e36e5ee44ddc98650
|
||||
size 24660
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7e861fb6dde62d5a36b3ad6ab525cf892711a2ecca0b92844e001f473e0c84b1
|
||||
size 23020
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:591625a572138c8c333dc296af0502e23c76c76cc8f26dc6ed0901ae83bd08f6
|
||||
size 893
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e389b969d3b9b7120f136fbf6592cbbc1c07326157cae6f863fb64559261dae9
|
||||
size 77952
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3644c108b9f0fa53e62ff422a9be6639642f0e64dab4a71f961c7911d4386384
|
||||
size 1732
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:04e9899e93f2a412c94e153cab4081457c9a44defb3b2c0b9df673d42c42cdd0
|
||||
size 178
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f4f79e08d97f4d1c87f8d89264f525c8789da3b73b3bb55d1e12f692f41a7b1b
|
||||
size 367
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3c0ce3213b50ff38d8aa1e91136a2d2cb142a3f569246170872e439cb2a29d15
|
||||
size 7028579
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b3151cdc153a44fa4e28abc8fc83567fc663aef7cd9dc041f5d66e617fe10b8e
|
||||
size 1801
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1e71b2e75b90ddf696692529485b4a75fd54ecd4bbc03e2cc7be4af032875765
|
||||
size 428
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ca10d7e9fb3ed18575dd1e277a2579c16d108e32f27439684afa0e10b1440910
|
||||
size 2776833
|
||||
Reference in New Issue
Block a user