Files
InvokeAI/invokeai/app/services/object_serializer/object_serializer_disk.py
2024-04-23 17:12:14 +10:00

86 lines
3.3 KiB
Python

import shutil
import tempfile
import typing
from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypeVar
import torch
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.util.misc import uuid_string
if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker
T = TypeVar("T")
class ObjectSerializerDisk(ObjectSerializerBase[T]):
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
:param output_dir: The folder where the serialized objects will be stored
:param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit
"""
def __init__(self, output_dir: Path, ephemeral: bool = False):
super().__init__()
self._ephemeral = ephemeral
self._base_output_dir = output_dir
self._base_output_dir.mkdir(parents=True, exist_ok=True)
if self._ephemeral:
# Remove dangling tempdirs that might have been left over from an earlier unplanned shutdown.
for temp_dir in filter(Path.is_dir, self._base_output_dir.glob("tmp*")):
shutil.rmtree(temp_dir)
# Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows
self._tempdir = (
tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None
)
self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir
self.__obj_class_name: Optional[str] = None
def load(self, name: str) -> T:
file_path = self._get_path(name)
try:
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
except FileNotFoundError as e:
raise ObjectNotFoundError(name) from e
def save(self, obj: T) -> str:
name = self._new_name()
file_path = self._get_path(name)
torch.save(obj, file_path) # pyright: ignore [reportUnknownMemberType]
return name
def delete(self, name: str) -> None:
file_path = self._get_path(name)
file_path.unlink()
@property
def _obj_class_name(self) -> str:
if not self.__obj_class_name:
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue]
return self.__obj_class_name
def _get_path(self, name: str) -> Path:
return self._output_dir / name
def _new_name(self) -> str:
return f"{self._obj_class_name}_{uuid_string()}"
def _tempdir_cleanup(self) -> None:
"""Calls `cleanup` on the temporary directory, if it exists."""
if self._tempdir:
self._tempdir.cleanup()
def __del__(self) -> None:
# In case the service is not properly stopped, clean up the temporary directory when the class instance is GC'd.
self._tempdir_cleanup()
def stop(self, invoker: "Invoker") -> None:
self._tempdir_cleanup()