tests(mm): refactor model identification tests

Overhaul of model identification (probing) tests. Previously we didn't
test the correctness of probing except in a few narrow cases - now we
do.

See tests/model_identification/README.md for a detailed overview of the
new test setup. It includes instructions for adding a new test case. In
brief:

- Download the model you want to add as a test case
- Run a script against it to generate the test model files
- Fill in the expected model type/format/base/etc in the generated test
metadata JSON file

Included test cases:
- All starter models
- A handful of other models that I had installed
- Models present in the previous test cases as smoke tests, now also
tested for correctness
This commit is contained in:
psychedelicious
2025-10-09 21:31:37 +11:00
parent 1373b440c3
commit 4334cf733a
401 changed files with 1393 additions and 849 deletions

1
.gitattributes vendored
View File

@@ -4,3 +4,4 @@
* text=auto
docker/** text eol=lf
tests/test_model_probe/stripped_models/** filter=lfs diff=lfs merge=lfs -text
tests/model_identification/stripped_models/** filter=lfs diff=lfs merge=lfs -text

View File

@@ -1,134 +0,0 @@
"""
Usage:
strip_models.py <models_input_dir> <stripped_output_dir>
Strips tensor data from model state_dicts while preserving metadata.
Used to create lightweight models for testing model classification.
Parameters:
<models_input_dir> Directory containing original models.
<stripped_output_dir> Directory where stripped models will be saved.
Options:
-h, --help Show this help message and exit
"""
import argparse
import json
import shutil
import sys
from pathlib import Path
from typing import Optional
import humanize
import torch
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk, StateDict
from invokeai.backend.model_manager.search import ModelSearch
METADATA_KEY = "metadata_key_for_stripped_models"
def strip(v):
match v:
case torch.Tensor():
return {"shape": v.shape, "dtype": str(v.dtype), "fakeTensor": True}
case dict():
return {k: strip(v) for k, v in v.items()}
case list() | tuple():
return [strip(x) for x in v]
case _:
return v
STR_TO_DTYPE = {str(dtype): dtype for dtype in torch.__dict__.values() if isinstance(dtype, torch.dtype)}
def dress(v):
match v:
case {"shape": shape, "dtype": dtype_str, "fakeTensor": True}:
dtype = STR_TO_DTYPE[dtype_str]
return torch.empty(shape, dtype=dtype)
case dict():
return {k: dress(v) for k, v in v.items()}
case list() | tuple():
return [dress(x) for x in v]
case _:
return v
def load_stripped_model(path: Path, *args, **kwargs):
with open(path, "r") as f:
contents = json.load(f)
contents.pop(METADATA_KEY, None)
return dress(contents)
class StrippedModelOnDisk(ModelOnDisk):
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
path = self.resolve_weight_file(path)
return load_stripped_model(path)
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
path = self.resolve_weight_file(path)
with open(path, "r") as f:
contents = json.load(f)
return contents.get(METADATA_KEY, {})
def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk:
original = ModelOnDisk(original_model_path)
if original.path.is_file():
shutil.copy2(original.path, stripped_model_path)
else:
shutil.copytree(original.path, stripped_model_path, dirs_exist_ok=True)
stripped = ModelOnDisk(stripped_model_path)
print(f"Created clone of {original.name} at {stripped.path}")
for component_path in stripped.weight_files():
original_state_dict = stripped.load_state_dict(component_path)
stripped_state_dict = strip(original_state_dict) # type: ignore
metadata = stripped.metadata()
contents = {**stripped_state_dict, METADATA_KEY: metadata}
with open(component_path, "w") as f:
json.dump(contents, f, indent=4)
before_size = humanize.naturalsize(original.size())
after_size = humanize.naturalsize(stripped.size())
print(f"{original.name} before: {before_size}, after: {after_size}")
return stripped
def parse_arguments():
class Parser(argparse.ArgumentParser):
def error(self, reason):
raise ValueError(reason)
parser = Parser()
parser.add_argument("models_input_dir", type=Path)
parser.add_argument("stripped_output_dir", type=Path)
try:
args = parser.parse_args()
except ValueError as e:
print(f"Error: {e}", file=sys.stderr)
print(__doc__, file=sys.stderr)
sys.exit(2)
if not args.models_input_dir.exists():
parser.error(f"Error: Input models directory '{args.models_input_dir}' does not exist.")
if not args.models_input_dir.is_dir():
parser.error(f"Error: '{args.input_models_dir}' is not a directory.")
return args
if __name__ == "__main__":
args = parse_arguments()
model_paths = sorted(ModelSearch().search(args.models_input_dir))
for path in model_paths:
stripped_path = args.stripped_output_dir / path.name
create_stripped_model(path, stripped_path)

View File

@@ -32,9 +32,8 @@ from invokeai.app.services.model_install.model_install_common import (
URLModelSource,
)
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
InvalidModelConfigException,
ModelFormat,
ModelRepoVariant,
ModelType,
@@ -71,7 +70,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
key = None
with pytest.raises((ValidationError, InvalidModelConfigException)):
with pytest.raises(ValidationError):
key = mm2_installer.register_path(
embedding_file, ModelRecordChanges(name="banana_sushi", type=ModelType("lora"))
)

View File

@@ -16,15 +16,24 @@ from invokeai.app.services.model_records import (
UnknownModelException,
)
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings,
MainDiffusersConfig,
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.main import (
Main_Diffusers_SD1_Config,
Main_Diffusers_SD2_Config,
Main_Diffusers_SDXL_Config,
MainModelDefaultSettings,
TI_File_Config,
VAEDiffusersConfig,
)
from invokeai.backend.model_manager.taxonomy import ModelSourceType
from invokeai.backend.model_manager.configs.textual_inversion import TI_File_SD1_Config
from invokeai.backend.model_manager.configs.vae import VAE_Diffusers_SD1_Config
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@@ -40,8 +49,8 @@ def store(
return ModelRecordServiceSQL(db, logger)
def example_ti_config(key: Optional[str] = None) -> TI_File_Config:
config = TI_File_Config(
def example_ti_config(key: Optional[str] = None) -> TI_File_SD1_Config:
config = TI_File_SD1_Config(
source="test/source/",
source_type=ModelSourceType.Path,
path="/tmp/pokemon.bin",
@@ -61,7 +70,7 @@ def test_type(store: ModelRecordServiceBase):
config = example_ti_config("key1")
store.add_model(config)
config1 = store.get_model("key1")
assert isinstance(config1, TI_File_Config)
assert isinstance(config1, TI_File_SD1_Config)
def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):
@@ -122,7 +131,7 @@ def test_exists(store: ModelRecordServiceBase):
def test_filter(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig(
config1 = Main_Diffusers_SD1_Config(
key="config1",
path="/tmp/config1",
name="config1",
@@ -132,8 +141,11 @@ def test_filter(store: ModelRecordServiceBase):
file_size=1001,
source="test/source",
source_type=ModelSourceType.Path,
variant=ModelVariantType.Normal,
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
config2 = MainDiffusersConfig(
config2 = Main_Diffusers_SD1_Config(
key="config2",
path="/tmp/config2",
name="config2",
@@ -143,17 +155,21 @@ def test_filter(store: ModelRecordServiceBase):
file_size=1002,
source="test/source",
source_type=ModelSourceType.Path,
variant=ModelVariantType.Normal,
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
config3 = VAEDiffusersConfig(
config3 = VAE_Diffusers_SD1_Config(
key="config3",
path="/tmp/config3",
name="config3",
base=BaseModelType("sd-2"),
base=BaseModelType.StableDiffusion1,
type=ModelType.VAE,
hash="CONFIG3HASH",
file_size=1003,
source="test/source",
source_type=ModelSourceType.Path,
repo_variant=ModelRepoVariant.Default,
)
for c in config1, config2, config3:
store.add_model(c)
@@ -176,7 +192,7 @@ def test_filter(store: ModelRecordServiceBase):
def test_unique(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig(
config1 = Main_Diffusers_SD1_Config(
path="/tmp/config1",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
@@ -185,28 +201,35 @@ def test_unique(store: ModelRecordServiceBase):
file_size=1004,
source="test/source/",
source_type=ModelSourceType.Path,
variant=ModelVariantType.Normal,
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
config2 = MainDiffusersConfig(
config2 = Main_Diffusers_SD2_Config(
path="/tmp/config2",
base=BaseModelType("sd-2"),
base=BaseModelType.StableDiffusion2,
type=ModelType.Main,
name="nonuniquename",
hash="CONFIG1HASH",
file_size=1005,
source="test/source/",
source_type=ModelSourceType.Path,
variant=ModelVariantType.Normal,
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
config3 = VAEDiffusersConfig(
config3 = VAE_Diffusers_SD1_Config(
path="/tmp/config3",
base=BaseModelType("sd-2"),
base=BaseModelType.StableDiffusion1,
type=ModelType.VAE,
name="nonuniquename",
hash="CONFIG1HASH",
file_size=1006,
source="test/source/",
source_type=ModelSourceType.Path,
repo_variant=ModelRepoVariant.Default,
)
config4 = MainDiffusersConfig(
config4 = Main_Diffusers_SD1_Config(
path="/tmp/config4",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
@@ -215,6 +238,9 @@ def test_unique(store: ModelRecordServiceBase):
file_size=1007,
source="test/source/",
source_type=ModelSourceType.Path,
variant=ModelVariantType.Normal,
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
# config1, config2 and config3 are compatible because they have unique combos
# of name, type and base
@@ -229,7 +255,7 @@ def test_unique(store: ModelRecordServiceBase):
def test_filter_2(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig(
config1 = Main_Diffusers_SD1_Config(
path="/tmp/config1",
name="config1",
base=BaseModelType.StableDiffusion1,
@@ -238,8 +264,11 @@ def test_filter_2(store: ModelRecordServiceBase):
file_size=1008,
source="test/source/",
source_type=ModelSourceType.Path,
variant=ModelVariantType.Normal,
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
config2 = MainDiffusersConfig(
config2 = Main_Diffusers_SD1_Config(
path="/tmp/config2",
name="config2",
base=BaseModelType.StableDiffusion1,
@@ -248,28 +277,37 @@ def test_filter_2(store: ModelRecordServiceBase):
file_size=1009,
source="test/source/",
source_type=ModelSourceType.Path,
variant=ModelVariantType.Normal,
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
config3 = MainDiffusersConfig(
config3 = Main_Diffusers_SD2_Config(
path="/tmp/config3",
name="dup_name1",
base=BaseModelType("sd-2"),
base=BaseModelType.StableDiffusion2,
type=ModelType.Main,
hash="CONFIG3HASH",
file_size=1010,
source="test/source/",
source_type=ModelSourceType.Path,
variant=ModelVariantType.Normal,
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
config4 = MainDiffusersConfig(
config4 = Main_Diffusers_SDXL_Config(
path="/tmp/config4",
name="dup_name1",
base=BaseModelType("sdxl"),
base=BaseModelType.StableDiffusionXL,
type=ModelType.Main,
hash="CONFIG3HASH",
file_size=1011,
source="test/source/",
source_type=ModelSourceType.Path,
variant=ModelVariantType.Normal,
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
config5 = VAEDiffusersConfig(
config5 = VAE_Diffusers_SD1_Config(
path="/tmp/config5",
name="dup_name1",
base=BaseModelType.StableDiffusion1,
@@ -278,6 +316,7 @@ def test_filter_2(store: ModelRecordServiceBase):
file_size=1012,
source="test/source/",
source_type=ModelSourceType.Path,
repo_variant=ModelRepoVariant.Default,
)
for c in config1, config2, config3, config4, config5:
store.add_model(c)

View File

@@ -1,7 +1,7 @@
import pytest
import torch
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.util.test_utils import install_and_load_model

View File

@@ -14,15 +14,19 @@ from invokeai.app.services.model_install import ModelInstallService, ModelInstal
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType, ModelVariantType
from invokeai.backend.model_manager.config import (
LoRADiffusersConfig,
MainCheckpointConfig,
MainDiffusersConfig,
VAEDiffusersConfig,
)
from invokeai.backend.model_manager.configs.lora import LoRA_Diffusers_SD1_Config, LoRA_Diffusers_SDXL_Config
from invokeai.backend.model_manager.configs.main import Main_Checkpoint_SD1_Config, Main_Diffusers_SDXL_Config
from invokeai.backend.model_manager.configs.vae import VAE_Diffusers_SD1_Config
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.taxonomy import ModelSourceType
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_metadata.metadata_examples import (
@@ -132,19 +136,20 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db, logger)
# add five simple config records to the database
config1 = VAEDiffusersConfig(
config1 = VAE_Diffusers_SD1_Config(
key="test_config_1",
path="/tmp/foo1",
format=ModelFormat.Diffusers,
name="test2",
base=BaseModelType.StableDiffusion2,
base=BaseModelType.StableDiffusion1,
type=ModelType.VAE,
hash="111222333444",
file_size=4096,
source="stabilityai/sdxl-vae",
source_type=ModelSourceType.HFRepoID,
repo_variant=ModelRepoVariant.Default,
)
config2 = MainCheckpointConfig(
config2 = Main_Checkpoint_SD1_Config(
key="test_config_2",
path="/tmp/foo2.ckpt",
name="model1",
@@ -157,8 +162,9 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
file_size=8192,
source="https://civitai.com/models/206883/split",
source_type=ModelSourceType.Url,
prediction_type=SchedulerPredictionType.Epsilon,
)
config3 = MainDiffusersConfig(
config3 = Main_Diffusers_SDXL_Config(
key="test_config_3",
path="/tmp/foo3",
format=ModelFormat.Diffusers,
@@ -170,8 +176,11 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
source="author3/model3",
description="This is test 3",
source_type=ModelSourceType.HFRepoID,
variant=ModelVariantType.Normal,
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
config4 = LoRADiffusersConfig(
config4 = LoRA_Diffusers_SDXL_Config(
key="test_config_4",
path="/tmp/foo4",
format=ModelFormat.Diffusers,
@@ -183,7 +192,7 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
source="author4/model4",
source_type=ModelSourceType.HFRepoID,
)
config5 = LoRADiffusersConfig(
config5 = LoRA_Diffusers_SD1_Config(
key="test_config_5",
path="/tmp/foo5",
format=ModelFormat.Diffusers,

View File

@@ -3,7 +3,7 @@ from typing import List
import pytest
from invokeai.backend.model_manager import ModelRepoVariant
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
from invokeai.backend.model_manager.util.select_hf_files import filter_files

View File

@@ -7,14 +7,9 @@
import logging
import shutil
from pathlib import Path
from types import SimpleNamespace
import picklescan.scanner
import pytest
import safetensors.torch
import torch
import invokeai.backend.quantization.gguf.loaders as gguf_loaders
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
@@ -25,7 +20,6 @@ from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from scripts.strip_models import load_stripped_model
from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
from tests.test_nodes import TestEventService
@@ -82,23 +76,3 @@ def invokeai_root_dir(tmp_path_factory) -> Path:
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
shutil.copytree(root_template, temp_dir)
return temp_dir
@pytest.fixture(scope="function")
def override_model_loading(monkeypatch):
"""The legacy model probe directly calls model loading functions (e.g. torch.load) and also performs file scanning
via picklescan.scanner.scan_file_path. This fixture replaces these functions with test-friendly versions for
model files that have been 'stripped' to reduce their size (see scripts/strip_models.py).
Ideally, model loading would be injected as a dependency (i.e. ModelOnDisk) - but to avoid modifying the legacy probe,
we monkeypatch as a temporary workaround until the legacy probe is fully deprecated.
"""
monkeypatch.setattr(torch, "load", load_stripped_model)
monkeypatch.setattr(safetensors.torch, "load", load_stripped_model)
monkeypatch.setattr(safetensors.torch, "load_file", load_stripped_model)
monkeypatch.setattr(gguf_loaders, "gguf_sd_loader", load_stripped_model)
def fake_scan(*args, **kwargs):
return SimpleNamespace(infected_files=0, scan_err=None)
monkeypatch.setattr(picklescan.scanner, "scan_file_path", fake_scan)

View File

@@ -0,0 +1,149 @@
# Model Probe (Identification) Testing
Invoke's model Identification system is tested against example model files. Test cases are lightweight representations of real models which have been "stripped" of their tensor data.
## Setup
Test cases are stored with git lfs. You _must_ [install git lfs](https://git-lfs.com/) to pull down the test cases and add to them.
```bash
# Only need to do this once
git lfs install
# Pull the actual model files down - if you just do `git pull` you'll only get pointers
git lfs pull
```
## Running the Tests
To run the tests use:
```bash
pytest -v tests/test_model_probe/test_identification.py
```
## Stripped Model Files
Invoke abstracts the loading of a model's state dict and metadata in a class called [`ModelOnDisk`](../invokeai/backend/model_manager/model_on_disk.py). This class loads real model weights. We use it to inspect models and identify them.
For testing purposes, we create a stripped-down version of model weights that contain only the model structure and metadata for each key, without the actual tensor data. The state dict structure is typically all we need to identify models; the tensors themselves are not needed. This allows us to store test cases in the repo without adding many gigabytes of data.
To see how this works, check out [`StrippedModelOnDisk`](./stripped_model_on_disk.py). This class includes logic to strip models and to load these stripped models for testing.
### Some Models Cannot Be Stripped
Certain models cannot be stripped because identification relies on inspecting the actual tensor data. We have to store the full model files for these test cases.
> Currently, the only models that cannot be stripped are [`spandrel`](https://github.com/chaiNNer-org/spandrel/) image-to-image models. `spandrel` supports _many_ model architectures but doesn't provide a way to identify or assert support for a model by its state dict structure alone.
>
> To positively identify these models, we must attempt to load the model using spandrel. If it loads successfully, we assume it is a supported model. Therefore, we cannot strip these models and must store the full model files in the test cases. We only store one such model to keep the test suite size manageable.
>
> `StrippedModelOnDisk` will simply pass-through the "live" tensor data for these models when loading them to test.
## Adding New Test Cases
Run the [`strip_model.py`](./strip_model.py) script to create a new test case. For example:
```bash
python strip_model.py /path/to/your/model --output_dir ./stripped_models
```
It supports single-file models and multi-file models (e.g. diffusers-style models). The output will be a directory named with a UUID, containing the stripped model files and a dummy `__test_metadata__.json` file.
Example output structure for a single-file model:
```
stripped_models/
└── 19fd1a40-c5b7-4734-bd3a-6e0e948cce0b/
├── __test_metadata__.json
└── Standard Reference (XLabs FLUX IP-Adapter v2).safetensors
```
This test metadata file should contain a single JSON dict and must be filled out manually with the expected identification results.
### Structure of `__test_metadata__.json`
This file contains a single JSON dict. Here's an example for a FLUX IP Adapter checkpoint:
```json
{
"source": "https://huggingface.co/XLabs-AI/flux-ip-adapter-v2/resolve/main/ip_adapter.safetensors",
"file_name": "Standard Reference (XLabs FLUX IP-Adapter v2).safetensors",
"expected_config_attrs": {
"type": "ip_adapter",
"format": "checkpoint",
"base": "flux"
}
}
```
See the details below for each field.
#### `"source"`
A string indicating the source of the model (e.g. a Hugging Face repo ID or URL). This is not used for identification, but is useful for reference so we know where the model came from. Nothing will break if this field is missing or incorrect, but it is good practice to fill it out.
- Example HF Repo ID: `"RunDiffusion/Juggernaut-XL-v9"`
- Example URL: `"https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors"`
### `"file_name"`
If the model is a single file (e.g. a `.safetensors` file), this is the name of that file. The test suite will look for this file in the test case directory.
If the model is multi-file (e.g. diffusers-style), omit this key or set it to a falsey value like `null` or an empty string.
- Example: `"model.safetensors"`
> The `strip_model.py` script will automatically fill this field in for single-file models.
### `"expected_config_attrs"`
This field is a dict of expected configuration attributes for the model. It is required for all test cases.
It is used to verify that the model's configuration matches expectations. The keys and values in this dict depend on the specific model and its configuration.
These attributes must be included, as they are the primary discriminators for models:
- `"type"`: The type of the model. This is the value of the `ModelType` enum.
- `"format"`: The format of the model files. This is the value of the `ModelFormat` enum.
- `"base"`: The base model pipeline architecture associated with this model. Many models do not have an associated base. For these, use `"any"`. This is the value of the `BaseModelType` enum.
Depending on the kind of model, these additional keys may be useful:
- `"prediction_type"`: The prediction type used by the model. This is the value of the `SchedulerPredictionType` enum.
- `"variant"`: The variant of the model, if applicable. This is the value of the `ModelVariantType` enum.
To see all possible values for these enums, check out their definitions in [`invokeai/backend/model_manager/taxonomy.py`](../invokeai/backend/model_manager/taxonomy.py).
For example, for a SD1.5 main (pipeline) inpainting model in diffusers format, you might have:
```json
{
"expected_config_attrs": {
"type": "main",
"format": "diffusers",
"base": "sd-1",
"prediction_type": "epsilon",
"variant": "inpaint"
}
}
```
### `"notes"`
This is an optional string field where you can add any notes or comments about the test case. It can be useful for providing context or explaining any special considerations.
### `"override_fields"`
In some rare cases, we may need to provide additional hints to the identification system to help it identify the model correctly.
Currently, the only known case where we need extra information is to differentiate between single-file SD1.x, SD2.x and SDXL VAEs. These models have identical structures, so we need to provide a hint. Though it is far from ideal, we use simple string matching on the model's name to provide this hint.
For example, when users install the `taesdxl` VAE from the HF repo `madebyollin/taesdxl`, the identification system will get the model name `taesdxl`. It sees "xl" in the name and infers that this is a SDXL VAE. To reproduce this in a test case, we add the following to `__test_metadata__.json`:
```json
{
"override_fields": {
"name": "taesdxl"
}
}
```

View File

@@ -0,0 +1,112 @@
"""
Usage:
strip_model.py <model_path> <output_dir>
Strips tensor data from model state_dict while preserving metadata.
Used to create lightweight models for testing model classification.
Parameters:
<model_path> The path to the model to be stripped.
<output_dir> Directory where stripped models will be saved (e.g. tests/test_model_probe/stripped_models)
Options:
-h, --help Show this help message and exit
"""
import argparse
import json
import shutil
import sys
from copy import deepcopy
from pathlib import Path
from typing import Any
import humanize
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from tests.model_identification.stripped_model_on_disk import StrippedModelOnDisk
TEST_METADATA_FILENAME = "__test_metadata__.json"
TEST_METADATA: dict[str, Any] = {
"source": "",
"file_name": "",
"override_fields": {},
"expected_config_attrs": {},
"notes": "",
}
def create_stripped_model(model_path: Path, output_dir: Path):
"""Creates a stripped version of the model at model_path in output_dir. A test metadata file is also created."""
original_mod = ModelOnDisk(model_path)
# The stripped model will be stored in a new directory named with a UUID. This mirrors the application's
# normalized model storage file structure.
uuid = uuid_string()
stripped_model_dir = output_dir / uuid
stripped_model_dir.mkdir(parents=True, exist_ok=True)
test_metadata_content = deepcopy(TEST_METADATA)
if original_mod.path.is_file():
shutil.copy2(original_mod.path, stripped_model_dir / original_mod.path.name)
test_metadata_content["file_name"] = original_mod.path.name
else:
shutil.copytree(original_mod.path, stripped_model_dir, dirs_exist_ok=True)
stripped_mod = ModelOnDisk(stripped_model_dir)
print(f"Created clone of {original_mod.name} at {stripped_mod.path}")
for component_path in stripped_mod.weight_files():
original_state_dict = stripped_mod.load_state_dict(component_path)
stripped_state_dict = StrippedModelOnDisk.strip(original_state_dict)
metadata = stripped_mod.metadata()
contents = {**stripped_state_dict, StrippedModelOnDisk.METADATA_KEY: metadata}
component_path.write_text(json.dumps(contents, indent=2))
test_metadata_path = stripped_model_dir / TEST_METADATA_FILENAME
test_metadata_path.write_text(json.dumps(test_metadata_content, indent=2))
before_size = humanize.naturalsize(original_mod.size())
after_size = humanize.naturalsize(stripped_mod.size())
print(f"{original_mod.name} before: {before_size}, after: {after_size}")
return stripped_mod
def parse_arguments():
class Parser(argparse.ArgumentParser):
def error(self, message: str):
raise ValueError(message)
parser = Parser()
parser.add_argument("model_path", type=Path)
parser.add_argument("output_dir", type=Path)
try:
args = parser.parse_args()
except ValueError as e:
print(f"Error: {e}", file=sys.stderr)
print(__doc__, file=sys.stderr)
sys.exit(2)
if not args.model_path.exists():
parser.error(f"Error: Input model path '{args.model_path}' does not exist.")
return args
if __name__ == "__main__":
args = parse_arguments()
model_path = Path(args.model_path)
output_dir = Path(args.output_dir)
create_stripped_model(model_path, output_dir)

View File

@@ -0,0 +1,83 @@
import json
from pathlib import Path
from typing import Any, Optional
import gguf
import torch
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk, StateDict
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
class StrippedModelOnDisk(ModelOnDisk):
METADATA_KEY = "metadata_key_for_stripped_models"
STR_TO_DTYPE = {str(dtype): dtype for dtype in torch.__dict__.values() if isinstance(dtype, torch.dtype)}
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
path = self.resolve_weight_file(path)
return self.load_stripped_model(path)
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
path = self.resolve_weight_file(path)
with open(path, "r") as f:
contents = json.load(f)
return contents.get(self.METADATA_KEY, {})
@classmethod
def strip(cls, v: Any):
match v:
case GGMLTensor():
# GGMLTensor needs special handling to preserve quantization metadata. It is a subclass of torch.Tensor,
# so we need to check for it before checking for torch.Tensor.
return {
"quantized_data": cls.strip(v.quantized_data),
"ggml_quantization_type": v._ggml_quantization_type.name,
"tensor_shape": list(v.tensor_shape),
"compute_dtype": str(v.compute_dtype),
"fakeGGMLTensor": True,
}
case torch.Tensor():
return {"shape": v.shape, "dtype": str(v.dtype), "fakeTensor": True}
case dict():
return {k: cls.strip(v) for k, v in v.items()}
case list() | tuple():
return [cls.strip(x) for x in v]
case _:
return v
@classmethod
def dress(cls, v: Any):
match v:
case {
"quantized_data": quantized_data,
"ggml_quantization_type": qtype_name,
"tensor_shape": tensor_shape,
"compute_dtype": compute_dtype_str,
"fakeGGMLTensor": True,
}:
# Reconstruct the GGMLTensor from stripped data
qtype = gguf.GGMLQuantizationType[qtype_name]
compute_dtype = cls.STR_TO_DTYPE[compute_dtype_str]
dressed_quantized_data = cls.dress(quantized_data)
return GGMLTensor(
data=dressed_quantized_data,
ggml_quantization_type=qtype,
tensor_shape=torch.Size(tensor_shape),
compute_dtype=compute_dtype,
)
case {"shape": shape, "dtype": dtype_str, "fakeTensor": True}:
dtype = cls.STR_TO_DTYPE[dtype_str]
return torch.empty(shape, dtype=dtype)
case dict():
return {k: cls.dress(v) for k, v in v.items()}
case list() | tuple():
return [cls.dress(x) for x in v]
case _:
return v
@classmethod
def load_stripped_model(cls, path: Path, *args, **kwargs):
with open(path, "r") as f:
contents = json.load(f)
contents.pop(cls.METADATA_KEY, None)
return cls.dress(contents)

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b9203265975f0a76fce9435a52d7c406bf0a78c33aace1d790aaaee82145f108
size 160

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b5b8f9a7619c9452f19407ef73a01ce3f061fcf932de43ddbaca336d12d02801
size 1235

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:764b113f9978f30219a355c559e198aa2e400ad93ee42e90b8b18e5c45d620e2
size 137262

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:764b113f9978f30219a355c559e198aa2e400ad93ee42e90b8b18e5c45d620e2
size 137262

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f406592ff8542bc6c441b1b9ce7b7871969d4934b71f197d356c542bfc6100dc
size 421

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5cc69c42f1579faa712d66cc7864d2dd49b9cb75b032182a446295f290f71ac1
size 170

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a8b619cf6d992fccde14c00af034a7d58516642acb6dc9b6ea2c4f233a026312
size 1230

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d713facffb33213a0893e7386c29c1b088361713beff027a21d8bbfd1a732fed
size 137262

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8cc85c48a1a03c570280424807982b57d8a93dbb341ecf9a8751b9a49490d865
size 165

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:190f74eab53d6605b4587e8635930602341e6a0ba10a2b2cae6f44c93fad7e40
size 251

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4380556fba7a05d8095ab7373d2cf054b3dc6f637d69760ff3259dd912a1c309
size 161

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:43baf829cca18f27b6c561e422965c05ad81fe032e381eb9d2ab87bf679eaad9
size 79623

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:14992635efcb76a20a07d90dc79faaedabf16113c82ecf29104436af0bf184af
size 13469

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6ca25466bbe5b289a9a77c43c12b607775d204e23c2691595f9afbe7d4fefa74
size 288

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5a5f1a97f5f940688c2249cf8cd23bc38e2cbe20243c10b39e2611883f2835ab
size 161

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:95618921255d678347fa7ba4490afeef084a27a49c42c14cf99cde29306eb74d
size 52258

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c54537e57e94fa4493a0da78af4e0d37d1f3ee6f5216605185e5943f666f48a9
size 162

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b92e42e1f42918321c1a927d5ae7a128b16e95e0570961fb7af69c10472c761c
size 999

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:95618921255d678347fa7ba4490afeef084a27a49c42c14cf99cde29306eb74d
size 52258

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f31e6051478d14a1cc49c6a62c1c5f3735761ba8ec86e6833e634bb81703bf4d
size 160

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5cbf68b91bd3065716b88b9ef85785023fef6f59ffae2d5762b05553cef9450d
size 1274

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:764b113f9978f30219a355c559e198aa2e400ad93ee42e90b8b18e5c45d620e2
size 137262

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5972b48f1ee690922e944b4cc58a8cfe8fff5548febd67641fc468c5de678106
size 193

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7b7055054103676636b98eb87e6f712fbfa1d32fe61d2839ac91e6dd8efdf569
size 160

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b5b8f9a7619c9452f19407ef73a01ce3f061fcf932de43ddbaca336d12d02801
size 1235

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:764b113f9978f30219a355c559e198aa2e400ad93ee42e90b8b18e5c45d620e2
size 137262

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:30c2bc3cc3a01876bb237332ce046c3914d59200ecc2d266e1c7b3bb8c03e6a1
size 108625

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d99c6db720916e6fdc756c1270f6547f755386901e96cb989b252656e9459be8
size 240

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f08ce127136cb04fea2ace436533a8117b1aaefc3c7d86740336ffc35e8028d8
size 163

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1fe6a5da6d1c39714934544613352f02fe755bf3ee07c726521a8d50e007db95
size 1988

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7f3cddb8d6c5c36ef1a4248c600c8cf72f70ced509d087229a4e15b3e947d4ae
size 118855

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e3b3402f5d83838bfddc533d14ab0226c0654345b3c21b2b58de9184da4e5829
size 159

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:95618921255d678347fa7ba4490afeef084a27a49c42c14cf99cde29306eb74d
size 52258

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:836d3065e0ddf91ac744a46c489d01701c413b2bf2cf56ac57537b65d278578f
size 156

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6b28d56fdb9a2048a6ae38bf6afcc781d555300f2bb537c4b4c4a4a15412fe78
size 89593

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:587cdb874c248b9ee9ff5c2e258d717cbec9ddc088eb5984a1332f947c774758
size 157

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:95618921255d678347fa7ba4490afeef084a27a49c42c14cf99cde29306eb74d
size 52258

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:75d154190528bace5b3b8f5e9300964edf4c7d991ae01cf55cd7248ac5358f03
size 160

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c6488d01fc339c83d777b6d71df4f5bc36c035b2c977f987e8a5c0a54c79b706
size 955

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d89c76fb682a4c29a668c84d33ac2503f7f7ed4219aa2f70e4fc6779962e2303
size 52258

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9f639b0732cb9544db020c515b163e2c9f65995f9fcb1efff441c5b9f5337062
size 503787

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ec1a9b5b5f451306aa2660dec27db9e66c5c2b12cd9506dca4cdfe961f6bae46
size 253

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bdcc653ce6538878e8e4b96552a17048c95387a210431deecfd531a8ef7822df
size 207

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1664c0837272bb24919c0a5b6b9fb6970eed2fbfc21d58f87a4ca126f1732023
size 29272

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e4d60954dc9025adad03bc8512700adb50b0726f71971902e2d48b81a93a7ce6
size 78037

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8572891666dca7988bc7cdab5ea18a206c7d6ab9e84d619da19d0babdff35777
size 273411

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:57f04fab20ea8867d504056b1c04f807247f0760c59a5ee4ae733977a5797416
size 36362

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6109f6f70e72601468fde241c3b20e37f35a468c4271352c5837fb0fbf8f68b3
size 33769

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a562c7040fd26f65ba8a6d6803595c32e3a4f6a547f0cbc6e4297bc4b24e3f3e
size 215

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e6289a8c39add7a686eb605c0f1c55f2b630ac2098d7ddbf75aa7e39fe169803
size 162

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aa6fecc8745c0d0b26f0d2ca0f686600c8b33177c667d44eb41a51e56e44a697
size 999

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:95618921255d678347fa7ba4490afeef084a27a49c42c14cf99cde29306eb74d
size 52258

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a388e94405bf905efe5ba92d0b144db6e44d7becadc670c05e6e09d922209071
size 197

Some files were not shown because too many files have changed in this diff Show More