tests(mm): fix remaining MM tests

This commit is contained in:
psychedelicious
2025-10-10 12:03:07 +11:00
parent adc332b9e3
commit e5935a39e4
3 changed files with 38 additions and 15 deletions

View File

@@ -8,7 +8,6 @@ from pathlib import Path
from typing import Any, Dict
import pytest
from pydantic import ValidationError
from pydantic_core import Url
from invokeai.app.services.config import InvokeAIAppConfig
@@ -27,6 +26,7 @@ from invokeai.app.services.model_install import (
)
from invokeai.app.services.model_install.model_install_common import (
InstallStatus,
InvalidModelConfigException,
LocalModelSource,
ModelInstallJob,
URLModelSource,
@@ -63,18 +63,14 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
assert Path(model_record.path) == embedding_file
assert Path(model_record.path).exists()
assert model_record.base == BaseModelType("sd-1")
assert model_record.description is not None
assert model_record.description is None
assert model_record.source is not None
assert Path(model_record.source) == embedding_file
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
key = None
with pytest.raises(ValidationError):
key = mm2_installer.register_path(
embedding_file, ModelRecordChanges(name="banana_sushi", type=ModelType("lora"))
)
assert key is None
with pytest.raises(InvalidModelConfigException):
mm2_installer.register_path(embedding_file, ModelRecordChanges(name="banana_sushi", type=ModelType("lora")))
def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
@@ -106,7 +102,11 @@ def test_rename(
key = mm2_installer.install_path(embedding_file)
model_record = store.get_model(key)
assert model_record.path.endswith(f"{key}/test_embedding.safetensors")
new_model_record = store.update_model(key, ModelRecordChanges(name="new model name", base=BaseModelType("sd-2")))
new_model_record = store.update_model(
key,
ModelRecordChanges(name="new model name", base=BaseModelType.StableDiffusion2),
allow_class_change=True,
)
# Renaming the model record shouldn't rename the file
assert new_model_record.name == "new model name"
assert model_record.path.endswith(f"{key}/test_embedding.safetensors")

View File

@@ -17,6 +17,7 @@ from invokeai.app.services.model_records import (
)
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.lora import LoRA_LyCORIS_SDXL_Config
from invokeai.backend.model_manager.configs.main import (
Main_Diffusers_SD1_Config,
Main_Diffusers_SD2_Config,
@@ -97,7 +98,29 @@ def test_model_records_updates_model(store: ModelRecordServiceBase):
assert new_config.name == new_name
def test_model_records_rejects_invalid_changes(store: ModelRecordServiceBase):
def test_model_records_updates_model_class(store: ModelRecordServiceBase):
config = example_ti_config("key1")
store.add_model(config)
changes = ModelRecordChanges(
type=ModelType.LoRA,
format=ModelFormat.LyCORIS,
base=BaseModelType.StableDiffusionXL,
)
new_config = store.update_model(config.key, changes, allow_class_change=True)
assert isinstance(new_config, LoRA_LyCORIS_SDXL_Config)
def test_model_records_rejects_invalid_attr_changes(store: ModelRecordServiceBase):
config = example_ti_config("key1")
store.add_model(config)
config = store.get_model("key1")
# upcast_attention is an invalid field for TIs
changes = ModelRecordChanges(upcast_attention=True)
with pytest.raises(ValidationError):
store.update_model(config.key, changes)
def test_model_records_rejects_invalid_attr_changes_that_change_class(store: ModelRecordServiceBase):
config = example_ti_config("key1")
store.add_model(config)
config = store.get_model("key1")
@@ -191,7 +214,7 @@ def test_filter(store: ModelRecordServiceBase):
assert len(matches) == 3
def test_unique(store: ModelRecordServiceBase):
def test_unique_by_path(store: ModelRecordServiceBase):
config1 = Main_Diffusers_SD1_Config(
path="/tmp/config1",
base=BaseModelType.StableDiffusion1,
@@ -230,7 +253,7 @@ def test_unique(store: ModelRecordServiceBase):
repo_variant=ModelRepoVariant.Default,
)
config4 = Main_Diffusers_SD1_Config(
path="/tmp/config4",
path="/tmp/config1",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
name="nonuniquename",
@@ -242,13 +265,13 @@ def test_unique(store: ModelRecordServiceBase):
prediction_type=SchedulerPredictionType.Epsilon,
repo_variant=ModelRepoVariant.Default,
)
# config1, config2 and config3 are compatible because they have unique combos
# config1, config2 and config3 are compatible because they have unique paths
# of name, type and base
for c in config1, config2, config3:
c.key = sha256(c.path.encode("utf-8")).hexdigest()
store.add_model(c)
# config4 clashes with config1 and should raise an integrity error
# config4 clashes with config1 (same path) and should raise an integrity error
with pytest.raises(DuplicateModelException):
config4.key = sha256(config4.path.encode("utf-8")).hexdigest()
store.add_model(config4)

View File

@@ -77,7 +77,7 @@ def diffusers_dir(mm2_model_files: Path) -> Path:
@pytest.fixture
def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
app_config = InvokeAIAppConfig(models_dir=mm2_root_dir / "models", log_level="info")
app_config = InvokeAIAppConfig(models_dir=mm2_root_dir / "models", log_level="info", allow_unknown_models=False)
app_config._root = mm2_root_dir
return app_config