mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
download of individual files working
This commit is contained in:
@@ -9,14 +9,16 @@ from invokeai.app.services.model_manager_service import (
|
||||
SubModelType,
|
||||
ModelInfo,
|
||||
)
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
"""Basic event bus, to have an empty stand-in when not needed."""
|
||||
|
||||
session_event: str = "session_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event."""
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||
@@ -187,3 +189,15 @@ class EventServiceBase:
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_model_download_event (
|
||||
self,
|
||||
job: DownloadJobBase
|
||||
):
|
||||
"""Emit event when the status of a download job changes."""
|
||||
self.dispatch( # use dispatch() directly here because we are not a session event.
|
||||
event_name="download_job_event",
|
||||
payload=dict(
|
||||
job=job
|
||||
)
|
||||
)
|
||||
|
||||
@@ -3,17 +3,27 @@
|
||||
Abstract base class for a multithreaded model download queue.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import requests
|
||||
import threading
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from queue import PriorityQueue
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Set, List, Optional, Dict
|
||||
from typing import Set, List, Optional, Dict, Callable
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, validator, PrivateAttr
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
class EventServicesBase: # forward declaration
|
||||
pass
|
||||
|
||||
class DownloadJobStatus(str, Enum):
|
||||
"""State of a download job."""
|
||||
@@ -34,17 +44,22 @@ class CancelledJobException(Exception):
|
||||
"""Raised when a job is cancelled."""
|
||||
|
||||
|
||||
DownloadEventHandler = Callable[["DownloadJobBase"], None]
|
||||
|
||||
|
||||
@total_ordering
|
||||
class DownloadJobBase(ABC, BaseModel):
|
||||
class DownloadJob(BaseModel):
|
||||
"""Class to monitor and control a model download request."""
|
||||
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
id: int = Field(description="Numeric ID of this job")
|
||||
url: AnyHttpUrl = Field(description="URL to download")
|
||||
destination: Path = Field(description="Destination of URL on local disk")
|
||||
access_token: Optional[str] = Field(description="access token needed to access this resource")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
total_bytes: int = Field(default=0, description="Total bytes to download")
|
||||
event_handler: Optional[DownloadEventHandler] = Field(description="Callable will be called whenever job status changes")
|
||||
error: Exception = Field(default=None, description="Exception that caused an error")
|
||||
|
||||
class Config():
|
||||
@@ -53,46 +68,18 @@ class DownloadJobBase(ABC, BaseModel):
|
||||
arbitrary_types_allowed = True
|
||||
validate_assignment = True
|
||||
|
||||
@validator('destination')
|
||||
def path_doesnt_exist(cls, v):
|
||||
"""Don't allow a destination to clobber an existing file."""
|
||||
if v.exists():
|
||||
raise ValueError(f"{v} already exists")
|
||||
return v
|
||||
# @validator('destination')
|
||||
# def path_doesnt_exist(cls, v):
|
||||
# """Don't allow a destination to clobber an existing file."""
|
||||
# if v.exists():
|
||||
# raise ValueError(f"{v} already exists")
|
||||
# return v
|
||||
|
||||
@abstractmethod
|
||||
def start(self):
|
||||
"""Start the job putting it into ENQUEUED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause(self):
|
||||
"""Pause the job, putting it into PAUSED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self):
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_priority(self, delta: int=+1):
|
||||
"""
|
||||
Change the job's priority.
|
||||
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make this a really high priority job:
|
||||
job.change_priority(-10).
|
||||
"""
|
||||
pass
|
||||
|
||||
def __lt__(self, other: "DownloadJobBase") -> bool:
|
||||
def __lt__(self, other: "DownloadJob") -> bool:
|
||||
"""
|
||||
Return True if self.priority < other.priority.
|
||||
|
||||
:param other: The DownloadJobBase that this will be compared against.
|
||||
:param other: The DownloadJob that this will be compared against.
|
||||
"""
|
||||
if not hasattr(other, "id"):
|
||||
return NotImplemented
|
||||
@@ -108,7 +95,10 @@ class DownloadQueueBase(ABC):
|
||||
url: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
) -> DownloadJobBase:
|
||||
start: bool = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handler: Optional[DownloadEventHandler] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Create a download job.
|
||||
|
||||
@@ -116,34 +106,27 @@ class DownloadQueueBase(ABC):
|
||||
:param destdir: Directory to download into.
|
||||
:param filename: Optional name of file, if not provided
|
||||
will use the content-disposition field to assign the name.
|
||||
:returns DownloadJob: The DownloadJob object for this task.
|
||||
:param start: Immediately start job [True]
|
||||
:param event_handler: Optional callable that will be called whenever job status changes.
|
||||
:returns job id: The numeric ID of the DownloadJob object for this task.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self) -> List[DownloadJobBase]:
|
||||
def list_jobs(self) -> List[DownloadJob]:
|
||||
"""
|
||||
List active DownloadJobBases.
|
||||
List active DownloadJobs.
|
||||
|
||||
:returns List[DownloadJobBase]: List of download jobs whose state is not "completed."
|
||||
:returns List[DownloadJob]: List of download jobs whose state is not "completed."
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_job(self, job: DownloadJobBase):
|
||||
def id_to_job(self, id: int) -> DownloadJob:
|
||||
"""
|
||||
Cancel a download and delete its job.
|
||||
Return the DownloadJob corresponding to the string ID.
|
||||
|
||||
:param job: DownloadJobBase to delete
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||
"""
|
||||
Return the DownloadJobBase corresponding to the string ID.
|
||||
|
||||
:param id: ID of the DownloadJobBase.
|
||||
:param id: ID of the DownloadJob.
|
||||
|
||||
Exceptions:
|
||||
* UnknownJobIDException
|
||||
@@ -156,52 +139,273 @@ class DownloadQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop_all_jobs(self):
|
||||
def pause_all_jobs(self):
|
||||
"""Pause and dequeue all active jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active and enquedjobs."""
|
||||
pass
|
||||
|
||||
class DownloadJob(DownloadJobBase):
|
||||
"""Implementation of DownloadJobBase"""
|
||||
@abstractmethod
|
||||
def start_job(self, id: int):
|
||||
"""Start the job putting it into ENQUEUED state."""
|
||||
pass
|
||||
|
||||
_queue: DownloadQueueBase = PrivateAttr()
|
||||
@abstractmethod
|
||||
def pause_job(self, id: int):
|
||||
"""Pause the job, putting it into PAUSED state."""
|
||||
pass
|
||||
|
||||
def __init__(self, queue: DownloadQueueBase, **kwargs):
|
||||
@abstractmethod
|
||||
def cancel_job(self, id: int):
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_priority(self, id: int, delta: int):
|
||||
"""
|
||||
Create a new DownloadJob.
|
||||
Change the job's priority.
|
||||
|
||||
:param queue: Reference to the DownloadQueueBase object that created us.
|
||||
:param id: ID of the job
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make this a really high priority job:
|
||||
job.change_priority(-10).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._queue = queue
|
||||
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def join(self):
|
||||
"""Wait until all jobs are off the queue."""
|
||||
pass
|
||||
|
||||
|
||||
class DownloadQueue(DownloadQueueBase):
|
||||
"""Class for queued download of models."""
|
||||
|
||||
_jobs: Dict[int, DownloadJobBase]
|
||||
_jobs: Dict[int, DownloadJob]
|
||||
_worker_pool: Set[Thread]
|
||||
_queue: PriorityQueue
|
||||
_next_job: int = 0
|
||||
_lock: threading.Lock
|
||||
_logger: InvokeAILogger
|
||||
_event_bus: Optional[EventServicesBase]
|
||||
_event_handler: Optional[DownloadEventHandler]
|
||||
_next_job_id: int = 0
|
||||
|
||||
def __init__(self, max_parallel_dl: int = 5):
|
||||
def __init__(self,
|
||||
max_parallel_dl: int = 5,
|
||||
events: Optional["EventServicesBase"] = None,
|
||||
event_handler: Optional[DownloadEventHandler] = None,
|
||||
):
|
||||
"""
|
||||
Initialize DownloadQueue.
|
||||
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed.
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||
:param events: Optional EventServices bus for reporting events.
|
||||
:param event_handler: Optional callable that will be called each time a job status changes.
|
||||
"""
|
||||
self._jobs = dict()
|
||||
self._next_job = 0
|
||||
self._next_job_id = 0
|
||||
self._queue = PriorityQueue()
|
||||
self._worker_pool = set()
|
||||
for i in range(0, max_parallel_dl):
|
||||
worker = Thread(target=self._download_next_item, daemon=True)
|
||||
worker.start()
|
||||
self._worker_pool.add(worker)
|
||||
self._lock = threading.RLock()
|
||||
self._logger = InvokeAILogger.getLogger()
|
||||
self._event_bus = events
|
||||
self._event_handler = event_handler
|
||||
|
||||
def create_download_job(
|
||||
self._start_workers(max_parallel_dl)
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
url: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
) -> DownloadJob:
|
||||
|
||||
start: bool = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handler: Optional[DownloadEventHandler] = None,
|
||||
) -> int:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
id = self._next_job_id
|
||||
self._jobs[id] = DownloadJob(
|
||||
id=id,
|
||||
url=url,
|
||||
destination=Path(destdir) / (filename or "."),
|
||||
access_token=access_token,
|
||||
event_handler=(event_handler or self._event_handler),
|
||||
)
|
||||
self._next_job_id += 1
|
||||
job = self._jobs[id]
|
||||
finally:
|
||||
self._lock.release()
|
||||
if start:
|
||||
self.start_job(id)
|
||||
return job.id
|
||||
|
||||
def join(self):
|
||||
self._queue.join()
|
||||
|
||||
def list_jobs(self) -> List[DownloadJob]:
|
||||
return self._jobs.values()
|
||||
|
||||
|
||||
def change_priority(self, id: int, delta: int):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
job = self._jobs[id]
|
||||
job.priority += delta
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def cancel_job(self, job: DownloadJob):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
job.error = CancelledJobException(f"Job {job.id} cancelled at caller's request")
|
||||
self._update_job_status
|
||||
del self._jobs[job.id]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def id_to_job(self, id: int) -> DownloadJob:
|
||||
try:
|
||||
return self._jobs[id]
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def start_job(self, id: int):
|
||||
try:
|
||||
job = self._jobs[id]
|
||||
self._update_job_status(job, DownloadJobStatus.ENQUEUED)
|
||||
self._queue.put(job)
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def pause_job(self, id: int):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
job = self._jobs[id]
|
||||
self._update_job_status(job, DownloadJobStatus.PAUSED)
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def start_all_jobs(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for id in self._jobs:
|
||||
self.start_job(id)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def pause_all_jobs(self, id: int):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for id in self._jobs:
|
||||
self.pause_job(id)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def cancel_all_jobs(self, id: int):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for id in self._jobs:
|
||||
self.cancel_job(id)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _start_workers(self, max_workers: int):
|
||||
for i in range(0, max_workers):
|
||||
worker = Thread(target=self._download_next_item, daemon=True)
|
||||
worker.start()
|
||||
self._worker_pool.add(worker)
|
||||
|
||||
def _download_next_item(self):
|
||||
"""Worker thread gets next job on priority queue."""
|
||||
while True:
|
||||
job = self._queue.get()
|
||||
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for cancelled or errored jobs
|
||||
self._download_with_resume(job)
|
||||
self._queue.task_done()
|
||||
|
||||
def _download_with_resume(self, job: DownloadJob):
|
||||
"""Do the actual download."""
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
resp = requests.get(job.url, header, stream=True)
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
job.total_bytes = content_length
|
||||
|
||||
if job.destination.is_dir():
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
|
||||
except AttributeError:
|
||||
file_name = os.path.basename(job.url)
|
||||
dest = job.destination / file_name
|
||||
else:
|
||||
dest = job.destination
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if dest.exists():
|
||||
job.bytes = dest.stat().st_size
|
||||
header["Range"] = f"bytes={job.bytes}-"
|
||||
open_mode = "ab"
|
||||
resp = requests.get(job.url, headers=header, stream=True) # new request with range
|
||||
|
||||
if exist_size > content_length:
|
||||
self._logger.warning("corrupt existing file found. re-downloading")
|
||||
os.remove(dest)
|
||||
exist_size = 0
|
||||
|
||||
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
|
||||
self._logger.warning(f"{dest}: complete file found. Skipping.")
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
return
|
||||
|
||||
if resp.status_code == 206 or exist_size > 0:
|
||||
self._logger.warning(f"{dest}: partial file found. Resuming...")
|
||||
elif resp.status_code != 200:
|
||||
self._logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
|
||||
else:
|
||||
self._logger.info(f"{dest}: Downloading...")
|
||||
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
report_delta = job.total_bytes / 100 # report every 1% change
|
||||
last_report_bytes = 0
|
||||
try:
|
||||
with open(dest, open_mode) as file:
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
|
||||
return
|
||||
job.bytes += file.write(data)
|
||||
if job.bytes - last_report_bytes >= report_delta:
|
||||
last_report_bytes = job.bytes
|
||||
self._update_job_status(job)
|
||||
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
del self._jobs[job.id]
|
||||
except Exception as excp:
|
||||
self._logger.error(f"An error occurred while downloading {dest}: {str(excp)}")
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
|
||||
def _update_job_status(self,
|
||||
job: DownloadJob,
|
||||
new_status: Optional[DownloadJobStatus] = None
|
||||
):
|
||||
"""Optionally change the job status and send an event indicating a change of state."""
|
||||
if new_status:
|
||||
job.status = new_status
|
||||
if bus := self._event_bus:
|
||||
bus.emit_model_download_event(job)
|
||||
self._logger.debug(f"Status update for download job {job.id}: {job}")
|
||||
if job.event_handler:
|
||||
job.event_handler(job)
|
||||
|
||||
Reference in New Issue
Block a user