mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 03:58:09 -05:00
In `ObjectSerializerDisk`, we use `torch.load` to load serialized objects from disk. With torch 2.6.0, torch defaults to `weights_only=True`. As a result, torch will raise when attempting to deserialize anything with an unrecognized class. For example, our `ConditioningFieldData` class is untrusted. When we load conditioning from disk, we will get a runtime error. Torch provides a method to add trusted classes to an allowlist. This change adds an arg to `ObjectSerializerDisk` to add a list of safe globals to the allowlist and uses it for both `ObjectSerializerDisk` instances. Note: My first attempt inferred the class from the generic type arg that `ObjectSerializerDisk` accepts, and added that to the allowlist. Unfortunately, this doesn't work. For example, `ConditioningFieldData` has a `conditionings` attribute that may be one some other untrusted classes representing model-specific conditioning data. So, even if we allowlist `ConditioningFieldData`, loading will fail when torch deserializes the `conditionings` attribute.
94 lines
3.5 KiB
Python
94 lines
3.5 KiB
Python
import shutil
|
|
import tempfile
|
|
import typing
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Optional, TypeVar
|
|
|
|
import torch
|
|
|
|
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
|
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
|
from invokeai.app.util.misc import uuid_string
|
|
|
|
if TYPE_CHECKING:
|
|
from invokeai.app.services.invoker import Invoker
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
|
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
|
|
|
:param output_dir: The folder where the serialized objects will be stored
|
|
:param safe_globals: A list of types to be added to the safe globals for torch serialization
|
|
:param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
output_dir: Path,
|
|
safe_globals: list[type],
|
|
ephemeral: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self._ephemeral = ephemeral
|
|
self._base_output_dir = output_dir
|
|
self._base_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if self._ephemeral:
|
|
# Remove dangling tempdirs that might have been left over from an earlier unplanned shutdown.
|
|
for temp_dir in filter(Path.is_dir, self._base_output_dir.glob("tmp*")):
|
|
shutil.rmtree(temp_dir)
|
|
|
|
# Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows
|
|
self._tempdir = (
|
|
tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None
|
|
)
|
|
self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir
|
|
self.__obj_class_name: Optional[str] = None
|
|
|
|
torch.serialization.add_safe_globals(safe_globals) if safe_globals else None
|
|
|
|
def load(self, name: str) -> T:
|
|
file_path = self._get_path(name)
|
|
try:
|
|
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
|
|
except FileNotFoundError as e:
|
|
raise ObjectNotFoundError(name) from e
|
|
|
|
def save(self, obj: T) -> str:
|
|
name = self._new_name()
|
|
file_path = self._get_path(name)
|
|
torch.save(obj, file_path) # pyright: ignore [reportUnknownMemberType]
|
|
return name
|
|
|
|
def delete(self, name: str) -> None:
|
|
file_path = self._get_path(name)
|
|
file_path.unlink()
|
|
|
|
@property
|
|
def _obj_class_name(self) -> str:
|
|
if not self.__obj_class_name:
|
|
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
|
|
self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue]
|
|
return self.__obj_class_name
|
|
|
|
def _get_path(self, name: str) -> Path:
|
|
return self._output_dir / name
|
|
|
|
def _new_name(self) -> str:
|
|
return f"{self._obj_class_name}_{uuid_string()}"
|
|
|
|
def _tempdir_cleanup(self) -> None:
|
|
"""Calls `cleanup` on the temporary directory, if it exists."""
|
|
if self._tempdir:
|
|
self._tempdir.cleanup()
|
|
|
|
def __del__(self) -> None:
|
|
# In case the service is not properly stopped, clean up the temporary directory when the class instance is GC'd.
|
|
self._tempdir_cleanup()
|
|
|
|
def stop(self, invoker: "Invoker") -> None:
|
|
self._tempdir_cleanup()
|