resolve merge conflicts

This commit is contained in:
Lincoln Stein
2024-06-23 10:31:35 -04:00
110 changed files with 5600 additions and 3245 deletions

View File

@@ -2,14 +2,18 @@
import re
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Generator, Optional
import pytest
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from requests_testadapter import TestAdapter
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from invokeai.app.services.config import get_config
from invokeai.app.services.config.config_default import URLRegexTokenPair
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
from invokeai.app.services.events.events_common import (
DownloadCancelledEvent,
DownloadCompleteEvent,
@@ -17,56 +21,23 @@ from invokeai.app.services.events.events_common import (
DownloadProgressEvent,
DownloadStartedEvent,
)
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings
TestAdapter.__test__ = False # type: ignore
TestAdapter.__test__ = False
@pytest.fixture
def session() -> Session:
sess = TestSession()
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
sess.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
return sess
@pytest.mark.timeout(timeout=20, method="thread")
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None:
events = set()
def event_handler(job: DownloadJob) -> None:
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(
requests_session=session,
requests_session=mm2_session,
)
queue.start()
job = queue.download(
@@ -82,16 +53,17 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.download_path == tmp_path / "mock12345.safetensors"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_errors(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_errors(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
requests_session=session,
requests_session=mm2_session,
)
queue.start()
@@ -110,11 +82,11 @@ def test_errors(tmp_path: Path, session: Session) -> None:
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_event_bus(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
@@ -146,10 +118,10 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None:
queue = DownloadQueueService(
requests_session=session,
requests_session=mm2_session,
)
queue.start()
@@ -178,11 +150,11 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
queue.stop()
@pytest.mark.timeout(timeout=15, method="thread")
def test_cancel(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
cancelled = False
@@ -194,9 +166,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
nonlocal cancelled
cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
@@ -212,3 +181,178 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
assert isinstance(events[-1], DownloadCancelledEvent)
assert events[-1].source == "http://www.civitai.com/models/12345"
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
events = set()
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(
requests_session=mm2_session,
)
queue.start()
job = queue.multifile_download(
parts=metadata.download_urls(session=mm2_session),
dest=tmp_path,
on_start=event_handler,
on_progress=event_handler,
on_complete=event_handler,
on_error=event_handler,
)
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.bytes > 0, "expected download bytes to be positive"
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
assert job.download_path == tmp_path / "sdxl-turbo"
assert Path(
tmp_path, "sdxl-turbo/model_index.json"
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
assert Path(
tmp_path, "sdxl-turbo/text_encoder/config.json"
).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
events = set()
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(
requests_session=mm2_session,
)
queue.start()
files = metadata.download_urls(session=mm2_session)
# this will give a 404 error
files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken")))
job = queue.multifile_download(
parts=files,
dest=tmp_path,
on_start=event_handler,
on_progress=event_handler,
on_complete=event_handler,
on_error=event_handler,
)
queue.join()
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
assert job.error_type is not None
assert "HTTPError(NOT FOUND)" in job.error_type
assert DownloadJobStatus.ERROR in events
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
cancelled = False
def cancelled_callback(job: DownloadJob) -> None:
nonlocal cancelled
cancelled = True
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
job = queue.multifile_download(
parts=metadata.download_urls(session=mm2_session),
dest=tmp_path,
on_cancelled=cancelled_callback,
)
queue.cancel_job(job)
queue.join()
assert job.status == DownloadJobStatus.CANCELLED
assert cancelled
events = event_bus.events
assert DownloadCancelledEvent in [type(x) for x in events]
queue.stop()
def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
requests_session=mm2_session,
)
queue.start()
job = queue.multifile_download(
parts=[
RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors"))
],
dest=tmp_path,
)
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.bytes > 0, "expected download bytes to be positive"
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
assert job.download_path == tmp_path / "mock12345.safetensors"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
queue.stop()
def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
requests_session=mm2_session,
)
with pytest.raises(AssertionError) as error:
queue.multifile_download(
parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))],
dest=tmp_path,
)
assert str(error.value) == "only relative download paths accepted"
@contextmanager
def clear_config() -> Generator[None, None, None]:
try:
yield None
finally:
get_config.cache_clear()
def test_tokens(tmp_path: Path, mm2_session: Session):
with clear_config():
config = get_config()
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
queue = DownloadQueueService(requests_session=mm2_session)
queue.start()
# this one has an access token assigned
job1 = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
)
# this one doesn't
job2 = queue.download(
source=AnyHttpUrl(
"http://www.huggingface.co/foo.txt",
),
dest=tmp_path,
)
queue.join()
# this token is defined in the temporary root invokeai.yaml
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
assert job1.access_token == "cv_12345"
assert job2.access_token is None
queue.stop()

View File

@@ -17,9 +17,11 @@ from invokeai.app.services.events.events_common import (
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallDownloadStartedEvent,
ModelInstallStartedEvent,
)
from invokeai.app.services.model_install import (
HFModelSource,
ModelInstallServiceBase,
)
from invokeai.app.services.model_install.model_install_common import (
@@ -29,7 +31,14 @@ from invokeai.app.services.model_install.model_install_common import (
URLModelSource,
)
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
from invokeai.backend.model_manager.config import (
BaseModelType,
InvalidModelConfigException,
ModelFormat,
ModelRepoVariant,
ModelType,
)
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
OS = platform.uname().system
@@ -222,7 +231,7 @@ def test_delete_register(
store.get_model(key)
@pytest.mark.timeout(timeout=20, method="thread")
@pytest.mark.timeout(timeout=10, method="thread")
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
@@ -243,15 +252,16 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
model_record = store.get_model(key)
assert (mm2_app_config.models_path / model_record.path).exists()
assert len(bus.events) == 4
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
assert isinstance(bus.events[2], ModelInstallStartedEvent)
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
assert len(bus.events) == 5
assert isinstance(bus.events[0], ModelInstallDownloadStartedEvent) # download starts
assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses
assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed
assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started
assert isinstance(bus.events[4], ModelInstallCompleteEvent) # install completed
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
bus: TestEventService = mm2_installer.event_bus
@@ -277,6 +287,49 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
assert len(bus.events) >= 3
@pytest.mark.timeout(timeout=10, method="thread")
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
bus = mm2_installer.event_bus
store = mm2_installer.record_store
assert isinstance(bus, EventServiceBase)
assert store is not None
job = mm2_installer.import_model(source)
job_list = mm2_installer.wait_for_installs(timeout=10)
assert len(job_list) == 1
assert job.complete
assert job.config_out
key = job.config_out.key
model_record = store.get_model(key)
assert (mm2_app_config.models_path / model_record.path).exists()
assert model_record.type == ModelType.Main
assert model_record.format == ModelFormat.Diffusers
assert hasattr(bus, "events") # the dummyeventservice has this
assert len(bus.events) >= 3
event_types = [type(x) for x in bus.events]
assert all(
x in event_types
for x in [
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
]
)
completed_events = [x for x in bus.events if isinstance(x, ModelInstallCompleteEvent)]
downloading_events = [x for x in bus.events if isinstance(x, ModelInstallDownloadProgressEvent)]
assert completed_events[0].total_bytes == downloading_events[-1].bytes
assert job.total_bytes == completed_events[0].total_bytes
print(downloading_events[-1])
print(job.download_parts)
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts)
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
job = mm2_installer.import_model(source)
@@ -308,7 +361,6 @@ def test_other_error_during_install(
assert job.error == "Test error"
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
@pytest.mark.parametrize(
"model_params",
[
@@ -326,7 +378,7 @@ def test_other_error_during_install(
},
],
)
@pytest.mark.timeout(timeout=40, method="thread")
@pytest.mark.timeout(timeout=10, method="thread")
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import."""
assert "name" in model_params and "type" in model_params
@@ -342,7 +394,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
}
assert "repo_id" in model_params
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
mm2_installer.wait_for_job(install_job1, timeout=20)
mm2_installer.wait_for_job(install_job1, timeout=10)
if model_params["type"] != "embedding":
assert install_job1.errored
assert install_job1.error_type == "InvalidModelConfigException"
@@ -351,6 +403,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
mm2_installer.wait_for_job(install_job2, timeout=20)
mm2_installer.wait_for_job(install_job2, timeout=10)
assert install_job2.complete
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out

View File

@@ -0,0 +1,88 @@
from pathlib import Path
import pytest
import torch
from diffusers import AutoencoderTiny
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
@pytest.fixture()
def mock_context(
mock_services: InvocationServices,
mm2_model_manager: ModelManagerServiceBase,
) -> InvocationContext:
mock_services.model_manager = mm2_model_manager
return build_invocation_context(
services=mock_services,
data=None, # type: ignore
is_canceled=None, # type: ignore
)
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None:
downloaded_path = mock_context.models.download_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
assert downloaded_path.is_file()
assert downloaded_path.exists()
assert downloaded_path.name == "test_embedding.safetensors"
assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache"
downloaded_path_2 = mock_context.models.download_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
assert downloaded_path == downloaded_path_2
def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None:
downloaded_path = mock_context.models.download_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
loaded_model_1 = mock_context.models.load_local_model(downloaded_path)
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_local_model(downloaded_path)
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model
loaded_model_3 = mock_context.models.load_local_model(embedding_file)
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
assert loaded_model_1.model is not loaded_model_3.model
assert isinstance(loaded_model_1.model, dict)
assert isinstance(loaded_model_3.model, dict)
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
@pytest.mark.skip(reason="This requires a test model to load")
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
loaded_model = mock_context.models.load_local_model(vae_directory)
assert isinstance(loaded_model, LoadedModelWithoutConfig)
assert isinstance(loaded_model.model, AutoencoderTiny)
def test_download_and_load(mock_context: InvocationContext) -> None:
loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
def test_download_diffusers(mock_context: InvocationContext) -> None:
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo")
assert (model_path / "model_index.json").exists()
assert (model_path / "vae").is_dir()
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
assert model_path.is_dir()
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
model_path / "diffusion_pytorch_model.safetensors"
).exists()

View File

@@ -61,6 +61,13 @@ def embedding_file(mm2_model_files: Path) -> Path:
return mm2_model_files / "test_embedding.safetensors"
# Can be used to test diffusers model directory loading, but
# the test file adds ~10MB of space.
# @pytest.fixture
# def vae_directory(mm2_model_files: Path) -> Path:
# return mm2_model_files / "taesdxl"
@pytest.fixture
def diffusers_dir(mm2_model_files: Path) -> Path:
return mm2_model_files / "test-diffusers-main"
@@ -293,4 +300,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
},
),
)
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
sess.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
sess.mount(
"http://www.huggingface.co/foo.txt",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": 'filename="foo.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
return sess