Re-enable classification API as fallback (#8007)

## Summary

- Fallback to new classification API if legacy probe fails
- Method to read model metadata
- Created `StrippedModelOnDisk` class for testing
- Test to verify only a single config `matches` with a model

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
jazzhaiku
2025-05-20 11:25:38 +10:00
committed by GitHub
6 changed files with 91 additions and 38 deletions

View File

@@ -38,6 +38,7 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
InvalidModelConfigException,
ModelConfigBase,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.metadata import (
@@ -646,14 +647,18 @@ class ModelInstallService(ModelInstallServiceBase):
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo)
# New model probe API is disabled pending resolution of issue caused by a change of the ordering of checks.
# See commit message for details.
# try:
# return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
# except InvalidModelConfigException:
# return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
# WARNING!
# The legacy probe relies on the implicit order of tests to determine model classification.
# This can lead to regressions between the legacy and new probes.
# Do NOT change the order of `probe` and `classify` without implementing one of the following fixes:
# Short-term fix: `classify` tests `matches` in the same order as the legacy probe.
# Long-term fix: Improve `matches` to be more specific so that only one config matches
# any given model - eliminating ambiguity and removing reliance on order.
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
try:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
except InvalidModelConfigException:
return ModelConfigBase.classify(model_path, hash_algo, **fields)
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None

View File

@@ -146,33 +146,35 @@ class ModelConfigBase(ABC, BaseModel):
)
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
_USING_LEGACY_PROBE: ClassVar[set] = set()
_USING_CLASSIFY_API: ClassVar[set] = set()
USING_LEGACY_PROBE: ClassVar[set] = set()
USING_CLASSIFY_API: ClassVar[set] = set()
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, LegacyProbeMixin):
ModelConfigBase._USING_LEGACY_PROBE.add(cls)
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
else:
ModelConfigBase._USING_CLASSIFY_API.add(cls)
ModelConfigBase.USING_CLASSIFY_API.add(cls)
@staticmethod
def all_config_classes():
subclasses = ModelConfigBase._USING_LEGACY_PROBE | ModelConfigBase._USING_CLASSIFY_API
subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API
concrete = {cls for cls in subclasses if not isabstract(cls)}
return concrete
@staticmethod
def classify(model_path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
Created to deprecate ModelProbe.probe
"""
candidates = ModelConfigBase._USING_CLASSIFY_API
if isinstance(mod, Path | str):
mod = ModelOnDisk(mod, hash_algo)
candidates = ModelConfigBase.USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
mod = ModelOnDisk(model_path, hash_algo)
for config_cls in sorted_by_match_speed:
try:

View File

@@ -4,6 +4,7 @@ from typing import Any, Optional, TypeAlias
import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from safetensors import safe_open
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
@@ -35,12 +36,21 @@ class ModelOnDisk:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self) -> set[Path]:
def weight_files(self) -> set[Path]:
if self.path.is_file():
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
try:
with safe_open(self.path, framework="pt", device="cpu") as f:
metadata = f.metadata()
assert isinstance(metadata, dict)
return metadata
except Exception:
return {}
def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.path.is_file():
return None
@@ -64,18 +74,7 @@ class ModelOnDisk:
if path in sd_cache:
return sd_cache[path]
if not path:
components = list(self.component_paths())
match components:
case []:
raise ValueError("No weight files found for this model")
case [p]:
path = p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)
path = self.resolve_weight_file(path)
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
@@ -94,3 +93,18 @@ class ModelOnDisk:
state_dict = checkpoint.get("state_dict", checkpoint)
sd_cache[path] = state_dict
return state_dict
def resolve_weight_file(self, path: Optional[Path] = None) -> Path:
if not path:
weight_files = list(self.weight_files())
match weight_files:
case []:
raise ValueError("No weight files found for this model")
case [p]:
return p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)
return path

View File

@@ -28,9 +28,9 @@ args = parser.parse_args()
def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS):
try:
return ModelConfigBase.classify(path, hash_algo)
except InvalidModelConfigException:
return ModelProbe.probe(path, hash_algo=hash_algo)
except InvalidModelConfigException:
return ModelConfigBase.classify(path, hash_algo)
for path in args.model_path:

View File

@@ -18,13 +18,16 @@ import json
import shutil
import sys
from pathlib import Path
from typing import Optional
import humanize
import torch
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk, StateDict
from invokeai.backend.model_manager.search import ModelSearch
METADATA_KEY = "metadata_key_for_stripped_models"
def strip(v):
match v:
@@ -57,9 +60,22 @@ def dress(v):
def load_stripped_model(path: Path, *args, **kwargs):
with open(path, "r") as f:
contents = json.load(f)
contents.pop(METADATA_KEY, None)
return dress(contents)
class StrippedModelOnDisk(ModelOnDisk):
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
path = self.resolve_weight_file(path)
return load_stripped_model(path)
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
path = self.resolve_weight_file(path)
with open(path, "r") as f:
contents = json.load(f)
return contents.get(METADATA_KEY, {})
def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk:
original = ModelOnDisk(original_model_path)
if original.path.is_file():
@@ -69,11 +85,14 @@ def create_stripped_model(original_model_path: Path, stripped_model_path: Path)
stripped = ModelOnDisk(stripped_model_path)
print(f"Created clone of {original.name} at {stripped.path}")
for component_path in stripped.component_paths():
for component_path in stripped.weight_files():
original_state_dict = stripped.load_state_dict(component_path)
stripped_state_dict = strip(original_state_dict) # type: ignore
metadata = stripped.metadata()
contents = {**stripped_state_dict, METADATA_KEY: metadata}
with open(component_path, "w") as f:
json.dump(stripped_state_dict, f, indent=4)
json.dump(contents, f, indent=4)
before_size = humanize.naturalsize(original.size())
after_size = humanize.naturalsize(stripped.size())

View File

@@ -29,6 +29,9 @@ from invokeai.backend.model_manager.legacy_probe import (
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util.logging import InvokeAILogger
from scripts.strip_models import StrippedModelOnDisk
logger = InvokeAILogger.get_logger(__file__)
@pytest.mark.parametrize(
@@ -156,7 +159,8 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
pass
try:
new_config = ModelConfigBase.classify(path, hash=fake_hash, key=fake_key)
stripped_mod = StrippedModelOnDisk(path)
new_config = ModelConfigBase.classify(stripped_mod, hash=fake_hash, key=fake_key)
except InvalidModelConfigException:
pass
@@ -165,10 +169,10 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
assert legacy_config.model_dump_json() == new_config.model_dump_json()
elif legacy_config:
assert type(legacy_config) in ModelConfigBase._USING_LEGACY_PROBE
assert type(legacy_config) in ModelConfigBase.USING_LEGACY_PROBE
elif new_config:
assert type(new_config) in ModelConfigBase._USING_CLASSIFY_API
assert type(new_config) in ModelConfigBase.USING_CLASSIFY_API
else:
raise ValueError(f"Both probe and classify failed to classify model at path {path}.")
@@ -177,7 +181,6 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
configs_with_tests.add(config_type)
untested_configs = ModelConfigBase.all_config_classes() - configs_with_tests - {MinimalConfigExample}
logger = InvokeAILogger.get_logger(__file__)
logger.warning(f"Function test_regression_against_model_probe missing test case for: {untested_configs}")
@@ -255,3 +258,13 @@ def test_any_model_config_includes_all_config_classes():
expected = set(ModelConfigBase.all_config_classes()) - {MinimalConfigExample}
assert extracted == expected
def test_config_uniquely_matches_model(datadir: Path):
model_paths = ModelSearch().search(datadir / "stripped_models")
for path in model_paths:
mod = StrippedModelOnDisk(path)
matches = {cls for cls in ModelConfigBase.USING_CLASSIFY_API if cls.matches(mod)}
assert len(matches) <= 1, f"Model at path {path} matches multiple config classes: {matches}"
if not matches:
logger.warning(f"Model at path {path} does not match any config classes using classify API.")