From 5fa2cf59e2e95002977ca7c581f101d65dfe5866 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 27 Mar 2025 08:11:36 +1000 Subject: [PATCH] fix(app): add trusted classes to torch safe globals to prevent errors when loading them In `ObjectSerializerDisk`, we use `torch.load` to load serialized objects from disk. With torch 2.6.0, torch defaults to `weights_only=True`. As a result, torch will raise when attempting to deserialize anything with an unrecognized class. For example, our `ConditioningFieldData` class is untrusted. When we load conditioning from disk, we will get a runtime error. Torch provides a method to add trusted classes to an allowlist. This change adds an arg to `ObjectSerializerDisk` to add a list of safe globals to the allowlist and uses it for both `ObjectSerializerDisk` instances. Note: My first attempt inferred the class from the generic type arg that `ObjectSerializerDisk` accepts, and added that to the allowlist. Unfortunately, this doesn't work. For example, `ConditioningFieldData` has a `conditionings` attribute that may be one some other untrusted classes representing model-specific conditioning data. So, even if we allowlist `ConditioningFieldData`, loading will fail when torch deserializes the `conditionings` attribute. --- invokeai/app/api/dependencies.py | 27 ++++++++++++++++--- .../object_serializer_disk.py | 10 ++++++- .../diffusion/conditioning_data.py | 3 +++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 2d5a9004f7..453f04dd8e 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -37,7 +37,13 @@ from invokeai.app.services.style_preset_records.style_preset_records_sqlite impo from invokeai.app.services.urls.urls_default import LocalUrlService from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + BasicConditioningInfo, + ConditioningFieldData, + FLUXConditioningInfo, + SD3ConditioningInfo, + SDXLConditioningInfo, +) from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -101,10 +107,25 @@ class ApiDependencies: images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) tensors = ObjectSerializerForwardCache( - ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True) + ObjectSerializerDisk[torch.Tensor]( + output_folder / "tensors", + safe_globals=[torch.Tensor], + ephemeral=True, + ), + max_cache_size=0, ) conditioning = ObjectSerializerForwardCache( - ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) + ObjectSerializerDisk[ConditioningFieldData]( + output_folder / "conditioning", + safe_globals=[ + ConditioningFieldData, + BasicConditioningInfo, + SDXLConditioningInfo, + FLUXConditioningInfo, + SD3ConditioningInfo, + ], + ephemeral=True, + ), ) download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events) model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index 8edd29e150..bbd3f78550 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -21,10 +21,16 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]): """Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`. :param output_dir: The folder where the serialized objects will be stored + :param safe_globals: A list of types to be added to the safe globals for torch serialization :param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit """ - def __init__(self, output_dir: Path, ephemeral: bool = False): + def __init__( + self, + output_dir: Path, + safe_globals: list[type], + ephemeral: bool = False, + ) -> None: super().__init__() self._ephemeral = ephemeral self._base_output_dir = output_dir @@ -42,6 +48,8 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]): self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir self.__obj_class_name: Optional[str] = None + torch.serialization.add_safe_globals(safe_globals) if safe_globals else None + def load(self, name: str) -> T: file_path = self._get_path(name) try: diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 184cdb9b02..3fc0e0092b 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -69,6 +69,9 @@ class SD3ConditioningInfo: @dataclass class ConditioningFieldData: + # If you change this class, adding more types, you _must_ update the instantiation of ObjectSerializerDisk in + # invokeai/app/api/dependencies.py, adding the types to the list of safe globals. If you do not, torch will be + # unable to deserialize the object and will raise an error. conditionings: ( List[BasicConditioningInfo] | List[SDXLConditioningInfo]