mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
resolve merge conflicts
This commit is contained in:
@@ -2,14 +2,18 @@
|
||||
|
||||
import re
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
from requests_testadapter import TestAdapter
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
|
||||
from invokeai.app.services.events.events_common import (
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
@@ -17,56 +21,23 @@ from invokeai.app.services.events.events_common import (
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
# Prevent pytest deprecation warnings
|
||||
TestAdapter.__test__ = False # type: ignore
|
||||
TestAdapter.__test__ = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> Session:
|
||||
sess = TestSession()
|
||||
for i in ["12345", "9999", "54321"]:
|
||||
content = (
|
||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||
) # for pause tests, must make content large
|
||||
sess.mount(
|
||||
f"http://www.civitai.com/models/{i}",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# here are some malformed URLs to test
|
||||
# missing the content length
|
||||
sess.mount(
|
||||
"http://www.civitai.com/models/missing",
|
||||
TestAdapter(
|
||||
b"Missing content length",
|
||||
headers={
|
||||
"Content-Disposition": 'filename="missing.txt"',
|
||||
},
|
||||
),
|
||||
)
|
||||
# not found test
|
||||
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob) -> None:
|
||||
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
events.add(job.status)
|
||||
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
job = queue.download(
|
||||
@@ -82,16 +53,17 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert job.download_path == tmp_path / "mock12345.safetensors"
|
||||
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||
|
||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_errors(tmp_path: Path, session: Session) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_errors(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
|
||||
@@ -110,11 +82,11 @@ def test_errors(tmp_path: Path, session: Session) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||
queue.start()
|
||||
queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
@@ -146,10 +118,10 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
|
||||
@@ -178,11 +150,11 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=15, method="thread")
|
||||
def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||
queue.start()
|
||||
|
||||
cancelled = False
|
||||
@@ -194,9 +166,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Join took too long to return")
|
||||
|
||||
job = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
@@ -212,3 +181,178 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
assert isinstance(events[-1], DownloadCancelledEvent)
|
||||
assert events[-1].source == "http://www.civitai.com/models/12345"
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
events.add(job.status)
|
||||
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
job = queue.multifile_download(
|
||||
parts=metadata.download_urls(session=mm2_session),
|
||||
dest=tmp_path,
|
||||
on_start=event_handler,
|
||||
on_progress=event_handler,
|
||||
on_complete=event_handler,
|
||||
on_error=event_handler,
|
||||
)
|
||||
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert job.bytes > 0, "expected download bytes to be positive"
|
||||
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||
assert job.download_path == tmp_path / "sdxl-turbo"
|
||||
assert Path(
|
||||
tmp_path, "sdxl-turbo/model_index.json"
|
||||
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
||||
assert Path(
|
||||
tmp_path, "sdxl-turbo/text_encoder/config.json"
|
||||
).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
|
||||
|
||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
events.add(job.status)
|
||||
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
files = metadata.download_urls(session=mm2_session)
|
||||
# this will give a 404 error
|
||||
files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken")))
|
||||
job = queue.multifile_download(
|
||||
parts=files,
|
||||
dest=tmp_path,
|
||||
on_start=event_handler,
|
||||
on_progress=event_handler,
|
||||
on_complete=event_handler,
|
||||
on_error=event_handler,
|
||||
)
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
|
||||
assert job.error_type is not None
|
||||
assert "HTTPError(NOT FOUND)" in job.error_type
|
||||
assert DownloadJobStatus.ERROR in events
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||
queue.start()
|
||||
|
||||
cancelled = False
|
||||
|
||||
def cancelled_callback(job: DownloadJob) -> None:
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
|
||||
job = queue.multifile_download(
|
||||
parts=metadata.download_urls(session=mm2_session),
|
||||
dest=tmp_path,
|
||||
on_cancelled=cancelled_callback,
|
||||
)
|
||||
queue.cancel_job(job)
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus.CANCELLED
|
||||
assert cancelled
|
||||
events = event_bus.events
|
||||
assert DownloadCancelledEvent in [type(x) for x in events]
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
job = queue.multifile_download(
|
||||
parts=[
|
||||
RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors"))
|
||||
],
|
||||
dest=tmp_path,
|
||||
)
|
||||
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert job.bytes > 0, "expected download bytes to be positive"
|
||||
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||
assert job.download_path == tmp_path / "mock12345.safetensors"
|
||||
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError) as error:
|
||||
queue.multifile_download(
|
||||
parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))],
|
||||
dest=tmp_path,
|
||||
)
|
||||
assert str(error.value) == "only relative download paths accepted"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def clear_config() -> Generator[None, None, None]:
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
get_config.cache_clear()
|
||||
|
||||
|
||||
def test_tokens(tmp_path: Path, mm2_session: Session):
|
||||
with clear_config():
|
||||
config = get_config()
|
||||
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
|
||||
queue = DownloadQueueService(requests_session=mm2_session)
|
||||
queue.start()
|
||||
# this one has an access token assigned
|
||||
job1 = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
)
|
||||
# this one doesn't
|
||||
job2 = queue.download(
|
||||
source=AnyHttpUrl(
|
||||
"http://www.huggingface.co/foo.txt",
|
||||
),
|
||||
dest=tmp_path,
|
||||
)
|
||||
queue.join()
|
||||
# this token is defined in the temporary root invokeai.yaml
|
||||
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
|
||||
assert job1.access_token == "cv_12345"
|
||||
assert job2.access_token is None
|
||||
queue.stop()
|
||||
|
||||
@@ -17,9 +17,11 @@ from invokeai.app.services.events.events_common import (
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallDownloadStartedEvent,
|
||||
ModelInstallStartedEvent,
|
||||
)
|
||||
from invokeai.app.services.model_install import (
|
||||
HFModelSource,
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from invokeai.app.services.model_install.model_install_common import (
|
||||
@@ -29,7 +31,14 @@ from invokeai.app.services.model_install.model_install_common import (
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
)
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
OS = platform.uname().system
|
||||
@@ -222,7 +231,7 @@ def test_delete_register(
|
||||
store.get_model(key)
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||
|
||||
@@ -243,15 +252,16 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
model_record = store.get_model(key)
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
|
||||
assert len(bus.events) == 4
|
||||
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
|
||||
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
|
||||
assert isinstance(bus.events[2], ModelInstallStartedEvent)
|
||||
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
|
||||
assert len(bus.events) == 5
|
||||
assert isinstance(bus.events[0], ModelInstallDownloadStartedEvent) # download starts
|
||||
assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses
|
||||
assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed
|
||||
assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started
|
||||
assert isinstance(bus.events[4], ModelInstallCompleteEvent) # install completed
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
@@ -277,6 +287,49 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
||||
assert len(bus.events) >= 3
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert isinstance(bus, EventServiceBase)
|
||||
assert store is not None
|
||||
|
||||
job = mm2_installer.import_model(source)
|
||||
job_list = mm2_installer.wait_for_installs(timeout=10)
|
||||
assert len(job_list) == 1
|
||||
assert job.complete
|
||||
assert job.config_out
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
assert model_record.type == ModelType.Main
|
||||
assert model_record.format == ModelFormat.Diffusers
|
||||
|
||||
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||
assert len(bus.events) >= 3
|
||||
event_types = [type(x) for x in bus.events]
|
||||
assert all(
|
||||
x in event_types
|
||||
for x in [
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
]
|
||||
)
|
||||
|
||||
completed_events = [x for x in bus.events if isinstance(x, ModelInstallCompleteEvent)]
|
||||
downloading_events = [x for x in bus.events if isinstance(x, ModelInstallDownloadProgressEvent)]
|
||||
assert completed_events[0].total_bytes == downloading_events[-1].bytes
|
||||
assert job.total_bytes == completed_events[0].total_bytes
|
||||
print(downloading_events[-1])
|
||||
print(job.download_parts)
|
||||
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts)
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
||||
job = mm2_installer.import_model(source)
|
||||
@@ -308,7 +361,6 @@ def test_other_error_during_install(
|
||||
assert job.error == "Test error"
|
||||
|
||||
|
||||
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
|
||||
@pytest.mark.parametrize(
|
||||
"model_params",
|
||||
[
|
||||
@@ -326,7 +378,7 @@ def test_other_error_during_install(
|
||||
},
|
||||
],
|
||||
)
|
||||
@pytest.mark.timeout(timeout=40, method="thread")
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||
assert "name" in model_params and "type" in model_params
|
||||
@@ -342,7 +394,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
|
||||
}
|
||||
assert "repo_id" in model_params
|
||||
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||
mm2_installer.wait_for_job(install_job1, timeout=20)
|
||||
mm2_installer.wait_for_job(install_job1, timeout=10)
|
||||
if model_params["type"] != "embedding":
|
||||
assert install_job1.errored
|
||||
assert install_job1.error_type == "InvalidModelConfigException"
|
||||
@@ -351,6 +403,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
|
||||
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||
|
||||
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||
mm2_installer.wait_for_job(install_job2, timeout=20)
|
||||
mm2_installer.wait_for_job(install_job2, timeout=10)
|
||||
assert install_job2.complete
|
||||
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||
|
||||
88
tests/app/services/model_load/test_load_api.py
Normal file
88
tests/app/services/model_load/test_load_api.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from diffusers import AutoencoderTiny
|
||||
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_context(
|
||||
mock_services: InvocationServices,
|
||||
mm2_model_manager: ModelManagerServiceBase,
|
||||
) -> InvocationContext:
|
||||
mock_services.model_manager = mm2_model_manager
|
||||
return build_invocation_context(
|
||||
services=mock_services,
|
||||
data=None, # type: ignore
|
||||
is_canceled=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None:
|
||||
downloaded_path = mock_context.models.download_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
assert downloaded_path.is_file()
|
||||
assert downloaded_path.exists()
|
||||
assert downloaded_path.name == "test_embedding.safetensors"
|
||||
assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache"
|
||||
|
||||
downloaded_path_2 = mock_context.models.download_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
assert downloaded_path == downloaded_path_2
|
||||
|
||||
|
||||
def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None:
|
||||
downloaded_path = mock_context.models.download_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
loaded_model_1 = mock_context.models.load_local_model(downloaded_path)
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_local_model(downloaded_path)
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model
|
||||
|
||||
loaded_model_3 = mock_context.models.load_local_model(embedding_file)
|
||||
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is not loaded_model_3.model
|
||||
assert isinstance(loaded_model_1.model, dict)
|
||||
assert isinstance(loaded_model_3.model, dict)
|
||||
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This requires a test model to load")
|
||||
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
|
||||
loaded_model = mock_context.models.load_local_model(vae_directory)
|
||||
assert isinstance(loaded_model, LoadedModelWithoutConfig)
|
||||
assert isinstance(loaded_model.model, AutoencoderTiny)
|
||||
|
||||
|
||||
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||
loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||
|
||||
|
||||
def test_download_diffusers(mock_context: InvocationContext) -> None:
|
||||
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo")
|
||||
assert (model_path / "model_index.json").exists()
|
||||
assert (model_path / "vae").is_dir()
|
||||
|
||||
|
||||
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
|
||||
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
|
||||
assert model_path.is_dir()
|
||||
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
|
||||
model_path / "diffusion_pytorch_model.safetensors"
|
||||
).exists()
|
||||
@@ -61,6 +61,13 @@ def embedding_file(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "test_embedding.safetensors"
|
||||
|
||||
|
||||
# Can be used to test diffusers model directory loading, but
|
||||
# the test file adds ~10MB of space.
|
||||
# @pytest.fixture
|
||||
# def vae_directory(mm2_model_files: Path) -> Path:
|
||||
# return mm2_model_files / "taesdxl"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def diffusers_dir(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "test-diffusers-main"
|
||||
@@ -293,4 +300,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
for i in ["12345", "9999", "54321"]:
|
||||
content = (
|
||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||
) # for pause tests, must make content large
|
||||
sess.mount(
|
||||
f"http://www.civitai.com/models/{i}",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
sess.mount(
|
||||
"http://www.huggingface.co/foo.txt",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": 'filename="foo.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# here are some malformed URLs to test
|
||||
# missing the content length
|
||||
sess.mount(
|
||||
"http://www.civitai.com/models/missing",
|
||||
TestAdapter(
|
||||
b"Missing content length",
|
||||
headers={
|
||||
"Content-Disposition": 'filename="missing.txt"',
|
||||
},
|
||||
),
|
||||
)
|
||||
# not found test
|
||||
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
|
||||
return sess
|
||||
|
||||
Reference in New Issue
Block a user