diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index f1b24aefd2..af835561d3 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -8,7 +8,6 @@ from pathlib import Path from typing import Any, Dict import pytest -from pydantic import ValidationError from pydantic_core import Url from invokeai.app.services.config import InvokeAIAppConfig @@ -27,6 +26,7 @@ from invokeai.app.services.model_install import ( ) from invokeai.app.services.model_install.model_install_common import ( InstallStatus, + InvalidModelConfigException, LocalModelSource, ModelInstallJob, URLModelSource, @@ -63,18 +63,14 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil assert Path(model_record.path) == embedding_file assert Path(model_record.path).exists() assert model_record.base == BaseModelType("sd-1") - assert model_record.description is not None + assert model_record.description is None assert model_record.source is not None assert Path(model_record.source) == embedding_file def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None: - key = None - with pytest.raises(ValidationError): - key = mm2_installer.register_path( - embedding_file, ModelRecordChanges(name="banana_sushi", type=ModelType("lora")) - ) - assert key is None + with pytest.raises(InvalidModelConfigException): + mm2_installer.register_path(embedding_file, ModelRecordChanges(name="banana_sushi", type=ModelType("lora"))) def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None: @@ -106,7 +102,11 @@ def test_rename( key = mm2_installer.install_path(embedding_file) model_record = store.get_model(key) assert model_record.path.endswith(f"{key}/test_embedding.safetensors") - new_model_record = store.update_model(key, ModelRecordChanges(name="new model name", base=BaseModelType("sd-2"))) + new_model_record = store.update_model( + key, + ModelRecordChanges(name="new model name", base=BaseModelType.StableDiffusion2), + allow_class_change=True, + ) # Renaming the model record shouldn't rename the file assert new_model_record.name == "new model name" assert model_record.path.endswith(f"{key}/test_embedding.safetensors") diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index df124b8115..2b6c54d5b0 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -17,6 +17,7 @@ from invokeai.app.services.model_records import ( ) from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings +from invokeai.backend.model_manager.configs.lora import LoRA_LyCORIS_SDXL_Config from invokeai.backend.model_manager.configs.main import ( Main_Diffusers_SD1_Config, Main_Diffusers_SD2_Config, @@ -97,7 +98,29 @@ def test_model_records_updates_model(store: ModelRecordServiceBase): assert new_config.name == new_name -def test_model_records_rejects_invalid_changes(store: ModelRecordServiceBase): +def test_model_records_updates_model_class(store: ModelRecordServiceBase): + config = example_ti_config("key1") + store.add_model(config) + changes = ModelRecordChanges( + type=ModelType.LoRA, + format=ModelFormat.LyCORIS, + base=BaseModelType.StableDiffusionXL, + ) + new_config = store.update_model(config.key, changes, allow_class_change=True) + assert isinstance(new_config, LoRA_LyCORIS_SDXL_Config) + + +def test_model_records_rejects_invalid_attr_changes(store: ModelRecordServiceBase): + config = example_ti_config("key1") + store.add_model(config) + config = store.get_model("key1") + # upcast_attention is an invalid field for TIs + changes = ModelRecordChanges(upcast_attention=True) + with pytest.raises(ValidationError): + store.update_model(config.key, changes) + + +def test_model_records_rejects_invalid_attr_changes_that_change_class(store: ModelRecordServiceBase): config = example_ti_config("key1") store.add_model(config) config = store.get_model("key1") @@ -191,7 +214,7 @@ def test_filter(store: ModelRecordServiceBase): assert len(matches) == 3 -def test_unique(store: ModelRecordServiceBase): +def test_unique_by_path(store: ModelRecordServiceBase): config1 = Main_Diffusers_SD1_Config( path="/tmp/config1", base=BaseModelType.StableDiffusion1, @@ -230,7 +253,7 @@ def test_unique(store: ModelRecordServiceBase): repo_variant=ModelRepoVariant.Default, ) config4 = Main_Diffusers_SD1_Config( - path="/tmp/config4", + path="/tmp/config1", base=BaseModelType.StableDiffusion1, type=ModelType.Main, name="nonuniquename", @@ -242,13 +265,13 @@ def test_unique(store: ModelRecordServiceBase): prediction_type=SchedulerPredictionType.Epsilon, repo_variant=ModelRepoVariant.Default, ) - # config1, config2 and config3 are compatible because they have unique combos + # config1, config2 and config3 are compatible because they have unique paths # of name, type and base for c in config1, config2, config3: c.key = sha256(c.path.encode("utf-8")).hexdigest() store.add_model(c) - # config4 clashes with config1 and should raise an integrity error + # config4 clashes with config1 (same path) and should raise an integrity error with pytest.raises(DuplicateModelException): config4.key = sha256(config4.path.encode("utf-8")).hexdigest() store.add_model(config4) diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 6ca4b77f0a..9ab4570dde 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -77,7 +77,7 @@ def diffusers_dir(mm2_model_files: Path) -> Path: @pytest.fixture def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: - app_config = InvokeAIAppConfig(models_dir=mm2_root_dir / "models", log_level="info") + app_config = InvokeAIAppConfig(models_dir=mm2_root_dir / "models", log_level="info", allow_unknown_models=False) app_config._root = mm2_root_dir return app_config