mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
tests(mm): fix remaining MM tests
This commit is contained in:
@@ -8,7 +8,6 @@ from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from pydantic_core import Url
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@@ -27,6 +26,7 @@ from invokeai.app.services.model_install import (
|
||||
)
|
||||
from invokeai.app.services.model_install.model_install_common import (
|
||||
InstallStatus,
|
||||
InvalidModelConfigException,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
URLModelSource,
|
||||
@@ -63,18 +63,14 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
|
||||
assert Path(model_record.path) == embedding_file
|
||||
assert Path(model_record.path).exists()
|
||||
assert model_record.base == BaseModelType("sd-1")
|
||||
assert model_record.description is not None
|
||||
assert model_record.description is None
|
||||
assert model_record.source is not None
|
||||
assert Path(model_record.source) == embedding_file
|
||||
|
||||
|
||||
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
key = None
|
||||
with pytest.raises(ValidationError):
|
||||
key = mm2_installer.register_path(
|
||||
embedding_file, ModelRecordChanges(name="banana_sushi", type=ModelType("lora"))
|
||||
)
|
||||
assert key is None
|
||||
with pytest.raises(InvalidModelConfigException):
|
||||
mm2_installer.register_path(embedding_file, ModelRecordChanges(name="banana_sushi", type=ModelType("lora")))
|
||||
|
||||
|
||||
def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
@@ -106,7 +102,11 @@ def test_rename(
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.path.endswith(f"{key}/test_embedding.safetensors")
|
||||
new_model_record = store.update_model(key, ModelRecordChanges(name="new model name", base=BaseModelType("sd-2")))
|
||||
new_model_record = store.update_model(
|
||||
key,
|
||||
ModelRecordChanges(name="new model name", base=BaseModelType.StableDiffusion2),
|
||||
allow_class_change=True,
|
||||
)
|
||||
# Renaming the model record shouldn't rename the file
|
||||
assert new_model_record.name == "new model name"
|
||||
assert model_record.path.endswith(f"{key}/test_embedding.safetensors")
|
||||
|
||||
@@ -17,6 +17,7 @@ from invokeai.app.services.model_records import (
|
||||
)
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
|
||||
from invokeai.backend.model_manager.configs.lora import LoRA_LyCORIS_SDXL_Config
|
||||
from invokeai.backend.model_manager.configs.main import (
|
||||
Main_Diffusers_SD1_Config,
|
||||
Main_Diffusers_SD2_Config,
|
||||
@@ -97,7 +98,29 @@ def test_model_records_updates_model(store: ModelRecordServiceBase):
|
||||
assert new_config.name == new_name
|
||||
|
||||
|
||||
def test_model_records_rejects_invalid_changes(store: ModelRecordServiceBase):
|
||||
def test_model_records_updates_model_class(store: ModelRecordServiceBase):
|
||||
config = example_ti_config("key1")
|
||||
store.add_model(config)
|
||||
changes = ModelRecordChanges(
|
||||
type=ModelType.LoRA,
|
||||
format=ModelFormat.LyCORIS,
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
)
|
||||
new_config = store.update_model(config.key, changes, allow_class_change=True)
|
||||
assert isinstance(new_config, LoRA_LyCORIS_SDXL_Config)
|
||||
|
||||
|
||||
def test_model_records_rejects_invalid_attr_changes(store: ModelRecordServiceBase):
|
||||
config = example_ti_config("key1")
|
||||
store.add_model(config)
|
||||
config = store.get_model("key1")
|
||||
# upcast_attention is an invalid field for TIs
|
||||
changes = ModelRecordChanges(upcast_attention=True)
|
||||
with pytest.raises(ValidationError):
|
||||
store.update_model(config.key, changes)
|
||||
|
||||
|
||||
def test_model_records_rejects_invalid_attr_changes_that_change_class(store: ModelRecordServiceBase):
|
||||
config = example_ti_config("key1")
|
||||
store.add_model(config)
|
||||
config = store.get_model("key1")
|
||||
@@ -191,7 +214,7 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
assert len(matches) == 3
|
||||
|
||||
|
||||
def test_unique(store: ModelRecordServiceBase):
|
||||
def test_unique_by_path(store: ModelRecordServiceBase):
|
||||
config1 = Main_Diffusers_SD1_Config(
|
||||
path="/tmp/config1",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
@@ -230,7 +253,7 @@ def test_unique(store: ModelRecordServiceBase):
|
||||
repo_variant=ModelRepoVariant.Default,
|
||||
)
|
||||
config4 = Main_Diffusers_SD1_Config(
|
||||
path="/tmp/config4",
|
||||
path="/tmp/config1",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
name="nonuniquename",
|
||||
@@ -242,13 +265,13 @@ def test_unique(store: ModelRecordServiceBase):
|
||||
prediction_type=SchedulerPredictionType.Epsilon,
|
||||
repo_variant=ModelRepoVariant.Default,
|
||||
)
|
||||
# config1, config2 and config3 are compatible because they have unique combos
|
||||
# config1, config2 and config3 are compatible because they have unique paths
|
||||
# of name, type and base
|
||||
for c in config1, config2, config3:
|
||||
c.key = sha256(c.path.encode("utf-8")).hexdigest()
|
||||
store.add_model(c)
|
||||
|
||||
# config4 clashes with config1 and should raise an integrity error
|
||||
# config4 clashes with config1 (same path) and should raise an integrity error
|
||||
with pytest.raises(DuplicateModelException):
|
||||
config4.key = sha256(config4.path.encode("utf-8")).hexdigest()
|
||||
store.add_model(config4)
|
||||
|
||||
@@ -77,7 +77,7 @@ def diffusers_dir(mm2_model_files: Path) -> Path:
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
||||
app_config = InvokeAIAppConfig(models_dir=mm2_root_dir / "models", log_level="info")
|
||||
app_config = InvokeAIAppConfig(models_dir=mm2_root_dir / "models", log_level="info", allow_unknown_models=False)
|
||||
app_config._root = mm2_root_dir
|
||||
return app_config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user