tests: fix broken tests

This commit is contained in:
psychedelicious
2024-02-08 00:36:53 +11:00
parent aff44c0e58
commit 6d25789705
4 changed files with 17 additions and 10 deletions

View File

@@ -1,15 +1,18 @@
import typing
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, TypeVar
from typing import TYPE_CHECKING, Optional, TypeVar
import torch
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.util.misc import uuid_string
if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker
T = TypeVar("T")
@@ -31,7 +34,7 @@ class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]):
self._output_dir.mkdir(parents=True, exist_ok=True)
self.__obj_class_name: Optional[str] = None
def start(self, invoker: Invoker) -> None:
def start(self, invoker: "Invoker") -> None:
self._invoker = invoker
delete_all_result = self._delete_all()
if delete_all_result.deleted_count > 0:

View File

@@ -1,11 +1,13 @@
from queue import Queue
from typing import Optional, TypeVar
from typing import TYPE_CHECKING, Optional, TypeVar
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
T = TypeVar("T")
if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker
class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
"""Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size."""
@@ -17,13 +19,13 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
self._cache_ids = Queue[str]()
self._max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
def start(self, invoker: "Invoker") -> None:
self._invoker = invoker
start_op = getattr(self._underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
def stop(self, invoker: "Invoker") -> None:
self._invoker = invoker
stop_op = getattr(self._underlying_storage, "stop", None)
if callable(stop_op):