mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-07 06:54:58 -05:00
Graph, metadata and workflow all take stringified JSON only. This makes the API consistent and means we don't need to do a round-trip of pydantic parsing when handling this data. It also prevents a failure mode where an uploaded image's metadata, workflow or graph are old and don't match the current schema. As before, the frontend does strict validation and parsing when loading these values.
163 lines
6.0 KiB
Python
163 lines
6.0 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
|
from pathlib import Path
|
|
from queue import Queue
|
|
from typing import Dict, Optional, Union
|
|
|
|
from PIL import Image, PngImagePlugin
|
|
from PIL.Image import Image as PILImageType
|
|
from send2trash import send2trash
|
|
|
|
from invokeai.app.services.invoker import Invoker
|
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
|
|
|
from .image_files_base import ImageFileStorageBase
|
|
from .image_files_common import ImageFileDeleteException, ImageFileNotFoundException, ImageFileSaveException
|
|
|
|
|
|
class DiskImageFileStorage(ImageFileStorageBase):
|
|
"""Stores images on disk"""
|
|
|
|
__output_folder: Path
|
|
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
|
__cache: Dict[Path, PILImageType]
|
|
__max_cache_size: int
|
|
__invoker: Invoker
|
|
|
|
def __init__(self, output_folder: Union[str, Path]):
|
|
self.__cache = {}
|
|
self.__cache_ids = Queue()
|
|
self.__max_cache_size = 10 # TODO: get this from config
|
|
|
|
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
|
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
|
# Validate required output folders at launch
|
|
self.__validate_storage_folders()
|
|
|
|
def start(self, invoker: Invoker) -> None:
|
|
self.__invoker = invoker
|
|
|
|
def get(self, image_name: str) -> PILImageType:
|
|
try:
|
|
image_path = self.get_path(image_name)
|
|
|
|
cache_item = self.__get_cache(image_path)
|
|
if cache_item:
|
|
return cache_item
|
|
|
|
image = Image.open(image_path)
|
|
self.__set_cache(image_path, image)
|
|
return image
|
|
except FileNotFoundError as e:
|
|
raise ImageFileNotFoundException from e
|
|
|
|
def save(
|
|
self,
|
|
image: PILImageType,
|
|
image_name: str,
|
|
metadata: Optional[str] = None,
|
|
workflow: Optional[str] = None,
|
|
graph: Optional[str] = None,
|
|
thumbnail_size: int = 256,
|
|
) -> None:
|
|
try:
|
|
self.__validate_storage_folders()
|
|
image_path = self.get_path(image_name)
|
|
|
|
pnginfo = PngImagePlugin.PngInfo()
|
|
info_dict = {}
|
|
|
|
if metadata is not None:
|
|
info_dict["invokeai_metadata"] = metadata
|
|
pnginfo.add_text("invokeai_metadata", metadata)
|
|
if workflow is not None:
|
|
info_dict["invokeai_workflow"] = workflow
|
|
pnginfo.add_text("invokeai_workflow", workflow)
|
|
if graph is not None:
|
|
info_dict["invokeai_graph"] = graph
|
|
pnginfo.add_text("invokeai_graph", graph)
|
|
|
|
# When saving the image, the image object's info field is not populated. We need to set it
|
|
image.info = info_dict
|
|
image.save(
|
|
image_path,
|
|
"PNG",
|
|
pnginfo=pnginfo,
|
|
compress_level=self.__invoker.services.configuration.pil_compress_level,
|
|
)
|
|
|
|
thumbnail_name = get_thumbnail_name(image_name)
|
|
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
|
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
|
thumbnail_image.save(thumbnail_path)
|
|
|
|
self.__set_cache(image_path, image)
|
|
self.__set_cache(thumbnail_path, thumbnail_image)
|
|
except Exception as e:
|
|
raise ImageFileSaveException from e
|
|
|
|
def delete(self, image_name: str) -> None:
|
|
try:
|
|
image_path = self.get_path(image_name)
|
|
|
|
if image_path.exists():
|
|
send2trash(image_path)
|
|
if image_path in self.__cache:
|
|
del self.__cache[image_path]
|
|
|
|
thumbnail_name = get_thumbnail_name(image_name)
|
|
thumbnail_path = self.get_path(thumbnail_name, True)
|
|
|
|
if thumbnail_path.exists():
|
|
send2trash(thumbnail_path)
|
|
if thumbnail_path in self.__cache:
|
|
del self.__cache[thumbnail_path]
|
|
except Exception as e:
|
|
raise ImageFileDeleteException from e
|
|
|
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
|
|
path = self.__output_folder / image_name
|
|
|
|
if thumbnail:
|
|
thumbnail_name = get_thumbnail_name(image_name)
|
|
path = self.__thumbnails_folder / thumbnail_name
|
|
|
|
return path
|
|
|
|
def validate_path(self, path: Union[str, Path]) -> bool:
|
|
"""Validates the path given for an image or thumbnail."""
|
|
path = path if isinstance(path, Path) else Path(path)
|
|
return path.exists()
|
|
|
|
def get_workflow(self, image_name: str) -> str | None:
|
|
image = self.get(image_name)
|
|
workflow = image.info.get("invokeai_workflow", None)
|
|
if isinstance(workflow, str):
|
|
return workflow
|
|
return None
|
|
|
|
def get_graph(self, image_name: str) -> str | None:
|
|
image = self.get(image_name)
|
|
graph = image.info.get("invokeai_graph", None)
|
|
if isinstance(graph, str):
|
|
return graph
|
|
return None
|
|
|
|
def __validate_storage_folders(self) -> None:
|
|
"""Checks if the required output folders exist and create them if they don't"""
|
|
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
|
|
for folder in folders:
|
|
folder.mkdir(parents=True, exist_ok=True)
|
|
|
|
def __get_cache(self, image_name: Path) -> Optional[PILImageType]:
|
|
return None if image_name not in self.__cache else self.__cache[image_name]
|
|
|
|
def __set_cache(self, image_name: Path, image: PILImageType):
|
|
if image_name not in self.__cache:
|
|
self.__cache[image_name] = image
|
|
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
|
if len(self.__cache) > self.__max_cache_size:
|
|
cache_id = self.__cache_ids.get()
|
|
if cache_id in self.__cache:
|
|
del self.__cache[cache_id]
|