Files
InvokeAI/tests/app/services/model_install/test_model_install.py
psychedelicious 417db71471 feat(db): decouple SqliteDatabase from config object
- Simplify init args to path (None means use memory), logger, and verbose
- Add docstrings to SqliteDatabase (it had almost none)
- Update all usages of the class
2023-12-12 10:30:37 +11:00

213 lines
7.9 KiB
Python

"""
Test the model installer
"""
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel, ValidationError
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import (
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallService,
ModelInstallServiceBase,
)
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import migration_2
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SQLiteMigrator
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.util.logging import InvokeAILogger
@pytest.fixture
def test_file(datadir: Path) -> Path:
return datadir / "test_embedding.safetensors"
@pytest.fixture
def app_config(datadir: Path) -> InvokeAIAppConfig:
return InvokeAIAppConfig(
root=datadir / "root",
models_dir=datadir / "root/models",
)
@pytest.fixture
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=app_config)
db_path = None if app_config.use_memory_db else app_config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=app_config.log_sql)
migrator = SQLiteMigrator(
db_path=db.db_path,
conn=db.conn,
lock=db.lock,
logger=logger,
log_sql=app_config.log_sql,
)
migrator.register_migration(migration_1)
migrator.register_migration(migration_2)
migrator.run_migrations()
# this test uses a file database, so we need to reinitialize it after migrations
db.reinitialize()
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return store
@pytest.fixture
def installer(app_config: InvokeAIAppConfig, store: ModelRecordServiceBase) -> ModelInstallServiceBase:
return ModelInstallService(
app_config=app_config,
record_store=store,
event_bus=DummyEventService(),
)
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = installer.register_path(test_file)
assert key is not None
assert len(key) == 32
def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store
key = installer.register_path(test_file)
model_record = store.get_model(key)
assert model_record is not None
assert model_record.name == "test_embedding"
assert model_record.type == ModelType.TextualInversion
assert Path(model_record.path) == test_file
assert model_record.base == BaseModelType("sd-1")
assert model_record.description is not None
assert model_record.source is not None
assert Path(model_record.source) == test_file
def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None:
key = None
with pytest.raises(ValidationError):
key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")})
assert key is None
def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store
key = installer.register_path(
test_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
)
model_record = store.get_model(key)
assert model_record.name == "banana_sushi"
assert model_record.source == "fake/repo_id"
assert model_record.current_hash == "New Hash"
def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
store = installer.record_store
key = installer.install_path(test_file)
model_record = store.get_model(key)
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
assert model_record.source == test_file.as_posix()
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
"""Note: may want to break this down into several smaller unit tests."""
path = test_file
description = "Test of metadata assignment"
source = LocalModelSource(path=path, inplace=False)
job = installer.import_model(source, config={"description": description})
assert job is not None
assert isinstance(job, ModelInstallJob)
# See if job is registered properly
assert job in installer.get_job(source)
# test that the job object tracked installation correctly
jobs = installer.wait_for_installs()
assert len(jobs) > 0
my_job = [x for x in jobs if x.source == source]
assert len(my_job) == 1
assert my_job[0].status == InstallStatus.COMPLETED
# test that the expected events were issued
bus = installer.event_bus
assert bus is not None # sigh - ruff is a stickler for type checking
assert isinstance(bus, DummyEventService)
assert len(bus.events) == 2
event_names = [x.event_name for x in bus.events]
assert "model_install_started" in event_names
assert "model_install_completed" in event_names
assert Path(bus.events[0].payload["source"]) == source
assert Path(bus.events[1].payload["source"]) == source
key = bus.events[1].payload["key"]
assert key is not None
# see if the thing actually got installed at the expected location
model_record = installer.record_store.get_model(key)
assert model_record is not None
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
assert Path(app_config.models_dir / model_record.path).exists()
# see if metadata was properly passed through
assert model_record.description == description
# see if prune works properly
installer.prune_jobs()
assert not installer.get_job(source)
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
store = installer.record_store
key = installer.install_path(test_file)
model_record = store.get_model(key)
assert Path(app_config.models_dir / model_record.path).exists()
assert test_file.exists() # original should still be there after installation
installer.delete(key)
assert not Path(
app_config.models_dir / model_record.path
).exists() # after deletion, installed copy should not exist
assert test_file.exists() # but original should still be there
with pytest.raises(UnknownModelException):
store.get_model(key)
def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
store = installer.record_store
key = installer.register_path(test_file)
model_record = store.get_model(key)
assert Path(app_config.models_dir / model_record.path).exists()
assert test_file.exists() # original should still be there after installation
installer.delete(key)
assert Path(app_config.models_dir / model_record.path).exists()
with pytest.raises(UnknownModelException):
store.get_model(key)