diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 9a6c7416f6..0cfcf2f3b7 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -93,7 +93,7 @@ class ApiDependencies: conditioning = ObjectSerializerForwardCache( ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) - download_queue_service = DownloadQueueService(event_bus=events) + download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events) model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") model_manager = ModelManagerService.build_model_manager( app_config=configuration, diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 7d8229fba1..d9ab2c7f35 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -15,6 +15,7 @@ from pydantic.networks import AnyHttpUrl from requests import HTTPError from tqdm import tqdm +from invokeai.app.services.config import InvokeAIAppConfig, get_config from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.util.misc import get_iso_timestamp from invokeai.backend.util.logging import InvokeAILogger @@ -40,15 +41,18 @@ class DownloadQueueService(DownloadQueueServiceBase): def __init__( self, max_parallel_dl: int = 5, + app_config: Optional[InvokeAIAppConfig] = None, event_bus: Optional[EventServiceBase] = None, requests_session: Optional[requests.sessions.Session] = None, ): """ Initialize DownloadQueue. + :param app_config: InvokeAIAppConfig object :param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param requests_session: Optional requests.sessions.Session object, for unit tests. """ + self._app_config = app_config or get_config() self._jobs: Dict[int, DownloadJob] = {} self._next_job_id = 0 self._queue: PriorityQueue[DownloadJob] = PriorityQueue() @@ -139,7 +143,7 @@ class DownloadQueueService(DownloadQueueServiceBase): source=source, dest=dest, priority=priority, - access_token=access_token, + access_token=access_token or self._lookup_access_token(source), ) self.submit_download_job( job, @@ -333,6 +337,16 @@ class DownloadQueueService(DownloadQueueServiceBase): def _in_progress_path(self, path: Path) -> Path: return path.with_name(path.name + ".downloading") + def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]: + # Pull the token from config if it exists and matches the URL + print(self._app_config) + token = None + for pair in self._app_config.remote_api_tokens or []: + if re.search(pair.url_regex, str(source)): + token = pair.token + break + return token + def _signal_job_started(self, job: DownloadJob) -> None: job.status = DownloadJobStatus.RUNNING if job.on_start: diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 92512baec9..1a08624f8e 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -222,16 +222,9 @@ class ModelInstallService(ModelInstallServiceBase): access_token=access_token, ) elif re.match(r"^https?://[^/]+", source): - # Pull the token from config if it exists and matches the URL - _token = access_token - if _token is None: - for pair in self.app_config.remote_api_tokens or []: - if re.search(pair.url_regex, source): - _token = pair.token - break source_obj = URLModelSource( url=AnyHttpUrl(source), - access_token=_token, + access_token=access_token, ) else: raise ValueError(f"Unsupported model source: '{source}'") diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 7a5f433aca..d16c00302e 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -75,8 +75,6 @@ class ModelManagerServiceBase(ABC): def load_ckpt_from_url( self, source: str | AnyHttpUrl, - access_token: Optional[str] = None, - timeout: Optional[int] = 0, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, ) -> LoadedModel: """ @@ -94,9 +92,6 @@ class ModelManagerServiceBase(ABC): Args: source: A URL or a string that can be converted in one. Repo_ids do not work here. - access_token: Optional access token for restricted resources. - timeout: Wait up to the indicated number of seconds before timing - out long downloads. loader: A Callable that expects a Path and returns a Dict[str|int, Any] Returns: diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 57c409c066..ed274266f3 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -106,8 +106,6 @@ class ModelManagerService(ModelManagerServiceBase): def load_ckpt_from_url( self, source: str | AnyHttpUrl, - access_token: Optional[str] = None, - timeout: Optional[int] = 0, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, ) -> LoadedModel: """ @@ -125,13 +123,10 @@ class ModelManagerService(ModelManagerServiceBase): Args: source: A URL or a string that can be converted in one. Repo_ids do not work here. - access_token: Optional access token for restricted resources. - timeout: Wait up to the indicated number of seconds before timing - out long downloads. loader: A Callable that expects a Path and returns a Dict[str|int, Any] Returns: A LoadedModel object. """ - model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout) + model_path = self.install.download_and_cache_ckpt(source=source) return self.load.load_ckpt_from_path(model_path=model_path, loader=loader) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index bfdbf1e025..c7602760f7 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -496,8 +496,6 @@ class ModelsInterface(InvocationContextInterface): def load_ckpt_from_url( self, source: str | AnyHttpUrl, - access_token: Optional[str] = None, - timeout: Optional[int] = 0, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, ) -> LoadedModel: """ @@ -515,17 +513,12 @@ class ModelsInterface(InvocationContextInterface): Args: source: A URL or a string that can be converted in one. Repo_ids do not work here. - access_token: Optional access token for restricted resources. - timeout: Wait up to the indicated number of seconds before timing - out long downloads. loader: A Callable that expects a Path and returns a Dict[str|int, Any] Returns: A LoadedModel object. """ - result: LoadedModel = self._services.model_manager.load_ckpt_from_url( - source=source, access_token=access_token, timeout=timeout, loader=loader - ) + result: LoadedModel = self._services.model_manager.load_ckpt_from_url(source=source, loader=loader) return result diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 307238fd61..07c473b183 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -2,14 +2,19 @@ import re import time +from contextlib import contextmanager from pathlib import Path +from typing import Generator import pytest from pydantic.networks import AnyHttpUrl from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession +from invokeai.app.services.config import get_config +from invokeai.app.services.config.config_default import URLRegexTokenPair from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.test_nodes import TestEventService # Prevent pytest deprecation warnings @@ -34,6 +39,17 @@ def session() -> Session: ), ) + sess.mount( + "http://www.huggingface.co/foo.txt", + TestAdapter( + content, + headers={ + "Content-Length": len(content), + "Content-Disposition": 'filename="foo.safetensors"', + }, + ), + ) + # here are some malformed URLs to test # missing the content length sess.mount( @@ -205,3 +221,37 @@ def test_cancel(tmp_path: Path, session: Session) -> None: assert events[-1].event_name == "download_cancelled" assert events[-1].payload["source"] == "http://www.civitai.com/models/12345" queue.stop() + + +@contextmanager +def clear_config() -> Generator[None, None, None]: + try: + yield None + finally: + get_config.cache_clear() + + +def test_tokens(tmp_path: Path, session: Session): + with clear_config(): + config = get_config() + config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")] + queue = DownloadQueueService(requests_session=session) + queue.start() + # this one has an access token assigned + job1 = queue.download( + source=AnyHttpUrl("http://www.civitai.com/models/12345"), + dest=tmp_path, + ) + # this one doesn't + job2 = queue.download( + source=AnyHttpUrl( + "http://www.huggingface.co/foo.txt", + ), + dest=tmp_path, + ) + queue.join() + # this token is defined in the temporary root invokeai.yaml + # see tests/backend/model_manager/data/invokeai_root/invokeai.yaml + assert job1.access_token == "cv_12345" + assert job2.access_token is None + queue.stop()