Compare commits

...

5 Commits

Author SHA1 Message Date
Lincoln Stein
12f9bda524 Merge github.com:invoke-ai/InvokeAI into lstein/bugfix/model-install-thread-stop 2024-03-20 22:42:02 -04:00
Lincoln Stein
b65eff1c65 add timeouts to the download tests 2024-03-20 22:40:55 -04:00
Lincoln Stein
ce687a2869 after stopping install and download services, wait for thread exit 2024-03-20 22:11:52 -04:00
Lincoln Stein
e452c6171b added debugging statements 2024-03-20 00:13:57 -04:00
Lincoln Stein
b15d05f8a8 refactor big _install_next_item() loop 2024-03-19 23:45:02 -04:00
5 changed files with 102 additions and 96 deletions

View File

@@ -87,6 +87,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._queue.queue.clear() self._queue.queue.clear()
self.join() # wait for all active jobs to finish self.join() # wait for all active jobs to finish
self._stop_event.set() self._stop_event.set()
for thread in self._worker_pool:
thread.join()
self._worker_pool.clear() self._worker_pool.clear()
def submit_download_job( def submit_download_job(

View File

@@ -34,6 +34,7 @@ from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.metadata import ( from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata, AnyModelRepoMetadata,
HuggingFaceMetadataFetch, HuggingFaceMetadataFetch,
ModelMetadataFetchBase,
ModelMetadataWithFiles, ModelMetadataWithFiles,
RemoteModelFile, RemoteModelFile,
) )
@@ -92,6 +93,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._running = False self._running = False
self._session = session self._session = session
self._install_thread: Optional[threading.Thread] = None
self._next_job_id = 0 self._next_job_id = 0
@property @property
@@ -126,6 +128,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._stop_event.set() self._stop_event.set()
self._clear_pending_jobs() self._clear_pending_jobs()
self._download_cache.clear() self._download_cache.clear()
assert self._install_thread is not None
self._install_thread.join()
self._running = False self._running = False
def _clear_pending_jobs(self) -> None: def _clear_pending_jobs(self) -> None:
@@ -275,6 +279,7 @@ class ModelInstallService(ModelInstallServiceBase):
if timeout > 0 and time.time() - start > timeout: if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded") raise TimeoutError("Timeout exceeded")
self._install_queue.join() self._install_queue.join()
return self._install_jobs return self._install_jobs
def cancel_job(self, job: ModelInstallJob) -> None: def cancel_job(self, job: ModelInstallJob) -> None:
@@ -415,15 +420,16 @@ class ModelInstallService(ModelInstallServiceBase):
# Internal functions that manage the installer threads # Internal functions that manage the installer threads
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
def _start_installer_thread(self) -> None: def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start() self._install_thread = threading.Thread(target=self._install_next_item, daemon=True)
self._install_thread.start()
self._running = True self._running = True
def _install_next_item(self) -> None: def _install_next_item(self) -> None:
done = False self._logger.info(f"Installer thread {threading.get_ident()} starting")
while not done: while True:
if self._stop_event.is_set(): if self._stop_event.is_set():
done = True break
continue self._logger.info(f"Installer thread {threading.get_ident()} running")
try: try:
job = self._install_queue.get(timeout=1) job = self._install_queue.get(timeout=1)
except Empty: except Empty:
@@ -436,39 +442,14 @@ class ModelInstallService(ModelInstallServiceBase):
elif job.errored: elif job.errored:
self._signal_job_errored(job) self._signal_job_errored(job)
elif ( elif job.waiting or job.downloads_done:
job.waiting or job.downloads_done self._register_or_install(job)
): # local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
except InvalidModelConfigException as excp: except InvalidModelConfigException as excp:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts): self._set_error(job, excp)
job.set_error(
InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
job.set_error(excp)
self._signal_job_errored(job)
except (OSError, DuplicateModelException) as excp: except (OSError, DuplicateModelException) as excp:
job.set_error(excp) self._set_error(job, excp)
self._signal_job_errored(job)
finally: finally:
# if this is an install of a remote file, then clean up the temporary directory # if this is an install of a remote file, then clean up the temporary directory
@@ -476,6 +457,36 @@ class ModelInstallService(ModelInstallServiceBase):
rmtree(job._install_tmpdir) rmtree(job._install_tmpdir)
self._install_completed_event.set() self._install_completed_event.set()
self._install_queue.task_done() self._install_queue.task_done()
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
def _register_or_install(self, job: ModelInstallJob) -> None:
# local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
job.set_error(
InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
job.set_error(excp)
self._signal_job_errored(job)
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory # Internal functions that manage the models directory
@@ -905,7 +916,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id) self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
@staticmethod @staticmethod
def get_fetcher_from_url(url: str): def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()): if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'") raise ValueError(f"Unsupported model source: '{url}'")

View File

@@ -51,6 +51,7 @@ def session() -> Session:
return sess return sess
@pytest.mark.timeout(timeout=20, method="thread")
def test_basic_queue_download(tmp_path: Path, session: Session) -> None: def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
events = set() events = set()
@@ -80,6 +81,7 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
queue.stop() queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_errors(tmp_path: Path, session: Session) -> None: def test_errors(tmp_path: Path, session: Session) -> None:
queue = DownloadQueueService( queue = DownloadQueueService(
requests_session=session, requests_session=session,
@@ -101,6 +103,7 @@ def test_errors(tmp_path: Path, session: Session) -> None:
queue.stop() queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_event_bus(tmp_path: Path, session: Session) -> None: def test_event_bus(tmp_path: Path, session: Session) -> None:
event_bus = TestEventService() event_bus = TestEventService()
@@ -136,6 +139,7 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
queue.stop() queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
queue = DownloadQueueService( queue = DownloadQueueService(
requests_session=session, requests_session=session,

View File

@@ -5,6 +5,7 @@ Test the model installer
import platform import platform
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Any, Dict
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
@@ -276,48 +277,48 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test # TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
# @pytest.mark.parametrize( @pytest.mark.parametrize(
# "model_params", "model_params",
# [ [
# # SDXL, Lora # SDXL, Lora
# { {
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors", "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
# "name": "test_lora", "name": "test_lora",
# "type": "embedding", "type": "embedding",
# }, },
# # SDXL, Lora - incorrect type # SDXL, Lora - incorrect type
# { {
# "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors", "repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
# "name": "test_lora", "name": "test_lora",
# "type": "lora", "type": "lora",
# }, },
# ], ],
# ) )
# @pytest.mark.timeout(timeout=40, method="thread") @pytest.mark.timeout(timeout=40, method="thread")
# def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
# """Test whether or not type is respected on configs when passed to heuristic import.""" """Test whether or not type is respected on configs when passed to heuristic import."""
# assert "name" in model_params and "type" in model_params assert "name" in model_params and "type" in model_params
# config1: Dict[str, Any] = { config1: Dict[str, Any] = {
# "name": f"{model_params['name']}_1", "name": f"{model_params['name']}_1",
# "type": model_params["type"], "type": model_params["type"],
# "hash": "placeholder1", "hash": "placeholder1",
# } }
# config2: Dict[str, Any] = { config2: Dict[str, Any] = {
# "name": f"{model_params['name']}_2", "name": f"{model_params['name']}_2",
# "type": ModelType(model_params["type"]), "type": ModelType(model_params["type"]),
# "hash": "placeholder2", "hash": "placeholder2",
# } }
# assert "repo_id" in model_params assert "repo_id" in model_params
# install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1) install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
# mm2_installer.wait_for_job(install_job1, timeout=20) mm2_installer.wait_for_job(install_job1, timeout=20)
# if model_params["type"] != "embedding": if model_params["type"] != "embedding":
# assert install_job1.errored assert install_job1.errored
# assert install_job1.error_type == "InvalidModelConfigException" assert install_job1.error_type == "InvalidModelConfigException"
# return return
# assert install_job1.complete assert install_job1.complete
# assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
# install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2) install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
# mm2_installer.wait_for_job(install_job2, timeout=20) mm2_installer.wait_for_job(install_job2, timeout=20)
# assert install_job2.complete assert install_job2.complete
# assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out

View File

@@ -2,13 +2,11 @@
import os import os
import shutil import shutil
import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from pytest import FixtureRequest
from requests.sessions import Session from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession from requests_testadapter import TestAdapter, TestSession
@@ -99,15 +97,11 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
@pytest.fixture @pytest.fixture
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase: def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
download_queue = DownloadQueueService(requests_session=mm2_session) download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start() download_queue.start()
yield download_queue
def stop_queue() -> None: download_queue.stop()
download_queue.stop()
request.addfinalizer(stop_queue)
return download_queue
@pytest.fixture @pytest.fixture
@@ -130,7 +124,6 @@ def mm2_installer(
mm2_app_config: InvokeAIAppConfig, mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase, mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session, mm2_session: Session,
request: FixtureRequest,
) -> ModelInstallServiceBase: ) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
@@ -145,13 +138,8 @@ def mm2_installer(
session=mm2_session, session=mm2_session,
) )
installer.start() installer.start()
yield installer
def stop_installer() -> None: installer.stop()
installer.stop()
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message
request.addfinalizer(stop_installer)
return installer
@pytest.fixture @pytest.fixture