mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-18 03:28:05 -05:00
Compare commits
231 Commits
dev/ci/upd
...
feat/nodes
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07c9b598bd | ||
|
|
4d37ce31fc | ||
|
|
20e853084f | ||
|
|
76fe1d0103 | ||
|
|
c4fad12ac1 | ||
|
|
db44d1431c | ||
|
|
8d79610be2 | ||
|
|
466980812b | ||
|
|
ca8f8b9162 | ||
|
|
8ae769fdad | ||
|
|
039f2b00df | ||
|
|
0299f0c4c6 | ||
|
|
df4abdcd81 | ||
|
|
ed8dfdf996 | ||
|
|
4004916af9 | ||
|
|
24638a71da | ||
|
|
c7392e7948 | ||
|
|
f92afaac7c | ||
|
|
6a30b8ec99 | ||
|
|
40a30f8fee | ||
|
|
fe07a0846e | ||
|
|
78acd185b5 | ||
|
|
57db0816d6 | ||
|
|
71347df765 | ||
|
|
751b4f249d | ||
|
|
5136b83049 | ||
|
|
6e4e0fe29e | ||
|
|
f0a9a4fb88 | ||
|
|
34b50e11b6 | ||
|
|
1d9c115225 | ||
|
|
30af20a056 | ||
|
|
cc21fb216c | ||
|
|
6fe62a2705 | ||
|
|
da87378713 | ||
|
|
b6f5267385 | ||
|
|
f9e78d3c64 | ||
|
|
b7b5bd1b46 | ||
|
|
9a3727d3ad | ||
|
|
d68c14516c | ||
|
|
9f4d39aa42 | ||
|
|
84b801d88f | ||
|
|
2fc70c509b | ||
|
|
34fb1c4b19 | ||
|
|
80bdd550cf | ||
|
|
2359b92b46 | ||
|
|
a404fb2d32 | ||
|
|
513eb11616 | ||
|
|
d2c9140e69 | ||
|
|
d95fe5925a | ||
|
|
835922ea8f | ||
|
|
e1e5266fc3 | ||
|
|
5e4457445f | ||
|
|
0221ca8f49 | ||
|
|
cf36e4029e | ||
|
|
c8a98a9a22 | ||
|
|
38ecca9362 | ||
|
|
c4681774a5 | ||
|
|
050add58d2 | ||
|
|
3d60c958c7 | ||
|
|
f5df150097 | ||
|
|
dac82adb5b | ||
|
|
b72c9787a9 | ||
|
|
2623941d91 | ||
|
|
d3a7fea939 | ||
|
|
5a7b687c84 | ||
|
|
0020457fc7 | ||
|
|
658b556544 | ||
|
|
37da0fc075 | ||
|
|
6d3e8507cc | ||
|
|
0e9470503f | ||
|
|
d2ebc6741b | ||
|
|
026d3260b4 | ||
|
|
78533714e3 | ||
|
|
691e1bf829 | ||
|
|
47a088d685 | ||
|
|
63db3fc22f | ||
|
|
ad0bb3f61a | ||
|
|
8f8cd90787 | ||
|
|
d796ea7bec | ||
|
|
e5b7dd63e9 | ||
|
|
af060188bd | ||
|
|
4270e7ae25 | ||
|
|
60a565d7de | ||
|
|
78cf70eaad | ||
|
|
eebaa50710 | ||
|
|
7d582553f2 | ||
|
|
4d6eea7e81 | ||
|
|
f44593331d | ||
|
|
3d9ecbf3c7 | ||
|
|
032aa1d59c | ||
|
|
35e0863bdb | ||
|
|
14070d674e | ||
|
|
108ce06c62 | ||
|
|
da364f3444 | ||
|
|
df5ba75c14 | ||
|
|
e4fb9cb33f | ||
|
|
65b527eb20 | ||
|
|
7dc9d18052 | ||
|
|
5013a4b9f3 | ||
|
|
f929359322 | ||
|
|
6522c71971 | ||
|
|
9c1e65f3a3 | ||
|
|
ebec200ba6 | ||
|
|
e559730b6e | ||
|
|
0acb8ed85d | ||
|
|
8c1c9cd702 | ||
|
|
0ece4686aa | ||
|
|
af95cef7f9 | ||
|
|
1eca7a918a | ||
|
|
9e6b958023 | ||
|
|
f7b99d93ae | ||
|
|
85d03dcd90 | ||
|
|
032555bcfe | ||
|
|
4caa1f19b2 | ||
|
|
95d4bd3012 | ||
|
|
037078c8ad | ||
|
|
6de2f66b50 | ||
|
|
cd7b248eda | ||
|
|
6d8c077f4e | ||
|
|
97127e560e | ||
|
|
27dc07d95a | ||
|
|
f7dc171c4f | ||
|
|
4b957edfec | ||
|
|
46ca7718d9 | ||
|
|
b928d7a6e6 | ||
|
|
8a836247c8 | ||
|
|
95c3644564 | ||
|
|
799cd07174 | ||
|
|
9af385468d | ||
|
|
3487388788 | ||
|
|
9a383e456d | ||
|
|
805f9f8f4a | ||
|
|
52aa0c9bbd | ||
|
|
7f5f4689cc | ||
|
|
a3f81f4b98 | ||
|
|
15c59e606f | ||
|
|
40d4cabecd | ||
|
|
3493c8119b | ||
|
|
c1e7460d39 | ||
|
|
3ffff023b2 | ||
|
|
f9384be59b | ||
|
|
6cf308004a | ||
|
|
d1029138d2 | ||
|
|
06b5800d28 | ||
|
|
483f2ccb56 | ||
|
|
93ced0bec6 | ||
|
|
4333852c37 | ||
|
|
3baa230077 | ||
|
|
9e594f9018 | ||
|
|
b0c41b4828 | ||
|
|
e0d6946b6b | ||
|
|
bf7ea8309f | ||
|
|
54b65f725f | ||
|
|
8ef49c2640 | ||
|
|
f488b1a7f2 | ||
|
|
d2edb7c402 | ||
|
|
f0a3f07b45 | ||
|
|
b42b630583 | ||
|
|
31a78d571b | ||
|
|
fdc2232ea0 | ||
|
|
e94d0b2d40 | ||
|
|
75ccbaee9c | ||
|
|
2848c8397c | ||
|
|
fe8b5193de | ||
|
|
3d1470399c | ||
|
|
fcf9c63049 | ||
|
|
7bfb5640ad | ||
|
|
15e57e3a3d | ||
|
|
279468c0e8 | ||
|
|
c565812723 | ||
|
|
ec6c8e2a38 | ||
|
|
77f2690711 | ||
|
|
c4b3a24ed7 | ||
|
|
33c69359c2 | ||
|
|
864f4bb4af | ||
|
|
5365f42a04 | ||
|
|
3dc60254b9 | ||
|
|
027a8562d7 | ||
|
|
34f3a0f0e3 | ||
|
|
d0bac1675e | ||
|
|
4e56c962f4 | ||
|
|
4ef0e43759 | ||
|
|
6945d10297 | ||
|
|
4d6cef7ac8 | ||
|
|
a7786d5ff2 | ||
|
|
6c1de975d9 | ||
|
|
a1079e455a | ||
|
|
5457c7f069 | ||
|
|
b8c1a3f96c | ||
|
|
cee8e85f76 | ||
|
|
09f166577e | ||
|
|
bcc21531fb | ||
|
|
da4eacdffe | ||
|
|
6102e560ba | ||
|
|
ff3aa57117 | ||
|
|
49db6f4fac | ||
|
|
20f6a597ab | ||
|
|
04c453721c | ||
|
|
350ffecc1f | ||
|
|
b0557aa16b | ||
|
|
1c9429a6ea | ||
|
|
206e6b1730 | ||
|
|
357cee2849 | ||
|
|
0b49997bb6 | ||
|
|
5e09dd380d | ||
|
|
c7303adb0d | ||
|
|
ed1f096a6f | ||
|
|
6ab5d28cf3 | ||
|
|
a75148cb16 | ||
|
|
f7bbc4004a | ||
|
|
cee21ca082 | ||
|
|
08ec12b391 | ||
|
|
ff5e2a9a8c | ||
|
|
e0b9b5cc6c | ||
|
|
aca4770481 | ||
|
|
5d5157fc65 | ||
|
|
fb6ef61a4d | ||
|
|
ee24ad7b13 | ||
|
|
f8e90ba3f0 | ||
|
|
ad0b70ca23 | ||
|
|
7dfa135b2c | ||
|
|
beeaa05658 | ||
|
|
6b6d654f60 | ||
|
|
853c83d0c2 | ||
|
|
1809990ed4 | ||
|
|
79d49853d2 | ||
|
|
bd0ad59c27 | ||
|
|
cce40acba5 | ||
|
|
bc9491ab69 | ||
|
|
b909bac0dc | ||
|
|
8f80ba9520 |
@@ -247,8 +247,8 @@ class InvokeAiInstance:
|
||||
pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torch~=2.0.0",
|
||||
"torchvision>=0.14.1",
|
||||
"--force-reinstall",
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from logging import Logger
|
||||
import os
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from typing import types
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
@@ -11,7 +13,7 @@ from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.image_file_storage import DiskImageFileStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
@@ -37,13 +39,16 @@ def check_internet() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
|
||||
def initialize(config, event_handler_id: int, logger: Logger = logger):
|
||||
Globals.try_patchmatch = config.patchmatch
|
||||
Globals.always_use_cpu = config.always_use_cpu
|
||||
Globals.internet_available = config.internet_available and check_internet()
|
||||
@@ -60,31 +65,47 @@ class ApiDependencies:
|
||||
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
||||
)
|
||||
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
||||
latents = ForwardCacheLatentsStorage(
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
|
||||
metadata = PngMetadataService()
|
||||
|
||||
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
|
||||
urls = LocalUrlService()
|
||||
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
|
||||
images_new = ImageService(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config,logger),
|
||||
model_manager=get_model_manager(config, logger),
|
||||
events=events,
|
||||
logger=logger,
|
||||
latents=latents,
|
||||
images=images,
|
||||
metadata=metadata,
|
||||
images=image_file_storage,
|
||||
images_new=images_new,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config,logger),
|
||||
restoration=RestorationServices(config, logger),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
47
invokeai/app/api/routers/image_files.py
Normal file
47
invokeai/app/api/routers/image_files.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from fastapi import HTTPException, Path
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.models.image import ImageType
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
image_files_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"])
|
||||
|
||||
|
||||
@image_files_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of the image to get"),
|
||||
image_name: str = Path(description="The id of the image to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images_new.get_path(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
|
||||
return FileResponse(path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@image_files_router.get(
|
||||
"/{image_type}/{image_name}/thumbnail", operation_id="get_thumbnail"
|
||||
)
|
||||
async def get_thumbnail(
|
||||
image_type: ImageType = Path(
|
||||
description="The type of the image whose thumbnail to get"
|
||||
),
|
||||
image_name: str = Path(description="The id of the image whose thumbnail to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a thumbnail"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images_new.get_path(
|
||||
image_type=image_type, image_name=image_name, thumbnail=True
|
||||
)
|
||||
|
||||
return FileResponse(path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
71
invokeai/app/api/routers/image_records.py
Normal file
71
invokeai/app/api/routers/image_records.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from fastapi import HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ImageType,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.services.models.image_record import ImageDTO
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
image_records_router = APIRouter(
|
||||
prefix="/v1/images/records", tags=["images", "records"]
|
||||
)
|
||||
|
||||
|
||||
@image_records_router.get("/{image_type}/{image_name}", operation_id="get_image_record")
|
||||
async def get_image_record(
|
||||
image_type: ImageType = Path(description="The type of the image record to get"),
|
||||
image_name: str = Path(description="The id of the image record to get"),
|
||||
) -> ImageDTO:
|
||||
"""Gets an image record by id"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images_new.get_dto(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@image_records_router.get(
|
||||
"/",
|
||||
operation_id="list_image_records",
|
||||
)
|
||||
async def list_image_records(
|
||||
image_type: ImageType = Query(description="The type of image records to get"),
|
||||
image_category: ImageCategory = Query(
|
||||
description="The kind of image records to get"
|
||||
),
|
||||
page: int = Query(default=0, description="The page of image records to get"),
|
||||
per_page: int = Query(
|
||||
default=10, description="The number of image records per page"
|
||||
),
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
"""Gets a list of image records by type and category"""
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images_new.get_many(
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
||||
|
||||
@image_records_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
||||
async def delete_image_record(
|
||||
image_type: ImageType = Query(description="The type of image to delete"),
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image record"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.images_new.delete(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
@@ -1,90 +1,39 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from invokeai.app.api.models.images import (
|
||||
ImageResponse,
|
||||
ImageResponseMetadata,
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ImageType,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
|
||||
from invokeai.app.services.image_file_storage import ImageFileStorageBase
|
||||
from invokeai.app.services.models.image_record import ImageRecord
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_name: str = Path(description="The name of the image to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image"""
|
||||
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
|
||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||
return FileResponse(path)
|
||||
else:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_type: ImageType = Path(description="The type of image to delete"),
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image and its thumbnail"""
|
||||
|
||||
ApiDependencies.invoker.services.images.delete(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail"
|
||||
)
|
||||
async def get_thumbnail(
|
||||
thumbnail_type: ImageType = Path(description="The type of thumbnail to get"),
|
||||
thumbnail_name: str = Path(description="The name of the thumbnail to get"),
|
||||
) -> FileResponse | Response:
|
||||
"""Gets a thumbnail"""
|
||||
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True
|
||||
)
|
||||
|
||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||
return FileResponse(path)
|
||||
else:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/uploads/",
|
||||
"/",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The image was uploaded successfully",
|
||||
"model": ImageResponse,
|
||||
},
|
||||
201: {"description": "The image was uploaded successfully"},
|
||||
415: {"description": "Image upload failed"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def upload_image(
|
||||
file: UploadFile, request: Request, response: Response
|
||||
) -> ImageResponse:
|
||||
file: UploadFile,
|
||||
image_type: ImageType,
|
||||
request: Request,
|
||||
response: Response,
|
||||
image_category: ImageCategory = ImageCategory.IMAGE,
|
||||
) -> ImageRecord:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
@@ -96,53 +45,33 @@ async def upload_image(
|
||||
# Error opening the image
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
try:
|
||||
image_record = ApiDependencies.invoker.services.images_new.create(
|
||||
image=img,
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
)
|
||||
|
||||
saved_image = ApiDependencies.invoker.services.images.save(
|
||||
ImageType.UPLOAD, filename, img
|
||||
)
|
||||
response.status_code = 201
|
||||
response.headers["Location"] = image_record.image_url
|
||||
|
||||
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
||||
|
||||
image_url = ApiDependencies.invoker.services.images.get_uri(
|
||||
ImageType.UPLOAD, saved_image.image_name
|
||||
)
|
||||
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
|
||||
ImageType.UPLOAD, saved_image.image_name, True
|
||||
)
|
||||
|
||||
res = ImageResponse(
|
||||
image_type=ImageType.UPLOAD,
|
||||
image_name=saved_image.image_name,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
metadata=ImageResponseMetadata(
|
||||
created=saved_image.created,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
invokeai=invokeai_metadata,
|
||||
),
|
||||
)
|
||||
|
||||
response.status_code = 201
|
||||
response.headers["Location"] = image_url
|
||||
|
||||
return res
|
||||
return image_record
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/",
|
||||
operation_id="list_images",
|
||||
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
||||
)
|
||||
async def list_images(
|
||||
image_type: ImageType = Query(
|
||||
default=ImageType.RESULT, description="The type of images to get"
|
||||
),
|
||||
page: int = Query(default=0, description="The page of images to get"),
|
||||
per_page: int = Query(default=10, description="The number of images per page"),
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
"""Gets a list of images"""
|
||||
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
|
||||
return result
|
||||
|
||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
||||
async def delete_image_record(
|
||||
image_type: ImageType = Query(description="The type of image to delete"),
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image record"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.images_new.delete(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
42
invokeai/app/api/routers/results.py
Normal file
42
invokeai/app/api/routers/results.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from fastapi import HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.results import ResultType, ResultWithSession
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
results_router = APIRouter(prefix="/v1/results", tags=["results"])
|
||||
|
||||
|
||||
@results_router.get("/{result_type}/{result_name}", operation_id="get_result")
|
||||
async def get_result(
|
||||
result_type: ResultType = Path(description="The type of result to get"),
|
||||
result_name: str = Path(description="The name of the result to get"),
|
||||
) -> ResultWithSession:
|
||||
"""Gets a result"""
|
||||
|
||||
result = ApiDependencies.invoker.services.results.get(
|
||||
result_id=result_name, result_type=result_type
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@results_router.get(
|
||||
"/",
|
||||
operation_id="list_results",
|
||||
responses={200: {"model": PaginatedResults[ResultWithSession]}},
|
||||
)
|
||||
async def list_results(
|
||||
result_type: ResultType = Query(description="The type of results to get"),
|
||||
page: int = Query(default=0, description="The page of results to get"),
|
||||
per_page: int = Query(default=10, description="The number of results per page"),
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
"""Gets a list of results"""
|
||||
results = ApiDependencies.invoker.services.results.get_many(
|
||||
result_type=result_type, page=page, per_page=per_page
|
||||
)
|
||||
return results
|
||||
@@ -3,6 +3,7 @@ import asyncio
|
||||
from inspect import signature
|
||||
|
||||
import uvicorn
|
||||
from invokeai.app.models import resources
|
||||
import invokeai.backend.util.logging as logger
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -15,10 +16,11 @@ from pydantic.schema import schema
|
||||
|
||||
from ..backend import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions, models
|
||||
from .api.routers import image_files, image_records, sessions, models
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
||||
@@ -74,10 +76,12 @@ async def shutdown_event():
|
||||
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
app.include_router(image_files.image_files_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(image_records.image_records_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
@@ -145,6 +149,12 @@ def overridden_redoc():
|
||||
)
|
||||
|
||||
|
||||
# Must mount *after* the other routes else it borks em
|
||||
app.mount(
|
||||
"/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui"
|
||||
)
|
||||
|
||||
|
||||
def invoke_api():
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
|
||||
@@ -28,7 +28,7 @@ from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.image_file_storage import DiskImageFileStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
@@ -214,7 +214,7 @@ def invoke_cli():
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
|
||||
images=DiskImageFileStorage(f'{output_folder}/images', metadata_service=metadata),
|
||||
metadata=metadata,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.invocation_services import InvocationServices
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import numpy.random
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
BaseInvocationOutput,
|
||||
)
|
||||
@@ -50,11 +50,11 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
seed: Optional[int] = Field(
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
le=np.iinfo(np.int32).max,
|
||||
description="The seed for the RNG",
|
||||
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
|
||||
le=SEED_MAX,
|
||||
description="The seed for the RNG (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
|
||||
@@ -100,7 +100,8 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
# TODO: support legacy blend?
|
||||
|
||||
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
|
||||
conjunction = Compel.parse_prompt_string(prompt_str)
|
||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||
|
||||
if getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
@@ -119,7 +120,7 @@ class CompelInvocation(BaseInvocation):
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.set(conditioning_name, (c, ec))
|
||||
context.services.latents.save(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.image import ImageCategory, ImageType
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
@@ -17,7 +20,8 @@ from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
|
||||
|
||||
class SDImageInvocation(BaseModel):
|
||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||
@@ -44,15 +48,13 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
@@ -90,24 +92,42 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# each time it is called. We only need the first one.
|
||||
generate_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images_new.create(
|
||||
image=generate_output.image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_category=ImageCategory.IMAGE,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
# image_type = ImageType.RESULT
|
||||
# image_name = context.services.images.create_name(
|
||||
# context.graph_execution_state_id, self.id
|
||||
# )
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
# metadata = context.services.metadata.build_metadata(
|
||||
# session_id=context.graph_execution_state_id, node=self
|
||||
# )
|
||||
|
||||
# context.services.images.save(
|
||||
# image_type, image_name, generate_output.image, metadata
|
||||
# )
|
||||
|
||||
# context.services.images_db.set(
|
||||
# id=image_name,
|
||||
# image_type=ImageType.RESULT,
|
||||
# image_category=ImageCategory.IMAGE,
|
||||
# session_id=context.graph_execution_state_id,
|
||||
# node_id=self.id,
|
||||
# metadata=GeneratedImageOrLatentsMetadata(),
|
||||
# )
|
||||
|
||||
context.services.images.save(
|
||||
image_type, image_name, generate_output.image, metadata
|
||||
)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_name=image_dto.image_name,
|
||||
image=generate_output.image,
|
||||
)
|
||||
|
||||
@@ -148,7 +168,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = None
|
||||
|
||||
if self.fit:
|
||||
image = image.resize((self.width, self.height))
|
||||
@@ -165,7 +184,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
@@ -197,7 +215,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
image=result_image,
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
@@ -205,6 +222,17 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
||||
seam_strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||
)
|
||||
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
||||
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
|
||||
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
|
||||
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
|
||||
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
|
||||
inpaint_replace: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import io
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
@@ -30,16 +31,14 @@ class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image"] = "image"
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
|
||||
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["type", "image", "width", "height", "mode"]
|
||||
}
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
def build_image_output(
|
||||
@@ -54,7 +53,6 @@ def build_image_output(
|
||||
image=image_field,
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
mode=image.mode,
|
||||
)
|
||||
|
||||
|
||||
@@ -151,7 +149,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
|
||||
context.services.images.save(image_type, image_name, image_crop, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
@@ -209,7 +207,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
|
||||
context.services.images.save(image_type, image_name, new_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
|
||||
233
invokeai/app/invocations/infill.py
Normal file
233
invokeai/app/invocations/infill.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput, build_image_output
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ColorField, ImageField, ImageType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
)
|
||||
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
methods = [
|
||||
"tile",
|
||||
"solid",
|
||||
]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, "patchmatch")
|
||||
return methods
|
||||
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
|
||||
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
# Skip patchmatch if patchmatch isn't available
|
||||
if not PatchMatch.patchmatch_available():
|
||||
return im
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(
|
||||
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||
)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
|
||||
|
||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
|
||||
def tile_fill_missing(
|
||||
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = tiles_mask > 0
|
||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
||||
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
||||
]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1, 2)
|
||||
st = tiles_all.reshape(
|
||||
(
|
||||
math.prod(tiles_all.shape[0:2]),
|
||||
math.prod(tiles_all.shape[2:4]),
|
||||
tiles_all.shape[4],
|
||||
)
|
||||
)
|
||||
si = Image.fromarray(st, mode="RGBA")
|
||||
|
||||
return si
|
||||
|
||||
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
type: Literal["infill_rgba"] = "infill_rgba"
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
color: Optional[ColorField] = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image)
|
||||
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
type: Literal["infill_tile"] = "infill_tile"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
infilled = tile_fill_missing(
|
||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||
)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(image.copy())
|
||||
else:
|
||||
raise ValueError("PatchMatch is not available on this system")
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
)
|
||||
@@ -1,11 +1,13 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import random
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, Union
|
||||
import einops
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
@@ -13,10 +15,12 @@ from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.image_file_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput, build_image_output
|
||||
from .compel import ConditioningField
|
||||
@@ -37,41 +41,55 @@ class LatentsField(BaseModel):
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
#fmt: off
|
||||
type: Literal["latent_output"] = "latent_output"
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
type: Literal["latents_output"] = "latents_output"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
width: int = Field(description="The width of the latents in pixels")
|
||||
height: int = Field(description="The height of the latents in pixels")
|
||||
#fmt: on
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
#fmt: off
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
|
||||
# Inputs
|
||||
noise: LatentsField = Field(default=None, description="The output noise")
|
||||
width: int = Field(description="The width of the noise in pixels")
|
||||
height: int = Field(description="The height of the noise in pixels")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# TODO: this seems like a hack
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(list(scheduler_map.keys()))
|
||||
tuple(list(SCHEDULER_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
@@ -102,17 +120,13 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
|
||||
return x
|
||||
|
||||
|
||||
def random_seed():
|
||||
return random.randint(0, np.iinfo(np.uint32).max)
|
||||
|
||||
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
@@ -130,10 +144,8 @@ class NoiseInvocation(BaseInvocation):
|
||||
noise = get_noise(self.width, self.height, device, self.seed)
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, noise)
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=name)
|
||||
)
|
||||
context.services.latents.save(name, noise)
|
||||
return build_noise_output(latents_name=name, latents=noise)
|
||||
|
||||
|
||||
# Text to image
|
||||
@@ -149,11 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
@@ -218,7 +229,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
|
||||
return conditioning_data
|
||||
|
||||
|
||||
@@ -249,10 +260,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
@@ -260,6 +269,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
@@ -271,10 +284,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
},
|
||||
}
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
@@ -287,7 +296,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
@@ -295,11 +304,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
device=model.device,
|
||||
)
|
||||
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
@@ -314,10 +319,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
# Latent to image
|
||||
@@ -381,11 +384,11 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
type: Literal["lresize"] = "lresize"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
@@ -401,8 +404,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, resized_latents)
|
||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
@@ -411,10 +414,10 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
type: Literal["lscale"] = "lscale"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
@@ -431,5 +434,49 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, resized_latents)
|
||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
type: Literal["i2l"] = "i2l"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info["model"]
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = model.non_noised_latents_from_image(
|
||||
image_tensor,
|
||||
device=model._model_group.device_for(model.unet),
|
||||
dtype=model.unet.dtype,
|
||||
)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
@@ -3,8 +3,14 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
|
||||
class MathInvocationConfig(BaseModel):
|
||||
@@ -21,19 +27,21 @@ class MathInvocationConfig(BaseModel):
|
||||
|
||||
class IntOutput(BaseInvocationOutput):
|
||||
"""An integer output"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["int_output"] = "int_output"
|
||||
a: int = Field(default=None, description="The output integer")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Adds two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["add"] = "add"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a + self.b)
|
||||
@@ -41,11 +49,12 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Subtracts two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sub"] = "sub"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a - self.b)
|
||||
@@ -53,11 +62,12 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Multiplies two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mul"] = "mul"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a * self.b)
|
||||
@@ -65,11 +75,26 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Divides two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["div"] = "div"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
|
||||
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["rand_int"] = "rand_int"
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=np.random.randint(self.low, self.high))
|
||||
|
||||
@@ -4,10 +4,11 @@ from invokeai.backend.model_management.model_manager import ModelManager
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
logger = model_manager.logger
|
||||
if model_manager.valid_model(model_name):
|
||||
model = model_manager.get_model(model_name)
|
||||
else:
|
||||
if model_name and not model_manager.valid_model(model_name):
|
||||
default_model_name = model_manager.default_model()
|
||||
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
|
||||
model = model_manager.get_model()
|
||||
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
|
||||
else:
|
||||
model = model_manager.get_model(model_name)
|
||||
|
||||
return model
|
||||
|
||||
@@ -1,12 +1,24 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.enum import MetaEnum
|
||||
|
||||
|
||||
class ImageType(str, Enum, metaclass=MetaEnum):
|
||||
"""The type of an image."""
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
INTERMEDIATE = "intermediates"
|
||||
|
||||
|
||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
|
||||
|
||||
IMAGE = "image"
|
||||
CONTROL_IMAGE = "control_image"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
def is_image_type(obj):
|
||||
@@ -27,3 +39,13 @@ class ImageField(BaseModel):
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_type", "image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
59
invokeai/app/models/metadata.py
Normal file
59
invokeai/app/models/metadata.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""
|
||||
Core generation metadata for an image/tensor generated in InvokeAI.
|
||||
|
||||
Also includes any metadata from the image's PNG tEXt chunks.
|
||||
|
||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
|
||||
|
||||
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
||||
"""
|
||||
|
||||
positive_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The positive conditioning."
|
||||
)
|
||||
negative_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The negative conditioning."
|
||||
)
|
||||
width: Optional[StrictInt] = Field(
|
||||
default=None, description="Width of the image/tensor in pixels."
|
||||
)
|
||||
height: Optional[StrictInt] = Field(
|
||||
default=None, description="Height of the image/tensor in pixels."
|
||||
)
|
||||
seed: Optional[StrictInt] = Field(
|
||||
default=None, description="The seed used for noise generation."
|
||||
)
|
||||
cfg_scale: Optional[StrictFloat] = Field(
|
||||
default=None, description="The classifier-free guidance scale."
|
||||
)
|
||||
steps: Optional[StrictInt] = Field(
|
||||
default=None, description="The number of steps used for inference."
|
||||
)
|
||||
scheduler: Optional[StrictStr] = Field(
|
||||
default=None, description="The scheduler used for inference."
|
||||
)
|
||||
model: Optional[StrictStr] = Field(
|
||||
default=None, description="The model used for inference."
|
||||
)
|
||||
strength: Optional[StrictFloat] = Field(
|
||||
default=None,
|
||||
description="The strength used for image-to-image/tensor-to-tensor.",
|
||||
)
|
||||
image: Optional[StrictStr] = Field(
|
||||
default=None, description="The ID of the initial image."
|
||||
)
|
||||
tensor: Optional[StrictStr] = Field(
|
||||
default=None, description="The ID of the initial tensor."
|
||||
)
|
||||
# Pending model refactor:
|
||||
# vae: Optional[str] = Field(default=None,description="The VAE used for decoding.")
|
||||
# unet: Optional[str] = Field(default=None,description="The UNet used dor inference.")
|
||||
# clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.")
|
||||
extra: Optional[StrictStr] = Field(
|
||||
default=None, description="Extra metadata, extracted from the PNG tEXt chunk."
|
||||
)
|
||||
28
invokeai/app/models/resources.py
Normal file
28
invokeai/app/models/resources.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# TODO: Make a new model for this
|
||||
from enum import Enum
|
||||
|
||||
from invokeai.app.util.enum import MetaEnum
|
||||
|
||||
|
||||
class ResourceType(str, Enum, metaclass=MetaEnum):
|
||||
"""The type of a resource."""
|
||||
|
||||
IMAGES = "images"
|
||||
TENSORS = "tensors"
|
||||
|
||||
|
||||
# class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
# """The origin of a resource (eg image or tensor)."""
|
||||
|
||||
# RESULTS = "results"
|
||||
# UPLOADS = "uploads"
|
||||
# INTERMEDIATES = "intermediates"
|
||||
|
||||
|
||||
|
||||
class TensorKind(str, Enum, metaclass=MetaEnum):
|
||||
"""The kind of a tensor. Use TensorKind.OTHER for non-default kinds."""
|
||||
|
||||
IMAGE_LATENTS = "image_latents"
|
||||
CONDITIONING = "conditioning"
|
||||
OTHER = "other"
|
||||
578
invokeai/app/services/db.ipynb
Normal file
578
invokeai/app/services/db.ipynb
Normal file
@@ -0,0 +1,578 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from abc import ABC, abstractmethod\n",
|
||||
"from enum import Enum\n",
|
||||
"import enum\n",
|
||||
"import sqlite3\n",
|
||||
"import threading\n",
|
||||
"from typing import Optional, Type, TypeVar, Union\n",
|
||||
"from PIL.Image import Image as PILImage\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"from torch import Tensor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"class ResourceOrigin(str, Enum):\n",
|
||||
" \"\"\"The origin of a resource (eg image or tensor).\"\"\"\n",
|
||||
"\n",
|
||||
" RESULTS = \"results\"\n",
|
||||
" UPLOADS = \"uploads\"\n",
|
||||
" INTERMEDIATES = \"intermediates\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ImageKind(str, Enum):\n",
|
||||
" \"\"\"The kind of an image. Use ImageKind.OTHER for non-default kinds.\"\"\"\n",
|
||||
"\n",
|
||||
" IMAGE = \"image\"\n",
|
||||
" CONTROL_IMAGE = \"control_image\"\n",
|
||||
" OTHER = \"other\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class TensorKind(str, Enum):\n",
|
||||
" \"\"\"The kind of a tensor. Use TensorKind.OTHER for non-default kinds.\"\"\"\n",
|
||||
"\n",
|
||||
" IMAGE_LATENTS = \"image_latents\"\n",
|
||||
" CONDITIONING = \"conditioning\"\n",
|
||||
" OTHER = \"other\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"def create_sql_values_string_from_string_enum(enum: Type[Enum]):\n",
|
||||
" \"\"\"\n",
|
||||
" Creates a string of the form \"('value1'), ('value2'), ..., ('valueN')\" from a StrEnum.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" delimiter = \", \"\n",
|
||||
" values = [f\"('{e.value}')\" for e in enum]\n",
|
||||
" return delimiter.join(values)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_sql_table_from_enum(\n",
|
||||
" enum: Type[Enum],\n",
|
||||
" table_name: str,\n",
|
||||
" primary_key_name: str,\n",
|
||||
" conn: sqlite3.Connection,\n",
|
||||
" cursor: sqlite3.Cursor,\n",
|
||||
" lock: threading.Lock,\n",
|
||||
"):\n",
|
||||
" \"\"\"\n",
|
||||
" Creates and populates a table to be used as a functional enum.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" values_string = create_sql_values_string_from_string_enum(enum)\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" f\"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS {table_name} (\n",
|
||||
" {primary_key_name} TEXT PRIMARY KEY\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" f\"\"\"--sql\n",
|
||||
" INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`resource_origins` functions as an enum for the ResourceOrigin model.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# def create_resource_origins_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
"# create_sql_table_from_enum(\n",
|
||||
"# enum=ResourceOrigin,\n",
|
||||
"# table_name=\"resource_origins\",\n",
|
||||
"# primary_key_name=\"origin_name\",\n",
|
||||
"# conn=conn,\n",
|
||||
"# cursor=cursor,\n",
|
||||
"# lock=lock,\n",
|
||||
"# )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`image_kinds` functions as an enum for the ImageType model.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# def create_image_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" # create_sql_table_from_enum(\n",
|
||||
" # enum=ImageKind,\n",
|
||||
" # table_name=\"image_kinds\",\n",
|
||||
" # primary_key_name=\"kind_name\",\n",
|
||||
" # conn=conn,\n",
|
||||
" # cursor=cursor,\n",
|
||||
" # lock=lock,\n",
|
||||
" # )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`tensor_kinds` functions as an enum for the TensorType model.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# def create_tensor_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" # create_sql_table_from_enum(\n",
|
||||
" # enum=TensorKind,\n",
|
||||
" # table_name=\"tensor_kinds\",\n",
|
||||
" # primary_key_name=\"kind_name\",\n",
|
||||
" # conn=conn,\n",
|
||||
" # cursor=cursor,\n",
|
||||
" # lock=lock,\n",
|
||||
" # )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`images` stores all images, regardless of type\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_images_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS images (\n",
|
||||
" id TEXT PRIMARY KEY,\n",
|
||||
" origin TEXT,\n",
|
||||
" image_kind TEXT,\n",
|
||||
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
|
||||
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
|
||||
" FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`images_results` stores additional data specific to `results` images.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_images_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS images_results (\n",
|
||||
" images_id TEXT PRIMARY KEY,\n",
|
||||
" session_id TEXT NOT NULL,\n",
|
||||
" node_id TEXT NOT NULL,\n",
|
||||
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_results_images_id ON images_results(images_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`images_intermediates` stores additional data specific to `intermediates` images\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_images_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS images_intermediates (\n",
|
||||
" images_id TEXT PRIMARY KEY,\n",
|
||||
" session_id TEXT NOT NULL,\n",
|
||||
" node_id TEXT NOT NULL,\n",
|
||||
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_intermediates_images_id ON images_intermediates(images_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`images_metadata` stores basic metadata for any image type\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_images_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS images_metadata (\n",
|
||||
" images_id TEXT PRIMARY KEY,\n",
|
||||
" metadata TEXT,\n",
|
||||
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `tensors` table: stores references to tensor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_tensors_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS tensors (\n",
|
||||
" id TEXT PRIMARY KEY,\n",
|
||||
" origin TEXT,\n",
|
||||
" tensor_kind TEXT,\n",
|
||||
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
|
||||
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
|
||||
" FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name)\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `tensors_results` stores additional data specific to `result` tensor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_tensors_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS tensors_results (\n",
|
||||
" tensors_id TEXT PRIMARY KEY,\n",
|
||||
" session_id TEXT NOT NULL,\n",
|
||||
" node_id TEXT NOT NULL,\n",
|
||||
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_results_tensors_id ON tensors_results(tensors_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `tensors_intermediates` stores additional data specific to `intermediate` tensor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_tensors_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS tensors_intermediates (\n",
|
||||
" tensors_id TEXT PRIMARY KEY,\n",
|
||||
" session_id TEXT NOT NULL,\n",
|
||||
" node_id TEXT NOT NULL,\n",
|
||||
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_intermediates_tensors_id ON tensors_intermediates(tensors_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `tensors_metadata` table: stores generated/transformed metadata for tensor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_tensors_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS tensors_metadata (\n",
|
||||
" tensors_id TEXT PRIMARY KEY,\n",
|
||||
" metadata TEXT,\n",
|
||||
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensors_id ON tensors_metadata(tensors_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"db_path = '/home/bat/Documents/Code/outputs/test.db'\n",
|
||||
"if (os.path.exists(db_path)):\n",
|
||||
" os.remove(db_path)\n",
|
||||
"\n",
|
||||
"conn = sqlite3.connect(\n",
|
||||
" db_path, check_same_thread=False\n",
|
||||
")\n",
|
||||
"cursor = conn.cursor()\n",
|
||||
"lock = threading.Lock()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"create_sql_table_from_enum(\n",
|
||||
" enum=ResourceOrigin,\n",
|
||||
" table_name=\"resource_origins\",\n",
|
||||
" primary_key_name=\"origin_name\",\n",
|
||||
" conn=conn,\n",
|
||||
" cursor=cursor,\n",
|
||||
" lock=lock,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"create_sql_table_from_enum(\n",
|
||||
" enum=ImageKind,\n",
|
||||
" table_name=\"image_kinds\",\n",
|
||||
" primary_key_name=\"kind_name\",\n",
|
||||
" conn=conn,\n",
|
||||
" cursor=cursor,\n",
|
||||
" lock=lock,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"create_sql_table_from_enum(\n",
|
||||
" enum=TensorKind,\n",
|
||||
" table_name=\"tensor_kinds\",\n",
|
||||
" primary_key_name=\"kind_name\",\n",
|
||||
" conn=conn,\n",
|
||||
" cursor=cursor,\n",
|
||||
" lock=lock,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"create_images_table(conn, cursor, lock)\n",
|
||||
"create_images_results_table(conn, cursor, lock)\n",
|
||||
"create_images_intermediates_table(conn, cursor, lock)\n",
|
||||
"create_images_metadata_table(conn, cursor, lock)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"create_tensors_table(conn, cursor, lock)\n",
|
||||
"create_tensors_results_table(conn, cursor, lock)\n",
|
||||
"create_tensors_intermediates_table(conn, cursor, lock)\n",
|
||||
"create_tensors_metadata_table(conn, cursor, lock)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 59,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"from pydantic import StrictStr\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GeneratedImageOrLatentsMetadata(BaseModel):\n",
|
||||
" \"\"\"Core generation metadata for an image/tensor generated in InvokeAI.\n",
|
||||
"\n",
|
||||
" Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.\n",
|
||||
"\n",
|
||||
" Full metadata may be accessed by querying for the session in the `graph_executions` table.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" positive_conditioning: Optional[StrictStr] = Field(\n",
|
||||
" default=None, description=\"The positive conditioning.\"\n",
|
||||
" )\n",
|
||||
" negative_conditioning: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The negative conditioning.\"\n",
|
||||
" )\n",
|
||||
" width: Optional[int] = Field(\n",
|
||||
" default=None, description=\"Width of the image/tensor in pixels.\"\n",
|
||||
" )\n",
|
||||
" height: Optional[int] = Field(\n",
|
||||
" default=None, description=\"Height of the image/tensor in pixels.\"\n",
|
||||
" )\n",
|
||||
" seed: Optional[int] = Field(\n",
|
||||
" default=None, description=\"The seed used for noise generation.\"\n",
|
||||
" )\n",
|
||||
" cfg_scale: Optional[float] = Field(\n",
|
||||
" default=None, description=\"The classifier-free guidance scale.\"\n",
|
||||
" )\n",
|
||||
" steps: Optional[int] = Field(\n",
|
||||
" default=None, description=\"The number of steps used for inference.\"\n",
|
||||
" )\n",
|
||||
" scheduler: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The scheduler used for inference.\"\n",
|
||||
" )\n",
|
||||
" model: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The model used for inference.\"\n",
|
||||
" )\n",
|
||||
" strength: Optional[float] = Field(\n",
|
||||
" default=None,\n",
|
||||
" description=\"The strength used for image-to-image/tensor-to-tensor.\",\n",
|
||||
" )\n",
|
||||
" image: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The ID of the initial image.\"\n",
|
||||
" )\n",
|
||||
" tensor: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The ID of the initial tensor.\"\n",
|
||||
" )\n",
|
||||
" # Pending model refactor:\n",
|
||||
" # vae: Optional[str] = Field(default=None,description=\"The VAE used for decoding.\")\n",
|
||||
" # unet: Optional[str] = Field(default=None,description=\"The UNet used dor inference.\")\n",
|
||||
" # clip: Optional[str] = Field(default=None,description=\"The CLIP Encoder used for conditioning.\")\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"GeneratedImageOrLatentsMetadata(positive_conditioning='123', negative_conditioning=None, width=None, height=None, seed=None, cfg_scale=None, steps=None, scheduler=None, model=None, strength=None, image=None, tensor=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"GeneratedImageOrLatentsMetadata(positive_conditioning='123')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -48,13 +48,14 @@ def create_text_to_image() -> LibraryGraph:
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
|
||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
#if text_to_image is None:
|
||||
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
# #if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
@@ -50,6 +50,8 @@ class EventServiceBase:
|
||||
result: dict,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
image_url: Optional[str] = None,
|
||||
thumbnail_url: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
@@ -59,6 +61,8 @@ class EventServiceBase:
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
result=result,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
180
invokeai/app/services/image_file_storage.py
Normal file
180
invokeai/app/services/image_file_storage.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
class ImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
class ImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
class ImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image file not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
# # TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the internal path to an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
pnginfo: Optional[PngInfo] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
|
||||
class DiskImageFileStorage(ImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, PILImageType]
|
||||
__max_cache_size: int
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_type in ImageType:
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
try:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
except FileNotFoundError as e:
|
||||
raise ImageFileStorageBase.ImageFileNotFoundException from e
|
||||
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
pnginfo: Optional[PngInfo] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
self.__set_cache(image_path, image)
|
||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||
except Exception as e:
|
||||
raise ImageFileStorageBase.ImageFileSaveException from e
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
try:
|
||||
basename = os.path.basename(image_name)
|
||||
image_path = self.get_path(image_type, basename)
|
||||
|
||||
if os.path.exists(image_path):
|
||||
send2trash(image_path)
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
send2trash(thumbnail_path)
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
except Exception as e:
|
||||
raise ImageFileStorageBase.ImageFileDeleteException from e
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if thumbnail:
|
||||
thumbnail_name = get_thumbnail_name(basename)
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_type, "thumbnails", thumbnail_name
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, basename)
|
||||
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
return abspath
|
||||
|
||||
def __get_cache(self, image_name: str) -> PILImageType | None:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: PILImageType):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
if cache_id in self.__cache:
|
||||
del self.__cache[cache_id]
|
||||
318
invokeai/app/services/image_record_storage.py
Normal file
318
invokeai/app/services/image_record_storage.py
Normal file
@@ -0,0 +1,318 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import datetime
|
||||
from typing import Optional
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ImageType,
|
||||
)
|
||||
from invokeai.app.services.util.create_enum_table import create_enum_table
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
deserialize_image_record,
|
||||
)
|
||||
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
|
||||
class ImageRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the image record store."""
|
||||
|
||||
class ImageRecordNotFoundException(Exception):
|
||||
"""Raised when an image record is not found."""
|
||||
|
||||
def __init__(self, message="Image record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
class ImageRecordSaveException(Exception):
|
||||
"""Raised when an image record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
class ImageRecordDeleteException(Exception):
|
||||
"""Raised when an image record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
"""Deletes an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
created_at: str = datetime.datetime.utcnow().isoformat(),
|
||||
) -> None:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the tables for the `images` database."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
id TEXT PRIMARY KEY,
|
||||
image_type TEXT, -- non-nullable via foreign key constraint
|
||||
image_category TEXT, -- non-nullable via foreign key constraint
|
||||
session_id TEXT, -- nullable
|
||||
node_id TEXT, -- nullable
|
||||
metadata TEXT, -- nullable
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY(image_type) REFERENCES image_types(type_name),
|
||||
FOREIGN KEY(image_category) REFERENCES image_categories(category_name)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the tables for image-related enums
|
||||
create_enum_table(
|
||||
enum=ImageType,
|
||||
table_name="image_types",
|
||||
primary_key_name="type_name",
|
||||
cursor=self._cursor,
|
||||
)
|
||||
|
||||
create_enum_table(
|
||||
enum=ImageCategory,
|
||||
table_name="image_categories",
|
||||
primary_key_name="category_name",
|
||||
cursor=self._cursor,
|
||||
)
|
||||
|
||||
# Create the `tags` table. TODO: do this elsewhere, shouldn't be in images db service
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
tag_name TEXT UNIQUE NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images_tags` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images_tags (
|
||||
image_id TEXT,
|
||||
tag_id INTEGER,
|
||||
PRIMARY KEY (image_id, tag_id),
|
||||
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY(tag_id) REFERENCES tags(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images_favorites` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images_favorites (
|
||||
image_id TEXT PRIMARY KEY,
|
||||
favorited_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM images
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = self._cursor.fetchone()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise self.ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
raise self.ImageRecordNotFoundException
|
||||
|
||||
return deserialize_image_record(result)
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM images
|
||||
WHERE image_type = ? AND image_category = ?
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(image_type.value, image_category.value, per_page, page * per_page),
|
||||
)
|
||||
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
images = list(map(lambda r: deserialize_image_record(r), result))
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*) FROM images
|
||||
WHERE image_type = ? AND image_category = ?
|
||||
""",
|
||||
(image_type.value, image_category.value),
|
||||
)
|
||||
|
||||
count = self._cursor.fetchone()[0]
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults(
|
||||
items=images, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordStorageBase.ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
created_at: str,
|
||||
) -> None:
|
||||
try:
|
||||
metadata_json = (
|
||||
None if metadata is None else metadata.json(exclude_none=True)
|
||||
)
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
id,
|
||||
image_type,
|
||||
image_category,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
created_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_type.value,
|
||||
image_category.value,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata_json,
|
||||
created_at,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordStorageBase.ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
@@ -1,273 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, List
|
||||
|
||||
from PIL.Image import Image
|
||||
import PIL.Image as PILImage
|
||||
from send2trash import send2trash
|
||||
from invokeai.app.api.models.images import (
|
||||
ImageResponse,
|
||||
ImageResponseMetadata,
|
||||
SavedImage,
|
||||
)
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.services.metadata import (
|
||||
InvokeAIMetadata,
|
||||
MetadataServiceBase,
|
||||
build_invokeai_metadata_pnginfo,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
"""Gets a paginated list of images."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the internal path to an image or its thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_uri(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the external URI to an image or its thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates an image path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
image: Image,
|
||||
metadata: InvokeAIMetadata | None = None,
|
||||
) -> SavedImage:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
def create_name(self, context_id: str, node_id: str) -> str:
|
||||
"""Creates a unique contextual image filename."""
|
||||
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
|
||||
|
||||
|
||||
class DiskImageStorage(ImageStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, Image]
|
||||
__max_cache_size: int
|
||||
__metadata_service: MetadataServiceBase
|
||||
|
||||
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
|
||||
self.__output_folder = output_folder
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
self.__metadata_service = metadata_service
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_type in ImageType:
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
dir_path = os.path.join(self.__output_folder, image_type)
|
||||
image_paths = glob(f"{dir_path}/*.png")
|
||||
count = len(image_paths)
|
||||
|
||||
sorted_image_paths = sorted(
|
||||
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
|
||||
)
|
||||
|
||||
page_of_image_paths = sorted_image_paths[
|
||||
page * per_page : (page + 1) * per_page
|
||||
]
|
||||
|
||||
page_of_images: List[ImageResponse] = []
|
||||
|
||||
for path in page_of_image_paths:
|
||||
filename = os.path.basename(path)
|
||||
img = PILImage.open(path)
|
||||
|
||||
invokeai_metadata = self.__metadata_service.get_metadata(img)
|
||||
|
||||
page_of_images.append(
|
||||
ImageResponse(
|
||||
image_type=image_type.value,
|
||||
image_name=filename,
|
||||
# TODO: DiskImageStorage should not be building URLs...?
|
||||
image_url=self.get_uri(image_type, filename),
|
||||
thumbnail_url=self.get_uri(image_type, filename, True),
|
||||
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
|
||||
metadata=ImageResponseMetadata(
|
||||
created=int(os.path.getctime(path)),
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
invokeai=invokeai_metadata,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
page_count_trunc = int(count / per_page)
|
||||
page_count_mod = count % per_page
|
||||
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
||||
|
||||
return PaginatedResults[ImageResponse](
|
||||
items=page_of_images,
|
||||
page=page,
|
||||
pages=page_count,
|
||||
per_page=per_page,
|
||||
total=count,
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = PILImage.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if is_thumbnail:
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_type, "thumbnails", basename
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, basename)
|
||||
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
return abspath
|
||||
|
||||
def get_uri(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if is_thumbnail:
|
||||
thumbnail_basename = get_thumbnail_name(basename)
|
||||
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
|
||||
else:
|
||||
uri = f"api/v1/images/{image_type.value}/{basename}"
|
||||
|
||||
return uri
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
os.stat(path)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def save(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
image: Image,
|
||||
metadata: InvokeAIMetadata | None = None,
|
||||
) -> SavedImage:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
|
||||
# TODO: Reading the image and then saving it strips the metadata...
|
||||
if metadata:
|
||||
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
else:
|
||||
image.save(image_path) # this saved image has an empty info
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
self.__set_cache(image_path, image)
|
||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||
|
||||
return SavedImage(
|
||||
image_name=image_name,
|
||||
thumbnail_name=thumbnail_name,
|
||||
created=int(os.path.getctime(image_path)),
|
||||
)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
basename = os.path.basename(image_name)
|
||||
image_path = self.get_path(image_type, basename)
|
||||
|
||||
if os.path.exists(image_path):
|
||||
send2trash(image_path)
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
send2trash(thumbnail_path)
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
|
||||
def __get_cache(self, image_name: str) -> Image | None:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: Image):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
del self.__cache[cache_id]
|
||||
355
invokeai/app/services/images.py
Normal file
355
invokeai/app/services/images.py
Normal file
@@ -0,0 +1,355 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
from logging import Logger
|
||||
from typing import Optional, Union
|
||||
import uuid
|
||||
from PIL.Image import Image as PILImageType
|
||||
from PIL import PngImagePlugin
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
ImageDTO,
|
||||
image_record_to_dto,
|
||||
)
|
||||
from invokeai.app.services.image_file_storage import ImageFileStorageBase
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
"""
|
||||
High-level service for image management.
|
||||
|
||||
Provides methods for creating, retrieving, and deleting images.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
"""Gets an image as a PIL image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
"""Gets an image's path"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, image_type: ImageType, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's or thumbnail's URL"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str):
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||
"""Adds a tag to an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||
"""Removes a tag from an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def favorite(self, image_type: ImageType, image_id: str) -> None:
|
||||
"""Favorites an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
|
||||
"""Unfavorites an image."""
|
||||
pass
|
||||
|
||||
|
||||
class ImageServiceDependencies:
|
||||
"""Service dependencies for the ImageService."""
|
||||
|
||||
records: ImageRecordStorageBase
|
||||
files: ImageFileStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self.records = image_record_storage
|
||||
self.files = image_file_storage
|
||||
self.metadata = metadata
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class ImageService(ImageServiceABC):
|
||||
_services: ImageServiceDependencies
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self._services = ImageServiceDependencies(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=url,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
) -> ImageDTO:
|
||||
image_name = self._create_image_name(
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
timestamp = get_iso_timestamp()
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
||||
else:
|
||||
pnginfo = None
|
||||
|
||||
try:
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
self._services.files.save(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
pnginfo=pnginfo,
|
||||
)
|
||||
|
||||
self._services.records.save(
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
created_at=timestamp,
|
||||
)
|
||||
|
||||
image_url = self._services.urls.get_image_url(image_type, image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(
|
||||
image_type, image_name, True
|
||||
)
|
||||
|
||||
return ImageDTO(
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
created_at=timestamp,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
except ImageRecordStorageBase.ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to save image record")
|
||||
raise
|
||||
except ImageFileStorageBase.ImageFileSaveException:
|
||||
self._services.logger.error("Failed to save image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem saving image record and file")
|
||||
raise e
|
||||
|
||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.files.get(image_type, image_name)
|
||||
except ImageFileStorageBase.ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image file")
|
||||
raise e
|
||||
|
||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.records.get(image_type, image_name)
|
||||
except ImageRecordStorageBase.ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image record")
|
||||
raise e
|
||||
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_type, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.records.get(image_type, image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_type, image_name),
|
||||
self._services.urls.get_image_url(image_type, image_name, True),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
except ImageRecordStorageBase.ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.records.get_many(
|
||||
image_type,
|
||||
image_category,
|
||||
page,
|
||||
per_page,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(image_type, r.image_name),
|
||||
self._services.urls.get_image_url(
|
||||
image_type, r.image_name, True
|
||||
),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
)
|
||||
|
||||
return PaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
page=results.page,
|
||||
pages=results.pages,
|
||||
per_page=results.per_page,
|
||||
total=results.total,
|
||||
)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting paginated image DTOs")
|
||||
raise e
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str):
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
try:
|
||||
self._services.files.delete(image_type, image_name)
|
||||
self._services.records.delete(image_type, image_name)
|
||||
except ImageRecordStorageBase.ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
except ImageFileStorageBase.ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image record and file")
|
||||
raise e
|
||||
|
||||
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||
raise NotImplementedError("The 'add_tag' method is not implemented yet.")
|
||||
|
||||
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||
raise NotImplementedError("The 'remove_tag' method is not implemented yet.")
|
||||
|
||||
def favorite(self, image_type: ImageType, image_id: str) -> None:
|
||||
raise NotImplementedError("The 'favorite' method is not implemented yet.")
|
||||
|
||||
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
|
||||
raise NotImplementedError("The 'unfavorite' method is not implemented yet.")
|
||||
|
||||
def _create_image_name(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Create a unique image name."""
|
||||
uuid_str = str(uuid.uuid4())
|
||||
|
||||
if node_id is not None and session_id is not None:
|
||||
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
|
||||
|
||||
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
|
||||
@@ -1,26 +1,32 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from typing import TYPE_CHECKING
|
||||
from logging import Logger
|
||||
|
||||
from typing import types
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.backend import ModelManager
|
||||
|
||||
from .events import EventServiceBase
|
||||
from .latent_storage import LatentsStorageBase
|
||||
from .image_storage import ImageStorageBase
|
||||
from .image_file_storage import ImageFileStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
|
||||
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
events: EventServiceBase
|
||||
latents: LatentsStorageBase
|
||||
images: ImageStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
images: ImageFileStorageBase
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
images_new: ImageService
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
@@ -28,26 +34,26 @@ class InvocationServices:
|
||||
processor: "InvocationProcessorABC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
logger: types.ModuleType,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_library: ItemStorageABC["LibraryGraph"],
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
logger: Logger,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageFileStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
images_new: ImageService,
|
||||
graph_library: ItemStorageABC["LibraryGraph"],
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.metadata = metadata
|
||||
self.queue = queue
|
||||
self.images_new = images_new
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
|
||||
@@ -16,7 +16,7 @@ class LatentsStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -47,8 +47,8 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.set(name, data)
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
@@ -80,7 +80,7 @@ class DiskLatentsStorage(LatentsStorageBase):
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
|
||||
@@ -20,9 +20,26 @@ class MetadataLatentsField(TypedDict):
|
||||
latents_name: str
|
||||
|
||||
|
||||
class MetadataColorField(TypedDict):
|
||||
"""Pydantic-less ColorField, used for metadata parsing"""
|
||||
|
||||
r: int
|
||||
g: int
|
||||
b: int
|
||||
a: int
|
||||
|
||||
|
||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
||||
NodeMetadata = Dict[
|
||||
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
|
||||
str,
|
||||
None
|
||||
| str
|
||||
| int
|
||||
| float
|
||||
| bool
|
||||
| MetadataImageField
|
||||
| MetadataLatentsField
|
||||
| MetadataColorField,
|
||||
]
|
||||
|
||||
|
||||
@@ -58,6 +75,11 @@ class MetadataServiceBase(ABC):
|
||||
"""Builds an InvokeAIMetadata object"""
|
||||
pass
|
||||
|
||||
# @abstractmethod
|
||||
# def create_metadata(self, session_id: str, node_id: str) -> dict:
|
||||
# """Creates metadata for a result"""
|
||||
# pass
|
||||
|
||||
|
||||
class PngMetadataService(MetadataServiceBase):
|
||||
"""Handles loading and building metadata for images."""
|
||||
|
||||
71
invokeai/app/services/models/image_record.py
Normal file
71
invokeai/app/services/models/image_record.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import datetime
|
||||
import sqlite3
|
||||
from typing import Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from invokeai.app.models.image import ImageCategory, ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
class ImageRecord(BaseModel):
|
||||
"""Deserialized image record."""
|
||||
|
||||
image_name: str = Field(description="The name of the image.")
|
||||
image_type: ImageType = Field(description="The type of the image.")
|
||||
image_category: ImageCategory = Field(description="The category of the image.")
|
||||
created_at: Union[datetime.datetime, str] = Field(
|
||||
description="The created timestamp of the image."
|
||||
)
|
||||
session_id: Optional[str] = Field(default=None, description="The session ID.")
|
||||
node_id: Optional[str] = Field(default=None, description="The node ID.")
|
||||
metadata: Optional[ImageMetadata] = Field(
|
||||
default=None, description="The image's metadata."
|
||||
)
|
||||
|
||||
|
||||
class ImageDTO(ImageRecord):
|
||||
"""Deserialized image record with URLs."""
|
||||
|
||||
image_url: str = Field(description="The URL of the image.")
|
||||
thumbnail_url: str = Field(description="The thumbnail URL of the image.")
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
image_name=image_record.image_name,
|
||||
image_type=image_record.image_type,
|
||||
image_category=image_record.image_category,
|
||||
created_at=image_record.created_at,
|
||||
session_id=image_record.session_id,
|
||||
node_id=image_record.node_id,
|
||||
metadata=image_record.metadata,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
|
||||
"""Deserializes an image record."""
|
||||
|
||||
image_dict = dict(image_row)
|
||||
|
||||
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
||||
|
||||
raw_metadata = image_dict.get("metadata", "{}")
|
||||
|
||||
metadata = ImageMetadata.parse_raw(raw_metadata)
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_dict.get("id", "unknown"),
|
||||
session_id=image_dict.get("session_id", None),
|
||||
node_id=image_dict.get("node_id", None),
|
||||
metadata=metadata,
|
||||
image_type=image_type,
|
||||
image_category=ImageCategory(
|
||||
image_dict.get("image_category", ImageCategory.IMAGE.value)
|
||||
),
|
||||
created_at=image_dict.get("created_at", get_iso_timestamp()),
|
||||
)
|
||||
@@ -1,11 +1,16 @@
|
||||
import time
|
||||
import traceback
|
||||
from threading import Event, Thread, BoundedSemaphore
|
||||
from typing import Any, TypeGuard
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.models.image import ImageType
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..models.exceptions import CanceledException
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
@@ -34,8 +39,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
while not stop_event.is_set():
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
except Exception as e:
|
||||
logger.debug("Exception while getting from queue: %s" % e)
|
||||
|
||||
if not queue_item: # Probably stopping
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
graph_execution_state = (
|
||||
@@ -80,12 +91,30 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
graph_execution_state
|
||||
)
|
||||
|
||||
def is_image_output(obj: Any) -> TypeGuard[ImageOutput]:
|
||||
return obj.__class__ == ImageOutput
|
||||
|
||||
outputs_dict = outputs.dict()
|
||||
|
||||
if is_image_output(outputs):
|
||||
image_url = self.__invoker.services.images_new.get_url(
|
||||
ImageType.RESULT, outputs.image.image_name
|
||||
)
|
||||
thumbnail_url = self.__invoker.services.images_new.get_url(
|
||||
ImageType.RESULT, outputs.image.image_name, True
|
||||
)
|
||||
else:
|
||||
image_url = None
|
||||
thumbnail_url = None
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.dict(),
|
||||
result=outputs_dict,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
@@ -124,7 +153,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
try:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
except Exception as e:
|
||||
logger.error("Error while invoking: %s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error=traceback.format_exc()
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
graph_execution_state.id
|
||||
|
||||
657
invokeai/app/services/proposeddesign.py
Normal file
657
invokeai/app/services/proposeddesign.py
Normal file
@@ -0,0 +1,657 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import enum
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Type, TypeVar, Union
|
||||
from PIL.Image import Image as PILImage
|
||||
from pydantic import BaseModel, Field
|
||||
from torch import Tensor
|
||||
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
|
||||
"""
|
||||
Substantial proposed changes to the management of images and tensor.
|
||||
|
||||
tl;dr:
|
||||
With the upcoming move to latents-only nodes, we need to handle metadata differently. After struggling with this unsuccessfully - trying to smoosh it in to the existing setup - I believe we need to expand the scope of the refactor to include the management of images and latents - and make `latents` a special case of `tensor`.
|
||||
|
||||
full story:
|
||||
The consensus for tensor-only nodes' metadata was to traverse the execution graph and grab the core parameters to write to the image. This was straightforward, and I've written functions to find the nearest t2l/l2l, noise, and compel nodes and build the metadata from those.
|
||||
|
||||
But struggling to integrate this and the associated edge cases this brought up a number of issues deeper in the system (some of which I had previously implemented). The ImageStorageService is doing way too much, and we have a need to be able to retrieve sessions the session given image/latents id, which is not currently feasible due to SQLite's JSON parsing performance.
|
||||
|
||||
I made a new ResultsService and `results` table in the db to facilitate this. This first attempt failed because it doesn't handle uploads and leaves the codebase messy.
|
||||
|
||||
So I've spent the day trying to figure out to handle this in a sane way and think I've got something decent. I've described some changes to service bases and the database below.
|
||||
|
||||
The gist of it is to store the core parameters for an image in its metadata when the image is saved, but never to read from it. Instead, the same metadata is stored in the database, which will be set up for efficient access. So when a page of images is requested, the metadata comes from the db instead of a filesystem operation.
|
||||
|
||||
The URL generation responsibilities have been split off the image storage service in to a URL service. New database services/tables for images and tensor are added. These services will provide paginated images/tensors for the API to serve. This also paves the way for handling tensors as first-class outputs.
|
||||
"""
|
||||
|
||||
|
||||
# TODO: Make a new model for this
|
||||
class ResourceOrigin(str, Enum):
|
||||
"""The origin of a resource (eg image or tensor)."""
|
||||
|
||||
RESULTS = "results"
|
||||
UPLOADS = "uploads"
|
||||
INTERMEDIATES = "intermediates"
|
||||
|
||||
|
||||
class ImageKind(str, Enum):
|
||||
"""The kind of an image."""
|
||||
|
||||
IMAGE = "image"
|
||||
CONTROL_IMAGE = "control_image"
|
||||
|
||||
|
||||
class TensorKind(str, Enum):
|
||||
"""The kind of a tensor."""
|
||||
|
||||
IMAGE_TENSOR = "tensor"
|
||||
CONDITIONING = "conditioning"
|
||||
|
||||
|
||||
"""
|
||||
Core Generation Metadata Pydantic Model
|
||||
|
||||
I've already implemented the code to traverse a session to build this object.
|
||||
"""
|
||||
|
||||
|
||||
class CoreGenerationMetadata(BaseModel):
|
||||
"""Core generation metadata for an image/tensor generated in InvokeAI.
|
||||
|
||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
|
||||
|
||||
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
||||
"""
|
||||
|
||||
positive_conditioning: Optional[str] = Field(
|
||||
description="The positive conditioning."
|
||||
)
|
||||
negative_conditioning: Optional[str] = Field(
|
||||
description="The negative conditioning."
|
||||
)
|
||||
width: Optional[int] = Field(description="Width of the image/tensor in pixels.")
|
||||
height: Optional[int] = Field(description="Height of the image/tensor in pixels.")
|
||||
seed: Optional[int] = Field(description="The seed used for noise generation.")
|
||||
cfg_scale: Optional[float] = Field(
|
||||
description="The classifier-free guidance scale."
|
||||
)
|
||||
steps: Optional[int] = Field(description="The number of steps used for inference.")
|
||||
scheduler: Optional[str] = Field(description="The scheduler used for inference.")
|
||||
model: Optional[str] = Field(description="The model used for inference.")
|
||||
strength: Optional[float] = Field(
|
||||
description="The strength used for image-to-image/tensor-to-tensor."
|
||||
)
|
||||
image: Optional[str] = Field(description="The ID of the initial image.")
|
||||
tensor: Optional[str] = Field(description="The ID of the initial tensor.")
|
||||
# Pending model refactor:
|
||||
# vae: Optional[str] = Field(description="The VAE used for decoding.")
|
||||
# unet: Optional[str] = Field(description="The UNet used dor inference.")
|
||||
# clip: Optional[str] = Field(description="The CLIP Encoder used for conditioning.")
|
||||
|
||||
|
||||
"""
|
||||
Minimal Uploads Metadata Model
|
||||
"""
|
||||
|
||||
|
||||
class UploadsMetadata(BaseModel):
|
||||
"""Limited metadata for an uploaded image/tensor."""
|
||||
|
||||
width: Optional[int] = Field(description="Width of the image/tensor in pixels.")
|
||||
height: Optional[int] = Field(description="Height of the image/tensor in pixels.")
|
||||
# The extra field will be the contents of the PNG file's tEXt chunk. It may have come
|
||||
# from another SD application or InvokeAI, so we need to make it very flexible. I think it's
|
||||
# best to just store it as a string and let the frontend parse it.
|
||||
# If the upload is a tensor type, this will be omitted.
|
||||
extra: Optional[str] = Field(
|
||||
description="Extra metadata, extracted from the PNG tEXt chunk."
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Slimmed-down Image Storage Service Base
|
||||
- No longer lists images or generates URLs - only stores and retrieves images.
|
||||
- OSS implementation for disk storage
|
||||
"""
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image: PILImage,
|
||||
image_kind: ImageKind,
|
||||
origin: ResourceOrigin,
|
||||
context_id: str,
|
||||
node_id: str,
|
||||
metadata: CoreGenerationMetadata,
|
||||
) -> str:
|
||||
"""Saves an image and its thumbnail, returning its unique identifier."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str, thumbnail: bool = False) -> Union[PILImage, None]:
|
||||
"""Retrieves an image as a PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, id: str) -> None:
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
|
||||
class TensorStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving tensors."""
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
tensor: Tensor,
|
||||
tensor_kind: TensorKind,
|
||||
origin: ResourceOrigin,
|
||||
context_id: str,
|
||||
node_id: str,
|
||||
metadata: CoreGenerationMetadata,
|
||||
) -> str:
|
||||
"""Saves a tensor, returning its unique identifier."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str, thumbnail: bool = False) -> Union[Tensor, None]:
|
||||
"""Retrieves a tensor as a torch Tensor."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, id: str) -> None:
|
||||
"""Deletes a tensor."""
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
New Url Service Base
|
||||
- Abstracts the logic for generating URLs out of the storage service
|
||||
- OSS implementation for locally-hosted URLs
|
||||
- Also provides a method to get the internal path to a resource (for OSS, the FS path)
|
||||
"""
|
||||
|
||||
|
||||
class ResourceLocationServiceBase(ABC):
|
||||
"""Responsible for locating resources (eg images or tensors)."""
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, id: str) -> str:
|
||||
"""Gets the URL for a resource."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, id: str) -> str:
|
||||
"""Gets the path for a resource."""
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
New Images Database Service Base
|
||||
|
||||
This is a new service that will be responsible for the new `images` table(s):
|
||||
- Storing images in the table
|
||||
- Retrieving individual images and pages of images
|
||||
- Deleting individual images
|
||||
|
||||
Operations will typically use joins with the various `images` tables.
|
||||
"""
|
||||
|
||||
|
||||
class ImagesDbServiceBase(ABC):
|
||||
"""Responsible for interfacing with `images` table."""
|
||||
|
||||
class GeneratedImageEntity(BaseModel):
|
||||
id: str = Field(description="The unique identifier for the image.")
|
||||
session_id: str = Field(description="The session ID.")
|
||||
node_id: str = Field(description="The node ID.")
|
||||
metadata: CoreGenerationMetadata = Field(
|
||||
description="The metadata for the image."
|
||||
)
|
||||
|
||||
class UploadedImageEntity(BaseModel):
|
||||
id: str = Field(description="The unique identifier for the image.")
|
||||
metadata: UploadsMetadata = Field(description="The metadata for the image.")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str) -> Union[GeneratedImageEntity, UploadedImageEntity, None]:
|
||||
"""Gets an image from the `images` table."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self, image_kind: ImageKind, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[Union[GeneratedImageEntity, UploadedImageEntity]]:
|
||||
"""Gets a page of images from the `images` table."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, id: str) -> None:
|
||||
"""Deletes an image from the `images` table."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(
|
||||
self,
|
||||
id: str,
|
||||
image_kind: ImageKind,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: CoreGenerationMetadata | UploadsMetadata,
|
||||
) -> None:
|
||||
"""Sets an image in the `images` table."""
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
New Tensor Database Service Base
|
||||
|
||||
This is a new service that will be responsible for the new `tensor` table:
|
||||
- Storing tensor in the table
|
||||
- Retrieving individual tensor and pages of tensor
|
||||
- Deleting individual tensor
|
||||
|
||||
Operations will always use joins with the `tensor_metadata` table.
|
||||
"""
|
||||
|
||||
|
||||
class TensorDbServiceBase(ABC):
|
||||
"""Responsible for interfacing with `tensor` table."""
|
||||
|
||||
class GeneratedTensorEntity(BaseModel):
|
||||
id: str = Field(description="The unique identifier for the tensor.")
|
||||
session_id: str = Field(description="The session ID.")
|
||||
node_id: str = Field(description="The node ID.")
|
||||
metadata: CoreGenerationMetadata = Field(
|
||||
description="The metadata for the tensor."
|
||||
)
|
||||
|
||||
class UploadedTensorEntity(BaseModel):
|
||||
id: str = Field(description="The unique identifier for the tensor.")
|
||||
metadata: UploadsMetadata = Field(description="The metadata for the tensor.")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str) -> Union[GeneratedTensorEntity, UploadedTensorEntity, None]:
|
||||
"""Gets a tensor from the `tensor` table."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self, tensor_kind: TensorKind, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[Union[GeneratedTensorEntity, UploadedTensorEntity]]:
|
||||
"""Gets a page of tensor from the `tensor` table."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, id: str) -> None:
|
||||
"""Deletes a tensor from the `tensor` table."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(
|
||||
self,
|
||||
id: str,
|
||||
tensor_kind: TensorKind,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: CoreGenerationMetadata | UploadsMetadata,
|
||||
) -> None:
|
||||
"""Sets a tensor in the `tensor` table."""
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
Database Changes
|
||||
|
||||
The existing tables will remain as-is, new tables will be added.
|
||||
|
||||
Tensor now also have the same types as images - `results`, `intermediates`, `uploads`. Storage, retrieval, and operations may diverge from images in the future, so they are managed separately.
|
||||
|
||||
A few `images` tables are created to store all images:
|
||||
- `results` and `intermediates` images have additional data: `session_id` and `node_id`, and may be further differentiated in the future. For this reason, they each get their own table.
|
||||
- `uploads` do not get their own table, as they are never going to have more than an `id`, `image_kind` and `timestamp`.
|
||||
- `images_metadata` holds the same image metadata that is written to the image. This table, along with the URL service, allow us to more efficiently serve images without having to read the image from storage.
|
||||
|
||||
The same tables are made for `tensor` and for the moment, implementation is expected to be identical.
|
||||
|
||||
Schemas for each table below.
|
||||
|
||||
Insertions and updates of ancillary tables (e.g. `results_images`, `images_metadata`, etc) will need to be done manually in the services, but should be straightforward. Deletion via cascading will be handled by the database.
|
||||
"""
|
||||
|
||||
|
||||
def create_sql_values_string_from_string_enum(enum: Type[Enum]):
|
||||
"""
|
||||
Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum.
|
||||
"""
|
||||
|
||||
delimiter = ", "
|
||||
values = [f"('{e.value}')" for e in enum]
|
||||
return delimiter.join(values)
|
||||
|
||||
|
||||
def create_sql_table_from_enum(
|
||||
enum: Type[Enum],
|
||||
table_name: str,
|
||||
primary_key_name: str,
|
||||
cursor: sqlite3.Cursor,
|
||||
lock: threading.Lock,
|
||||
):
|
||||
"""
|
||||
Creates and populates a table to be used as a functional enum.
|
||||
"""
|
||||
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
values_string = create_sql_values_string_from_string_enum(enum)
|
||||
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
{primary_key_name} TEXT PRIMARY KEY
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
"""
|
||||
`resource_origins` functions as an enum for the ResourceOrigin model.
|
||||
"""
|
||||
|
||||
|
||||
def create_resource_origins_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
create_sql_table_from_enum(
|
||||
enum=ResourceOrigin,
|
||||
table_name="resource_origins",
|
||||
primary_key_name="origin_name",
|
||||
cursor=cursor,
|
||||
lock=lock,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
`image_kinds` functions as an enum for the ImageType model.
|
||||
"""
|
||||
|
||||
|
||||
def create_image_kinds_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
create_sql_table_from_enum(
|
||||
enum=ImageKind,
|
||||
table_name="image_kinds",
|
||||
primary_key_name="kind_name",
|
||||
cursor=cursor,
|
||||
lock=lock,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
`tensor_kinds` functions as an enum for the TensorType model.
|
||||
"""
|
||||
|
||||
|
||||
def create_tensor_kinds_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
create_sql_table_from_enum(
|
||||
enum=TensorKind,
|
||||
table_name="tensor_kinds",
|
||||
primary_key_name="kind_name",
|
||||
cursor=cursor,
|
||||
lock=lock,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
`images` stores all images, regardless of type
|
||||
"""
|
||||
|
||||
|
||||
def create_images_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
id TEXT PRIMARY KEY,
|
||||
origin TEXT,
|
||||
image_kind TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),
|
||||
FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
"""
|
||||
`image_results` stores additional data specific to `results` images.
|
||||
"""
|
||||
|
||||
|
||||
def create_image_results_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS image_results (
|
||||
images_id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
node_id TEXT NOT NULL,
|
||||
FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_image_results_images_id ON image_results(id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
"""
|
||||
`image_intermediates` stores additional data specific to `intermediates` images
|
||||
"""
|
||||
|
||||
|
||||
def create_image_intermediates_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS image_intermediates (
|
||||
images_id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
node_id TEXT NOT NULL,
|
||||
FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_image_intermediates_images_id ON image_intermediates(id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
"""
|
||||
`images_metadata` stores basic metadata for any image type
|
||||
"""
|
||||
|
||||
|
||||
def create_images_metadata_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images_metadata (
|
||||
images_id TEXT PRIMARY KEY,
|
||||
metadata TEXT,
|
||||
FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
# `tensor` table: stores references to tensor
|
||||
|
||||
|
||||
def create_tensors_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tensors (
|
||||
id TEXT PRIMARY KEY,
|
||||
origin TEXT,
|
||||
tensor_kind TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),
|
||||
FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name),
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
# `results_tensor` stores additional data specific to `result` tensor
|
||||
|
||||
|
||||
def create_tensor_results_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tensor_results (
|
||||
tensor_id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
node_id TEXT NOT NULL,
|
||||
FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensor_results_tensor_id ON tensor_results(tensor_id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
# `tensor_intermediates` stores additional data specific to `intermediate` tensor
|
||||
|
||||
|
||||
def create_tensor_intermediates_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tensor_intermediates (
|
||||
tensor_id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
node_id TEXT NOT NULL,
|
||||
FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensor_intermediates_tensor_id ON tensor_intermediates(tensor_id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
# `tensors_metadata` table: stores generated/transformed metadata for tensor
|
||||
|
||||
|
||||
def create_tensors_metadata_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tensors_metadata (
|
||||
tensor_id TEXT PRIMARY KEY,
|
||||
metadata TEXT,
|
||||
FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensor_id ON tensors_metadata(tensor_id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
466
invokeai/app/services/results.py
Normal file
466
invokeai/app/services/results.py
Normal file
@@ -0,0 +1,466 @@
|
||||
from enum import Enum
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Any, Union
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from pydantic import BaseModel, Field, parse_obj_as, parse_raw_as
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.services.graph import Edge, GraphExecutionState
|
||||
from invokeai.app.invocations.latent import LatentsOutput
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
|
||||
class ResultType(str, Enum):
|
||||
image_output = "image_output"
|
||||
latents_output = "latents_output"
|
||||
|
||||
|
||||
class Result(BaseModel):
|
||||
"""A session result"""
|
||||
|
||||
id: str = Field(description="Result ID")
|
||||
session_id: str = Field(description="Session ID")
|
||||
node_id: str = Field(description="Node ID")
|
||||
data: Union[LatentsOutput, ImageOutput] = Field(description="The result data")
|
||||
|
||||
|
||||
class ResultWithSession(BaseModel):
|
||||
"""A result with its session"""
|
||||
|
||||
result: Result = Field(description="The result")
|
||||
session: GraphExecutionState = Field(description="The session")
|
||||
|
||||
|
||||
# Create a directed graph
|
||||
from typing import Any, TypedDict, Union
|
||||
from networkx import DiGraph
|
||||
import networkx as nx
|
||||
import json
|
||||
|
||||
|
||||
# We need to use a loose class for nodes to allow for graceful parsing - we cannot use the stricter
|
||||
# model used by the system, because we may be a graph in an old format. We can, however, use the
|
||||
# Edge model, because the edge format does not change.
|
||||
class LooseGraph(BaseModel):
|
||||
id: str
|
||||
nodes: dict[str, dict[str, Any]]
|
||||
edges: list[Edge]
|
||||
|
||||
|
||||
# An intermediate type used during parsing
|
||||
class NearestAncestor(TypedDict):
|
||||
node_id: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
# The ancestor types that contain the core metadata
|
||||
ANCESTOR_TYPES = ['t2l', 'l2l']
|
||||
|
||||
# The core metadata parameters in the ancestor types
|
||||
ANCESTOR_PARAMS = ['steps', 'model', 'cfg_scale', 'scheduler', 'strength']
|
||||
|
||||
# The core metadata parameters in the noise node
|
||||
NOISE_FIELDS = ['seed', 'width', 'height']
|
||||
|
||||
# Find nearest t2l or l2l ancestor from a given l2i node
|
||||
def find_nearest_ancestor(G: DiGraph, node_id: str) -> Union[NearestAncestor, None]:
|
||||
"""Returns metadata for the nearest ancestor of a given node.
|
||||
|
||||
Parameters:
|
||||
G (DiGraph): A directed graph.
|
||||
node_id (str): The ID of the starting node.
|
||||
|
||||
Returns:
|
||||
NearestAncestor | None: An object with the ID and metadata of the nearest ancestor.
|
||||
"""
|
||||
|
||||
# Retrieve the node from the graph
|
||||
node = G.nodes[node_id]
|
||||
|
||||
# If the node type is one of the core metadata node types, gather necessary metadata and return
|
||||
if node.get('type') in ANCESTOR_TYPES:
|
||||
parsed_metadata = {param: val for param, val in node.items() if param in ANCESTOR_PARAMS}
|
||||
return NearestAncestor(node_id=node_id, metadata=parsed_metadata)
|
||||
|
||||
|
||||
# Else, look for the ancestor in the predecessor nodes
|
||||
for predecessor in G.predecessors(node_id):
|
||||
result = find_nearest_ancestor(G, predecessor)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# If there are no valid ancestors, return None
|
||||
return None
|
||||
|
||||
|
||||
def get_additional_metadata(graph: LooseGraph, node_id: str) -> Union[dict[str, Any], None]:
|
||||
"""Collects additional metadata from nodes connected to a given node.
|
||||
|
||||
Parameters:
|
||||
graph (LooseGraph): The graph.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
dict | None: A dictionary containing additional metadata.
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
|
||||
# Iterate over all edges in the graph
|
||||
for edge in graph.edges:
|
||||
dest_node_id = edge.destination.node_id
|
||||
dest_field = edge.destination.field
|
||||
source_node = graph.nodes[edge.source.node_id]
|
||||
|
||||
# If the destination node ID matches the given node ID, gather necessary metadata
|
||||
if dest_node_id == node_id:
|
||||
# If the destination field is 'positive_conditioning', add the 'prompt' from the source node
|
||||
if dest_field == 'positive_conditioning':
|
||||
metadata['positive_conditioning'] = source_node.get('prompt')
|
||||
# If the destination field is 'negative_conditioning', add the 'prompt' from the source node
|
||||
if dest_field == 'negative_conditioning':
|
||||
metadata['negative_conditioning'] = source_node.get('prompt')
|
||||
# If the destination field is 'noise', add the core noise fields from the source node
|
||||
if dest_field == 'noise':
|
||||
for field in NOISE_FIELDS:
|
||||
metadata[field] = source_node.get(field)
|
||||
return metadata
|
||||
|
||||
def build_core_metadata(graph_raw: str, node_id: str) -> Union[dict, None]:
|
||||
"""Builds the core metadata for a given node.
|
||||
|
||||
Parameters:
|
||||
graph_raw (str): The graph structure as a raw string.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
dict | None: A dictionary containing core metadata.
|
||||
"""
|
||||
|
||||
# Create a directed graph to facilitate traversal
|
||||
G = nx.DiGraph()
|
||||
|
||||
# Convert the raw graph string into a JSON object
|
||||
graph = parse_obj_as(LooseGraph, graph_raw)
|
||||
|
||||
# Add nodes and edges to the graph
|
||||
for node_id, node_data in graph.nodes.items():
|
||||
G.add_node(node_id, **node_data)
|
||||
for edge in graph.edges:
|
||||
G.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
|
||||
# Find the nearest ancestor of the given node
|
||||
ancestor = find_nearest_ancestor(G, node_id)
|
||||
|
||||
# If no ancestor was found, return None
|
||||
if ancestor is None:
|
||||
return None
|
||||
|
||||
metadata = ancestor['metadata']
|
||||
ancestor_id = ancestor['node_id']
|
||||
|
||||
# Get additional metadata related to the ancestor
|
||||
addl_metadata = get_additional_metadata(graph, ancestor_id)
|
||||
|
||||
# If additional metadata was found, add it to the main metadata
|
||||
if addl_metadata is not None:
|
||||
metadata.update(addl_metadata)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
|
||||
class ResultsServiceABC(ABC):
|
||||
"""The Results service is responsible for retrieving results."""
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self, result_id: str, result_type: ResultType
|
||||
) -> Union[ResultWithSession, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self, result_type: ResultType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SqliteResultsService(ResultsServiceABC):
|
||||
"""SQLite implementation of the Results service."""
|
||||
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, filename: str):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS results (
|
||||
id TEXT PRIMARY KEY, -- the result's name
|
||||
result_type TEXT, -- `image_output` | `latents_output`
|
||||
node_id TEXT, -- the node that produced this result
|
||||
session_id TEXT, -- the session that produced this result
|
||||
created_at INTEGER, -- the time at which this result was created
|
||||
data TEXT -- the result itself
|
||||
);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_result_id ON results(id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _parse_joined_result(self, result_row: Any, column_names: list[str]):
|
||||
result_raw = {}
|
||||
session_raw = {}
|
||||
|
||||
for idx, name in enumerate(column_names):
|
||||
if name == "session":
|
||||
session_raw = json.loads(result_row[idx])
|
||||
elif name == "data":
|
||||
result_raw[name] = json.loads(result_row[idx])
|
||||
else:
|
||||
result_raw[name] = result_row[idx]
|
||||
|
||||
graph_raw = session_raw['execution_graph']
|
||||
|
||||
result = parse_obj_as(Result, result_raw)
|
||||
session = parse_obj_as(GraphExecutionState, session_raw)
|
||||
|
||||
m = build_core_metadata(graph_raw, result.node_id)
|
||||
print(m)
|
||||
|
||||
# g = session.execution_graph.nx_graph()
|
||||
# ancestors = nx.dag.ancestors(g, result.node_id)
|
||||
|
||||
# nodes = [session.execution_graph.get_node(result.node_id)]
|
||||
# for ancestor in ancestors:
|
||||
# nodes.append(session.execution_graph.get_node(ancestor))
|
||||
|
||||
# filtered_nodes = filter(lambda n: n.type in NODE_TYPE_ALLOWLIST, nodes)
|
||||
# print(list(map(lambda n: n.dict(), filtered_nodes)))
|
||||
# metadata = {}
|
||||
# for node in nodes:
|
||||
# if (node.type in ['txt2img', 'img2img',])
|
||||
# for field, value in node.dict().items():
|
||||
# if field not in ['type', 'id']:
|
||||
# if field not in metadata:
|
||||
# metadata[field] = value
|
||||
|
||||
# print(ancestors)
|
||||
# print(nodes)
|
||||
# print(metadata)
|
||||
|
||||
# for node in nodes:
|
||||
# print(node.dict())
|
||||
|
||||
# print(nodes)
|
||||
|
||||
return ResultWithSession(
|
||||
result=result,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def get(
|
||||
self, result_id: str, result_type: ResultType
|
||||
) -> Union[ResultWithSession, None]:
|
||||
"""Retrieves a result by ID and type."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT
|
||||
results.id AS id,
|
||||
results.result_type AS result_type,
|
||||
results.node_id AS node_id,
|
||||
results.session_id AS session_id,
|
||||
results.data AS data,
|
||||
graph_executions.item AS session
|
||||
FROM results
|
||||
JOIN graph_executions ON results.session_id = graph_executions.id
|
||||
WHERE results.id = ? AND results.result_type = ?
|
||||
""",
|
||||
(result_id, result_type),
|
||||
)
|
||||
|
||||
result_row = self._cursor.fetchone()
|
||||
|
||||
if result_row is None:
|
||||
return None
|
||||
|
||||
column_names = list(map(lambda x: x[0], self._cursor.description))
|
||||
result_parsed = self._parse_joined_result(result_row, column_names)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result_parsed:
|
||||
return None
|
||||
|
||||
return result_parsed
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
result_type: ResultType,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
"""Lists results of a given type."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT
|
||||
results.id AS id,
|
||||
results.result_type AS result_type,
|
||||
results.node_id AS node_id,
|
||||
results.session_id AS session_id,
|
||||
results.data AS data,
|
||||
graph_executions.item AS session
|
||||
FROM results
|
||||
JOIN graph_executions ON results.session_id = graph_executions.id
|
||||
WHERE results.result_type = ?
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(result_type.value, per_page, page * per_page),
|
||||
)
|
||||
|
||||
result_rows = self._cursor.fetchall()
|
||||
column_names = list(map(lambda c: c[0], self._cursor.description))
|
||||
|
||||
result_parsed = []
|
||||
|
||||
for result_row in result_rows:
|
||||
result_parsed.append(
|
||||
self._parse_joined_result(result_row, column_names)
|
||||
)
|
||||
|
||||
self._cursor.execute("""SELECT count(*) FROM results;""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[ResultWithSession](
|
||||
items=result_parsed,
|
||||
page=page,
|
||||
pages=pageCount,
|
||||
per_page=per_page,
|
||||
total=count,
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
"""Finds results by query."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT results.data, graph_executions.item
|
||||
FROM results
|
||||
JOIN graph_executions ON results.session_id = graph_executions.id
|
||||
WHERE item LIKE ?
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
)
|
||||
|
||||
result_rows = self._cursor.fetchall()
|
||||
|
||||
items = list(
|
||||
map(
|
||||
lambda r: ResultWithSession(
|
||||
result=parse_raw_as(Result, r[0]),
|
||||
session=parse_raw_as(GraphExecutionState, r[1]),
|
||||
),
|
||||
result_rows,
|
||||
)
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*) FROM results WHERE item LIKE ?;
|
||||
""",
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[ResultWithSession](
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
|
||||
"""Updates the results table with the results from the session."""
|
||||
with self._conn as conn:
|
||||
for node_id, result in session.results.items():
|
||||
# We'll only process 'image_output' or 'latents_output'
|
||||
if result.type not in ["image_output", "latents_output"]:
|
||||
continue
|
||||
|
||||
# The id depends on the result type
|
||||
if result.type == "image_output":
|
||||
id = result.image.image_name
|
||||
result_type = "image_output"
|
||||
else:
|
||||
id = result.latents.latents_name
|
||||
result_type = "latents_output"
|
||||
|
||||
# Insert the result into the results table, ignoring if it already exists
|
||||
conn.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO results (id, result_type, node_id, session_id, created_at, data)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
id,
|
||||
result_type,
|
||||
node_id,
|
||||
session.id,
|
||||
get_timestamp(),
|
||||
result.json(),
|
||||
),
|
||||
)
|
||||
30
invokeai/app/services/urls.py
Normal file
30
invokeai/app/services/urls.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name
|
||||
|
||||
|
||||
class UrlServiceBase(ABC):
|
||||
"""Responsible for building URLs for resources."""
|
||||
|
||||
@abstractmethod
|
||||
def get_image_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalUrlService(UrlServiceBase):
|
||||
def __init__(self, base_url: str = "api/v1"):
|
||||
self._base_url = base_url
|
||||
|
||||
def get_image_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
if thumbnail:
|
||||
return f"{self._base_url}/files/images/{image_type.value}/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/files/images/{image_type.value}/{image_basename}"
|
||||
39
invokeai/app/services/util/create_enum_table.py
Normal file
39
invokeai/app/services/util/create_enum_table.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from enum import Enum
|
||||
import sqlite3
|
||||
from typing import Type
|
||||
|
||||
|
||||
def create_sql_values_string_from_string_enum(enum: Type[Enum]):
|
||||
"""
|
||||
Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum.
|
||||
"""
|
||||
|
||||
delimiter = ", "
|
||||
values = [f"('{e.value}')" for e in enum]
|
||||
return delimiter.join(values)
|
||||
|
||||
|
||||
def create_enum_table(
|
||||
enum: Type[Enum],
|
||||
table_name: str,
|
||||
primary_key_name: str,
|
||||
cursor: sqlite3.Cursor,
|
||||
):
|
||||
"""
|
||||
Creates and populates a table to be used as a functional enum.
|
||||
"""
|
||||
|
||||
values_string = create_sql_values_string_from_string_enum(enum)
|
||||
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
{primary_key_name} TEXT PRIMARY KEY
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};
|
||||
"""
|
||||
)
|
||||
12
invokeai/app/util/enum.py
Normal file
12
invokeai/app/util/enum.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from enum import EnumMeta
|
||||
|
||||
|
||||
class MetaEnum(EnumMeta):
|
||||
"""Metaclass to support `in` syntax value checking in String Enums"""
|
||||
|
||||
def __contains__(cls, item):
|
||||
try:
|
||||
cls(item)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
@@ -1,5 +1,21 @@
|
||||
import datetime
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
|
||||
|
||||
def get_iso_timestamp() -> str:
|
||||
return datetime.datetime.utcnow().isoformat()
|
||||
|
||||
|
||||
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
||||
return datetime.datetime.fromisoformat(iso_timestamp)
|
||||
|
||||
|
||||
SEED_MAX = np.iinfo(np.int32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
return np.random.randint(0, SEED_MAX)
|
||||
|
||||
@@ -108,17 +108,21 @@ APP_VERSION = invokeai.version.__version__
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
"ddim",
|
||||
"k_dpm_2_a",
|
||||
"k_dpm_2",
|
||||
"k_dpmpp_2_a",
|
||||
"k_dpmpp_2",
|
||||
"k_euler_a",
|
||||
"k_euler",
|
||||
"k_heun",
|
||||
"k_lms",
|
||||
"plms",
|
||||
# diffusers:
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"pndm",
|
||||
"heun",
|
||||
"heun_k",
|
||||
"euler",
|
||||
"euler_k",
|
||||
"euler_a",
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
PRECISION_CHOICES = [
|
||||
@@ -631,7 +635,7 @@ class Args(object):
|
||||
choices=SAMPLER_CHOICES,
|
||||
metavar="SAMPLER_NAME",
|
||||
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||
default="k_lms",
|
||||
default="lms",
|
||||
)
|
||||
render_group.add_argument(
|
||||
"--log_tokenization",
|
||||
|
||||
@@ -37,6 +37,7 @@ from .safety_checker import SafetyChecker
|
||||
from .prompting import get_uc_and_c_and_ec
|
||||
from .prompting.conditioning import log_tokenization
|
||||
from .stable_diffusion import HuggingFaceConceptsLibrary
|
||||
from .stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from .util import choose_precision, choose_torch_device
|
||||
|
||||
def fix_func(orig):
|
||||
@@ -141,7 +142,7 @@ class Generate:
|
||||
model=None,
|
||||
conf="configs/models.yaml",
|
||||
embedding_path=None,
|
||||
sampler_name="k_lms",
|
||||
sampler_name="lms",
|
||||
ddim_eta=0.0, # deterministic
|
||||
full_precision=False,
|
||||
precision="auto",
|
||||
@@ -1047,29 +1048,12 @@ class Generate:
|
||||
def _set_scheduler(self):
|
||||
default = self.model.scheduler
|
||||
|
||||
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
# DPMSolverMultistepScheduler is technically not `k_` anything, as it is neither
|
||||
# the k-diffusers implementation nor included in EDM (Karras 2022), but we can
|
||||
# provide an alias for compatibility.
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
|
||||
if self.sampler_name in scheduler_map:
|
||||
sampler_class = scheduler_map[self.sampler_name]
|
||||
if self.sampler_name in SCHEDULER_MAP:
|
||||
sampler_class, sampler_extra_config = SCHEDULER_MAP[self.sampler_name]
|
||||
msg = (
|
||||
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
||||
)
|
||||
self.sampler = sampler_class.from_config(self.model.scheduler.config)
|
||||
self.sampler = sampler_class.from_config({**self.model.scheduler.config, **sampler_extra_config})
|
||||
else:
|
||||
msg = (
|
||||
f" Unsupported Sampler: {self.sampler_name} "+
|
||||
|
||||
@@ -31,6 +31,7 @@ from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
downsampling = 8
|
||||
|
||||
@@ -71,19 +72,6 @@ class InvokeAIGeneratorOutput:
|
||||
# we are interposing a wrapper around the original Generator classes so that
|
||||
# old code that calls Generate will continue to work.
|
||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
@@ -175,14 +163,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
'''
|
||||
Return list of all the schedulers that we currently handle.
|
||||
'''
|
||||
return list(self.scheduler_map.keys())
|
||||
return list(SCHEDULER_MAP.keys())
|
||||
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
@@ -226,10 +220,10 @@ class Inpaint(Img2Img):
|
||||
def generate(self,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
seam_steps: int = 30,
|
||||
tile_size: int = 32,
|
||||
inpaint_replace=False,
|
||||
infill_method=None,
|
||||
|
||||
@@ -4,6 +4,7 @@ invokeai.backend.generator.inpaint descends from .generator
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -59,7 +60,7 @@ class Inpaint(Img2Img):
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
def infill_patchmatch(self, im: Image.Image) -> Image:
|
||||
def infill_patchmatch(self, im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
@@ -75,18 +76,18 @@ class Inpaint(Img2Img):
|
||||
return im_patched
|
||||
|
||||
def tile_fill_missing(
|
||||
self, im: Image.Image, tile_size: int = 16, seed: int = None
|
||||
) -> Image:
|
||||
self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size = (tile_size, tile_size)
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = self.get_tile_images(a, *tile_size).copy()
|
||||
tiles = self.get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
@@ -127,7 +128,9 @@ class Inpaint(Img2Img):
|
||||
|
||||
return si
|
||||
|
||||
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
|
||||
def mask_edge(
|
||||
self, mask: Image.Image, edge_size: int, edge_blur: int
|
||||
) -> Image.Image:
|
||||
npimg = np.asarray(mask, dtype=np.uint8)
|
||||
|
||||
# Detect any partially transparent regions
|
||||
@@ -206,15 +209,15 @@ class Inpaint(Img2Img):
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image: PIL.Image.Image | torch.FloatTensor,
|
||||
mask_image: PIL.Image.Image | torch.FloatTensor,
|
||||
init_image: Image.Image | torch.FloatTensor,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
strength: float,
|
||||
mask_blur_radius: int = 8,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
seam_steps: int = 30,
|
||||
tile_size: int = 32,
|
||||
step_callback=None,
|
||||
inpaint_replace=False,
|
||||
@@ -222,7 +225,7 @@ class Inpaint(Img2Img):
|
||||
infill_method=None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
inpaint_fill: Tuple[int, int, int, int] = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -239,7 +242,7 @@ class Inpaint(Img2Img):
|
||||
self.inpaint_width = inpaint_width
|
||||
self.inpaint_height = inpaint_height
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
if isinstance(init_image, Image.Image):
|
||||
self.pil_image = init_image.copy()
|
||||
|
||||
# Do infill
|
||||
@@ -250,8 +253,8 @@ class Inpaint(Img2Img):
|
||||
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
|
||||
)
|
||||
elif infill_method == "solid":
|
||||
solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill)
|
||||
init_filled = PIL.Image.alpha_composite(solid_bg, init_image)
|
||||
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
|
||||
init_filled = Image.alpha_composite(solid_bg, init_image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Non-supported infill type {infill_method}", infill_method
|
||||
@@ -269,7 +272,7 @@ class Inpaint(Img2Img):
|
||||
# Create init tensor
|
||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
||||
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
if isinstance(mask_image, Image.Image):
|
||||
self.pil_mask = mask_image.copy()
|
||||
debug_image(
|
||||
mask_image,
|
||||
|
||||
@@ -47,6 +47,7 @@ from diffusers import (
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
@@ -1209,6 +1210,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == 'unipc':
|
||||
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = scheduler
|
||||
else:
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
SchedulerMixin,
|
||||
logging as dlogging,
|
||||
)
|
||||
)
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
@@ -68,7 +68,7 @@ class SDModelComponent(Enum):
|
||||
scheduler="scheduler"
|
||||
safety_checker="safety_checker"
|
||||
feature_extractor="feature_extractor"
|
||||
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
|
||||
class ModelManager(object):
|
||||
@@ -182,7 +182,7 @@ class ModelManager(object):
|
||||
vae from the model currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.vae)
|
||||
|
||||
|
||||
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned CLIPTokenizer. If no
|
||||
@@ -190,12 +190,12 @@ class ModelManager(object):
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
|
||||
|
||||
|
||||
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned UNet2DConditionModel. If no model
|
||||
name is provided, return the UNet from the model
|
||||
currently in the GPU.
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.unet)
|
||||
|
||||
@@ -222,7 +222,7 @@ class ModelManager(object):
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.scheduler)
|
||||
|
||||
|
||||
def _get_sub_model(
|
||||
self,
|
||||
model_name: str=None,
|
||||
@@ -1228,7 +1228,7 @@ class ModelManager(object):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||
self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)")
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
@@ -16,6 +16,7 @@ from compel.prompt_parser import (
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
PromptParser,
|
||||
Conjunction,
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
@@ -25,58 +26,48 @@ from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||
from ..util import torch_dtype
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(
|
||||
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
|
||||
):
|
||||
def get_uc_and_c_and_ec(prompt_string,
|
||||
model: InvokeAIDiffuserComponent,
|
||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_string
|
||||
)
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
|
||||
tokenizer = model.tokenizer
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False
|
||||
)
|
||||
compel = Compel(tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
# get rid of any newline characters
|
||||
prompt_string = prompt_string.replace("\n", " ")
|
||||
(
|
||||
positive_prompt_string,
|
||||
negative_prompt_string,
|
||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
||||
negative_prompt_string
|
||||
)
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_conjunction: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||
|
||||
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
'cross_attention_control', None))
|
||||
return uc, c, ec
|
||||
|
||||
|
||||
def get_prompt_structure(
|
||||
prompt_string, skip_normalize_legacy_blend: bool = False
|
||||
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
@@ -87,18 +78,17 @@ def get_prompt_structure(
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
||||
positive_prompt: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
||||
negative_prompt_string
|
||||
)
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
||||
|
||||
return positive_prompt, negative_prompt
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||
) -> int:
|
||||
@@ -245,22 +235,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
|
||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
logger.debug(f"{discarded}\x1b[0m")
|
||||
|
||||
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
weights = [x[1] for x in weighted_subprompts]
|
||||
|
||||
pp = PromptParser()
|
||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
|
||||
return Blend(
|
||||
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
|
||||
)
|
||||
|
||||
flattened_prompts = []
|
||||
weights = []
|
||||
for i, x in enumerate(parsed_conjunctions):
|
||||
if len(x.prompts)>0:
|
||||
flattened_prompts.append(x.prompts[0])
|
||||
weights.append(weighted_subprompts[i][1])
|
||||
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
"""
|
||||
|
||||
@@ -509,10 +509,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
else:
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
|
||||
if timesteps is None:
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps, device=self._model_group.device_for(self.unet)
|
||||
)
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(
|
||||
self.generate_latents_from_embeddings, PipelineIntermediateState
|
||||
@@ -545,8 +548,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
):
|
||||
yield PipelineIntermediateState(
|
||||
run_id=run_id,
|
||||
@@ -725,12 +729,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: torch.Tensor,
|
||||
run_id=None,
|
||||
callback=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(
|
||||
num_inference_steps,
|
||||
strength,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
latents=initial_latents if strength < 1.0 else torch.zeros_like(
|
||||
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
|
||||
@@ -756,13 +756,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def get_img2img_timesteps(
|
||||
self, num_inference_steps: int, strength: float, device
|
||||
self, num_inference_steps: int, strength: float, device=None
|
||||
) -> (torch.Tensor, int):
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
assert img2img_pipeline.scheduler is self.scheduler
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
else:
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
|
||||
num_inference_steps, strength, device=device
|
||||
num_inference_steps, strength, device=scheduler_device
|
||||
)
|
||||
# Workaround for low strength resulting in zero timesteps.
|
||||
# TODO: submit upstream fix for zero-step img2img
|
||||
@@ -796,9 +802,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if init_image.dim() == 3:
|
||||
init_image = init_image.unsqueeze(0)
|
||||
|
||||
timesteps, _ = self.get_img2img_timesteps(
|
||||
num_inference_steps, strength, device=device
|
||||
)
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
|
||||
@@ -10,6 +10,7 @@ import diffusers
|
||||
import psutil
|
||||
import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from torch import nn
|
||||
|
||||
@@ -352,8 +353,7 @@ def restore_default_cross_attention(
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def override_cross_attention(model, context: Context, is_running_diffusers=False):
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
@@ -372,37 +372,22 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
context.cross_attention_mask = mask.to(device)
|
||||
context.cross_attention_index_map = indices.to(device)
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next(
|
||||
(
|
||||
p.slice_size
|
||||
for p in old_attn_processors.values()
|
||||
if type(p) is SlicedAttnProcessor
|
||||
),
|
||||
default_slice_size,
|
||||
)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
return old_attn_processors
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
context.register_cross_attention_modules(model)
|
||||
inject_attention_function(model, context)
|
||||
return None
|
||||
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
def get_cross_attention_modules(
|
||||
model, which: CrossAttentionType
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
@@ -17,8 +18,8 @@ from .cross_attention_control import (
|
||||
CrossAttentionType,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
override_cross_attention,
|
||||
restore_default_cross_attention,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
@@ -79,24 +80,35 @@ class InvokeAIDiffuserComponent:
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = Globals.sequential_guidance
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def custom_attention_context(
|
||||
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
|
||||
cls,
|
||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int
|
||||
):
|
||||
do_swap = (
|
||||
extra_conditioning_info is not None
|
||||
and extra_conditioning_info.wants_cross_attention_control
|
||||
)
|
||||
old_attn_processor = None
|
||||
if do_swap:
|
||||
old_attn_processor = self.override_cross_attention(
|
||||
extra_conditioning_info, step_count=step_count
|
||||
)
|
||||
old_attn_processors = None
|
||||
if extra_conditioning_info and (
|
||||
extra_conditioning_info.wants_cross_attention_control
|
||||
):
|
||||
old_attn_processors = unet.attn_processors
|
||||
# Load lora conditions into the model
|
||||
if extra_conditioning_info.wants_cross_attention_control:
|
||||
cross_attention_control_context = Context(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
cross_attention_control_context,
|
||||
)
|
||||
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if old_attn_processor is not None:
|
||||
self.restore_default_cross_attention(old_attn_processor)
|
||||
if old_attn_processors is not None:
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
|
||||
1
invokeai/backend/stable_diffusion/schedulers/__init__.py
Normal file
1
invokeai/backend/stable_diffusion/schedulers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .schedulers import SCHEDULER_MAP
|
||||
23
invokeai/backend/stable_diffusion/schedulers/schedulers.py
Normal file
23
invokeai/backend/stable_diffusion/schedulers/schedulers.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
|
||||
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
|
||||
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
|
||||
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
|
||||
|
||||
SCHEDULER_MAP = dict(
|
||||
ddim=(DDIMScheduler, dict()),
|
||||
ddpm=(DDPMScheduler, dict()),
|
||||
deis=(DEISMultistepScheduler, dict()),
|
||||
lms=(LMSDiscreteScheduler, dict()),
|
||||
pndm=(PNDMScheduler, dict()),
|
||||
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
||||
kdpm_2=(KDPM2DiscreteScheduler, dict()),
|
||||
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
|
||||
dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
|
||||
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
|
||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
)
|
||||
@@ -2,34 +2,37 @@
|
||||
|
||||
"""invokeai.util.logging
|
||||
|
||||
Logging class for InvokeAI that produces console messages that follow
|
||||
the conventions established in InvokeAI 1.X through 2.X.
|
||||
Logging class for InvokeAI that produces console messages
|
||||
|
||||
|
||||
One way to use it:
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.getLogger(__name__)
|
||||
logger.critical('this is critical')
|
||||
logger.error('this is an error')
|
||||
logger.warning('this is a warning')
|
||||
logger.info('this is info')
|
||||
logger.debug('this is debugging')
|
||||
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
|
||||
(or)
|
||||
logger = InvokeAILogger.getLogger(__name__) // To use the filename
|
||||
|
||||
logger.critical('this is critical') // Critical Message
|
||||
logger.error('this is an error') // Error Message
|
||||
logger.warning('this is a warning') // Warning Message
|
||||
logger.info('this is info') // Info Message
|
||||
logger.debug('this is debugging') // Debug Message
|
||||
|
||||
Console messages:
|
||||
### this is critical
|
||||
*** this is an error ***
|
||||
** this is a warning
|
||||
>> this is info
|
||||
| this is debugging
|
||||
[12-05-2023 20]::[InvokeAI]::CRITICAL --> This is an info message [In Bold Red]
|
||||
[12-05-2023 20]::[InvokeAI]::ERROR --> This is an info message [In Red]
|
||||
[12-05-2023 20]::[InvokeAI]::WARNING --> This is an info message [In Yellow]
|
||||
[12-05-2023 20]::[InvokeAI]::INFO --> This is an info message [In Grey]
|
||||
[12-05-2023 20]::[InvokeAI]::DEBUG --> This is an info message [In Grey]
|
||||
|
||||
Another way:
|
||||
import invokeai.backend.util.logging as ialog
|
||||
ialogger.debug('this is a debugging message')
|
||||
Alternate Method (in this case the logger name will be set to InvokeAI):
|
||||
import invokeai.backend.util.logging as IAILogger
|
||||
IAILogger.debug('this is a debugging message')
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
# module level functions
|
||||
def debug(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
|
||||
@@ -42,7 +45,7 @@ def warning(msg, *args, **kwargs):
|
||||
|
||||
def error(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def critical(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
|
||||
|
||||
@@ -55,55 +58,53 @@ def disable(level=logging.CRITICAL):
|
||||
def basicConfig(**kwargs):
|
||||
InvokeAILogger.getLogger().basicConfig(**kwargs)
|
||||
|
||||
def getLogger(name: str=None)->logging.Logger:
|
||||
def getLogger(name: str = None) -> logging.Logger:
|
||||
return InvokeAILogger.getLogger(name)
|
||||
|
||||
|
||||
class InvokeAILogFormatter(logging.Formatter):
|
||||
'''
|
||||
Repurposed from:
|
||||
https://stackoverflow.com/questions/14844970/modifying-logging-message-format-based-on-message-logging-level-in-python3
|
||||
Custom Formatting for the InvokeAI Logger
|
||||
'''
|
||||
crit_fmt = "### %(msg)s"
|
||||
err_fmt = "*** %(msg)s"
|
||||
warn_fmt = "** %(msg)s"
|
||||
info_fmt = ">> %(msg)s"
|
||||
dbg_fmt = " | %(msg)s"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(fmt="%(levelno)d: %(msg)s", datefmt=None, style='%')
|
||||
# Color Codes
|
||||
grey = "\x1b[38;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
red = "\x1b[31;20m"
|
||||
cyan = "\x1b[36;20m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
reset = "\x1b[0m"
|
||||
|
||||
# Log Format
|
||||
log_format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
||||
## More Formatting Options: %(pathname)s, %(filename)s, %(module)s, %(lineno)d
|
||||
|
||||
# Format Map
|
||||
FORMATS = {
|
||||
logging.DEBUG: cyan + log_format + reset,
|
||||
logging.INFO: grey + log_format + reset,
|
||||
logging.WARNING: yellow + log_format + reset,
|
||||
logging.ERROR: red + log_format + reset,
|
||||
logging.CRITICAL: bold_red + log_format + reset
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
# Remember the format used when the logging module
|
||||
# was installed (in the event that this formatter is
|
||||
# used with the vanilla logging module.
|
||||
format_orig = self._style._fmt
|
||||
if record.levelno == logging.DEBUG:
|
||||
self._style._fmt = InvokeAILogFormatter.dbg_fmt
|
||||
if record.levelno == logging.INFO:
|
||||
self._style._fmt = InvokeAILogFormatter.info_fmt
|
||||
if record.levelno == logging.WARNING:
|
||||
self._style._fmt = InvokeAILogFormatter.warn_fmt
|
||||
if record.levelno == logging.ERROR:
|
||||
self._style._fmt = InvokeAILogFormatter.err_fmt
|
||||
if record.levelno == logging.CRITICAL:
|
||||
self._style._fmt = InvokeAILogFormatter.crit_fmt
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt, datefmt="%d-%m-%Y %H:%M:%S")
|
||||
return formatter.format(record)
|
||||
|
||||
# parent class does the work
|
||||
result = super().format(record)
|
||||
self._style._fmt = format_orig
|
||||
return result
|
||||
|
||||
class InvokeAILogger(object):
|
||||
loggers = dict()
|
||||
|
||||
|
||||
@classmethod
|
||||
def getLogger(self, name:str='invokeai')->logging.Logger:
|
||||
if name not in self.loggers:
|
||||
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger:
|
||||
if name not in cls.loggers:
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
ch = logging.StreamHandler()
|
||||
fmt = InvokeAILogFormatter()
|
||||
ch.setFormatter(fmt)
|
||||
logger.addHandler(ch)
|
||||
self.loggers[name] = logger
|
||||
return self.loggers[name]
|
||||
cls.loggers[name] = logger
|
||||
return cls.loggers[name]
|
||||
|
||||
@@ -4,17 +4,21 @@ from .parse_seed_weights import parse_seed_weights
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
"ddim",
|
||||
"k_dpm_2_a",
|
||||
"k_dpm_2",
|
||||
"k_dpmpp_2_a",
|
||||
"k_dpmpp_2",
|
||||
"k_euler_a",
|
||||
"k_euler",
|
||||
"k_heun",
|
||||
"k_lms",
|
||||
"plms",
|
||||
# diffusers:
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"pndm",
|
||||
"heun",
|
||||
'heun_k',
|
||||
"euler",
|
||||
"euler_k",
|
||||
"euler_a",
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
{
|
||||
"plugins": [
|
||||
[
|
||||
"transform-imports",
|
||||
{
|
||||
"lodash": {
|
||||
"transform": "lodash/${member}",
|
||||
"preventFullImport": true
|
||||
}
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
6
invokeai/frontend/web/.gitignore
vendored
6
invokeai/frontend/web/.gitignore
vendored
@@ -34,4 +34,8 @@ stats.html
|
||||
!.yarn/plugins
|
||||
!.yarn/releases
|
||||
!.yarn/sdks
|
||||
!.yarn/versions
|
||||
!.yarn/versions
|
||||
|
||||
# Yalc
|
||||
.yalc
|
||||
yalc.lock
|
||||
@@ -5,6 +5,7 @@ import { PluginOption, UserConfig } from 'vite';
|
||||
import dts from 'vite-plugin-dts';
|
||||
import eslint from 'vite-plugin-eslint';
|
||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
||||
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
|
||||
|
||||
export const packageConfig: UserConfig = {
|
||||
base: './',
|
||||
@@ -16,9 +17,10 @@ export const packageConfig: UserConfig = {
|
||||
dts({
|
||||
insertTypesEntry: true,
|
||||
}),
|
||||
cssInjectedByJsPlugin(),
|
||||
],
|
||||
build: {
|
||||
chunkSizeWarningLimit: 1500,
|
||||
cssCodeSplit: true,
|
||||
lib: {
|
||||
entry: path.resolve(__dirname, '../src/index.ts'),
|
||||
name: 'InvokeAIUI',
|
||||
@@ -30,6 +32,7 @@ export const packageConfig: UserConfig = {
|
||||
globals: {
|
||||
react: 'React',
|
||||
'react-dom': 'ReactDOM',
|
||||
'@emotion/react': 'EmotionReact',
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -15,15 +15,3 @@ The `postinstall` script patches a few packages and runs the Chakra CLI to gener
|
||||
### Patch `@chakra-ui/cli`
|
||||
|
||||
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
|
||||
|
||||
### Patch `redux-persist`
|
||||
|
||||
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
|
||||
|
||||
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
|
||||
|
||||
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
|
||||
|
||||
### Patch `redux-deep-persist`
|
||||
|
||||
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.
|
||||
|
||||
@@ -37,7 +37,7 @@ From `invokeai/frontend/web/` run `yarn install` to get everything set up.
|
||||
Start everything in dev mode:
|
||||
|
||||
1. Start the dev server: `yarn dev`
|
||||
2. Start the InvokeAI UI per usual: `invokeai --web`
|
||||
2. Start the InvokeAI Nodes backend: `python scripts/invokeai-new.py --web # run from the repo root`
|
||||
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
|
||||
|
||||
### Production builds
|
||||
|
||||
@@ -21,7 +21,6 @@
|
||||
"scripts": {
|
||||
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
|
||||
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
||||
"build": "yarn run lint && vite build",
|
||||
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||
@@ -63,11 +62,13 @@
|
||||
"@dagrejs/graphlib": "^2.1.12",
|
||||
"@emotion/react": "^11.10.6",
|
||||
"@emotion/styled": "^11.10.6",
|
||||
"@floating-ui/react-dom": "^2.0.0",
|
||||
"@fontsource/inter": "^4.5.15",
|
||||
"@reduxjs/toolkit": "^1.9.5",
|
||||
"@roarr/browser-log-writer": "^1.1.5",
|
||||
"chakra-ui-contextmenu": "^1.0.5",
|
||||
"dateformat": "^5.0.3",
|
||||
"downshift": "^7.6.0",
|
||||
"formik": "^2.2.9",
|
||||
"framer-motion": "^10.12.4",
|
||||
"fuse.js": "^6.6.2",
|
||||
@@ -88,17 +89,14 @@
|
||||
"react-i18next": "^12.2.2",
|
||||
"react-icons": "^4.7.1",
|
||||
"react-konva": "^18.2.7",
|
||||
"react-konva-utils": "^1.0.4",
|
||||
"react-redux": "^8.0.5",
|
||||
"react-rnd": "^10.4.1",
|
||||
"react-transition-group": "^4.4.5",
|
||||
"react-resizable-panels": "^0.0.42",
|
||||
"react-use": "^17.4.0",
|
||||
"react-virtuoso": "^4.3.5",
|
||||
"react-zoom-pan-pinch": "^3.0.7",
|
||||
"reactflow": "^11.7.0",
|
||||
"redux-deep-persist": "^1.0.7",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-persist": "^6.0.0",
|
||||
"redux-remember": "^3.3.1",
|
||||
"roarr": "^7.15.0",
|
||||
"serialize-error": "^11.0.0",
|
||||
"socket.io-client": "^4.6.0",
|
||||
@@ -118,6 +116,7 @@
|
||||
"@types/node": "^18.16.2",
|
||||
"@types/react": "^18.2.0",
|
||||
"@types/react-dom": "^18.2.1",
|
||||
"@types/react-redux": "^7.1.25",
|
||||
"@types/react-transition-group": "^4.4.5",
|
||||
"@types/uuid": "^9.0.0",
|
||||
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
||||
@@ -143,6 +142,7 @@
|
||||
"terser": "^5.17.1",
|
||||
"ts-toolbelt": "^9.6.0",
|
||||
"vite": "^4.3.3",
|
||||
"vite-plugin-css-injected-by-js": "^3.1.1",
|
||||
"vite-plugin-dts": "^2.3.0",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
"vite-tsconfig-paths": "^4.2.0",
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
diff --git a/node_modules/redux-deep-persist/lib/types.d.ts b/node_modules/redux-deep-persist/lib/types.d.ts
|
||||
index b67b8c2..7fc0fa1 100644
|
||||
--- a/node_modules/redux-deep-persist/lib/types.d.ts
|
||||
+++ b/node_modules/redux-deep-persist/lib/types.d.ts
|
||||
@@ -35,6 +35,7 @@ export interface PersistConfig<S, RS = any, HSS = any, ESS = any> {
|
||||
whitelist?: Array<string>;
|
||||
transforms?: Array<Transform<HSS, ESS, S, RS>>;
|
||||
throttle?: number;
|
||||
+ debounce?: number;
|
||||
migrate?: PersistMigrate;
|
||||
stateReconciler?: false | StateReconciler<S>;
|
||||
getStoredState?: (config: PersistConfig<S, RS, HSS, ESS>) => Promise<PersistedState>;
|
||||
diff --git a/node_modules/redux-deep-persist/src/types.ts b/node_modules/redux-deep-persist/src/types.ts
|
||||
index 398ac19..cbc5663 100644
|
||||
--- a/node_modules/redux-deep-persist/src/types.ts
|
||||
+++ b/node_modules/redux-deep-persist/src/types.ts
|
||||
@@ -91,6 +91,7 @@ export interface PersistConfig<S, RS = any, HSS = any, ESS = any> {
|
||||
whitelist?: Array<string>;
|
||||
transforms?: Array<Transform<HSS, ESS, S, RS>>;
|
||||
throttle?: number;
|
||||
+ debounce?: number;
|
||||
migrate?: PersistMigrate;
|
||||
stateReconciler?: false | StateReconciler<S>;
|
||||
/**
|
||||
@@ -1,116 +0,0 @@
|
||||
diff --git a/node_modules/redux-persist/es/createPersistoid.js b/node_modules/redux-persist/es/createPersistoid.js
|
||||
index 8b43b9a..184faab 100644
|
||||
--- a/node_modules/redux-persist/es/createPersistoid.js
|
||||
+++ b/node_modules/redux-persist/es/createPersistoid.js
|
||||
@@ -6,6 +6,7 @@ export default function createPersistoid(config) {
|
||||
var whitelist = config.whitelist || null;
|
||||
var transforms = config.transforms || [];
|
||||
var throttle = config.throttle || 0;
|
||||
+ var debounce = config.debounce || 0;
|
||||
var storageKey = "".concat(config.keyPrefix !== undefined ? config.keyPrefix : KEY_PREFIX).concat(config.key);
|
||||
var storage = config.storage;
|
||||
var serialize;
|
||||
@@ -28,30 +29,37 @@ export default function createPersistoid(config) {
|
||||
var timeIterator = null;
|
||||
var writePromise = null;
|
||||
|
||||
- var update = function update(state) {
|
||||
- // add any changed keys to the queue
|
||||
- Object.keys(state).forEach(function (key) {
|
||||
- if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop
|
||||
+ // Timer for debounced `update()`
|
||||
+ let timer = 0;
|
||||
|
||||
- if (lastState[key] === state[key]) return; // value unchanged? noop
|
||||
+ function update(state) {
|
||||
+ // Debounce the update
|
||||
+ clearTimeout(timer);
|
||||
+ timer = setTimeout(() => {
|
||||
+ // add any changed keys to the queue
|
||||
+ Object.keys(state).forEach(function (key) {
|
||||
+ if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop
|
||||
|
||||
- if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop
|
||||
+ if (lastState[key] === state[key]) return; // value unchanged? noop
|
||||
|
||||
- keysToProcess.push(key); // add key to queue
|
||||
- }); //if any key is missing in the new state which was present in the lastState,
|
||||
- //add it for processing too
|
||||
+ if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop
|
||||
|
||||
- Object.keys(lastState).forEach(function (key) {
|
||||
- if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) {
|
||||
- keysToProcess.push(key);
|
||||
- }
|
||||
- }); // start the time iterator if not running (read: throttle)
|
||||
+ keysToProcess.push(key); // add key to queue
|
||||
+ }); //if any key is missing in the new state which was present in the lastState,
|
||||
+ //add it for processing too
|
||||
|
||||
- if (timeIterator === null) {
|
||||
- timeIterator = setInterval(processNextKey, throttle);
|
||||
- }
|
||||
+ Object.keys(lastState).forEach(function (key) {
|
||||
+ if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) {
|
||||
+ keysToProcess.push(key);
|
||||
+ }
|
||||
+ }); // start the time iterator if not running (read: throttle)
|
||||
+
|
||||
+ if (timeIterator === null) {
|
||||
+ timeIterator = setInterval(processNextKey, throttle);
|
||||
+ }
|
||||
|
||||
- lastState = state;
|
||||
+ lastState = state;
|
||||
+ }, debounce)
|
||||
};
|
||||
|
||||
function processNextKey() {
|
||||
diff --git a/node_modules/redux-persist/es/types.js.flow b/node_modules/redux-persist/es/types.js.flow
|
||||
index c50d3cd..39d8be2 100644
|
||||
--- a/node_modules/redux-persist/es/types.js.flow
|
||||
+++ b/node_modules/redux-persist/es/types.js.flow
|
||||
@@ -19,6 +19,7 @@ export type PersistConfig = {
|
||||
whitelist?: Array<string>,
|
||||
transforms?: Array<Transform>,
|
||||
throttle?: number,
|
||||
+ debounce?: number,
|
||||
migrate?: (PersistedState, number) => Promise<PersistedState>,
|
||||
stateReconciler?: false | Function,
|
||||
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
|
||||
diff --git a/node_modules/redux-persist/lib/types.js.flow b/node_modules/redux-persist/lib/types.js.flow
|
||||
index c50d3cd..39d8be2 100644
|
||||
--- a/node_modules/redux-persist/lib/types.js.flow
|
||||
+++ b/node_modules/redux-persist/lib/types.js.flow
|
||||
@@ -19,6 +19,7 @@ export type PersistConfig = {
|
||||
whitelist?: Array<string>,
|
||||
transforms?: Array<Transform>,
|
||||
throttle?: number,
|
||||
+ debounce?: number,
|
||||
migrate?: (PersistedState, number) => Promise<PersistedState>,
|
||||
stateReconciler?: false | Function,
|
||||
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
|
||||
diff --git a/node_modules/redux-persist/src/types.js b/node_modules/redux-persist/src/types.js
|
||||
index c50d3cd..39d8be2 100644
|
||||
--- a/node_modules/redux-persist/src/types.js
|
||||
+++ b/node_modules/redux-persist/src/types.js
|
||||
@@ -19,6 +19,7 @@ export type PersistConfig = {
|
||||
whitelist?: Array<string>,
|
||||
transforms?: Array<Transform>,
|
||||
throttle?: number,
|
||||
+ debounce?: number,
|
||||
migrate?: (PersistedState, number) => Promise<PersistedState>,
|
||||
stateReconciler?: false | Function,
|
||||
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
|
||||
diff --git a/node_modules/redux-persist/types/types.d.ts b/node_modules/redux-persist/types/types.d.ts
|
||||
index b3733bc..2a1696c 100644
|
||||
--- a/node_modules/redux-persist/types/types.d.ts
|
||||
+++ b/node_modules/redux-persist/types/types.d.ts
|
||||
@@ -35,6 +35,7 @@ declare module "redux-persist/es/types" {
|
||||
whitelist?: Array<string>;
|
||||
transforms?: Array<Transform<HSS, ESS, S, RS>>;
|
||||
throttle?: number;
|
||||
+ debounce?: number;
|
||||
migrate?: PersistMigrate;
|
||||
stateReconciler?: false | StateReconciler<S>;
|
||||
/**
|
||||
@@ -25,7 +25,7 @@
|
||||
"common": {
|
||||
"hotkeysLabel": "Hotkeys",
|
||||
"themeLabel": "Theme",
|
||||
"languagePickerLabel": "Language Picker",
|
||||
"languagePickerLabel": "Language",
|
||||
"reportBugLabel": "Report Bug",
|
||||
"githubLabel": "Github",
|
||||
"discordLabel": "Discord",
|
||||
@@ -54,7 +54,7 @@
|
||||
"img2img": "Image To Image",
|
||||
"unifiedCanvas": "Unified Canvas",
|
||||
"linear": "Linear",
|
||||
"nodes": "Nodes",
|
||||
"nodes": "Node Editor",
|
||||
"postprocessing": "Post Processing",
|
||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||
"postProcessing": "Post Processing",
|
||||
@@ -102,7 +102,8 @@
|
||||
"generate": "Generate",
|
||||
"openInNewTab": "Open in New Tab",
|
||||
"dontAskMeAgain": "Don't ask me again",
|
||||
"areYouSure": "Are you sure?"
|
||||
"areYouSure": "Are you sure?",
|
||||
"imagePrompt": "Image Prompt"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generations",
|
||||
@@ -449,13 +450,14 @@
|
||||
"cfgScale": "CFG Scale",
|
||||
"width": "Width",
|
||||
"height": "Height",
|
||||
"sampler": "Sampler",
|
||||
"scheduler": "Scheduler",
|
||||
"seed": "Seed",
|
||||
"imageToImage": "Image to Image",
|
||||
"randomizeSeed": "Randomize Seed",
|
||||
"shuffle": "Shuffle",
|
||||
"shuffle": "Shuffle Seed",
|
||||
"noiseThreshold": "Noise Threshold",
|
||||
"perlinNoise": "Perlin Noise",
|
||||
"noiseSettings": "Noise",
|
||||
"variations": "Variations",
|
||||
"variationAmount": "Variation Amount",
|
||||
"seedWeights": "Seed Weights",
|
||||
@@ -470,6 +472,8 @@
|
||||
"scale": "Scale",
|
||||
"otherOptions": "Other Options",
|
||||
"seamlessTiling": "Seamless Tiling",
|
||||
"seamlessXAxis": "X Axis",
|
||||
"seamlessYAxis": "Y Axis",
|
||||
"hiresOptim": "High Res Optimization",
|
||||
"hiresStrength": "High Res Strength",
|
||||
"imageFit": "Fit Initial Image To Output Size",
|
||||
@@ -527,7 +531,8 @@
|
||||
"useCanvasBeta": "Use Canvas Beta Layout",
|
||||
"enableImageDebugging": "Enable Image Debugging",
|
||||
"useSlidersForAll": "Use Sliders For All Options",
|
||||
"autoShowProgress": "Auto Show Progress Images",
|
||||
"showProgressInViewer": "Show Progress Images in Viewer",
|
||||
"antialiasProgressImages": "Antialias Progress Images",
|
||||
"resetWebUI": "Reset Web UI",
|
||||
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
||||
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
|
||||
@@ -535,7 +540,10 @@
|
||||
"consoleLogLevel": "Log Level",
|
||||
"shouldLogToConsole": "Console Logging",
|
||||
"developer": "Developer",
|
||||
"general": "General"
|
||||
"general": "General",
|
||||
"generation": "Generation",
|
||||
"ui": "User Interface",
|
||||
"availableSchedulers": "Available Schedulers"
|
||||
},
|
||||
"toast": {
|
||||
"serverError": "Server Error",
|
||||
@@ -544,13 +552,14 @@
|
||||
"canceled": "Processing Canceled",
|
||||
"tempFoldersEmptied": "Temp Folder Emptied",
|
||||
"uploadFailed": "Upload failed",
|
||||
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",
|
||||
"uploadFailedUnableToLoadDesc": "Unable to load file",
|
||||
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
|
||||
"downloadImageStarted": "Image Download Started",
|
||||
"imageCopied": "Image Copied",
|
||||
"imageLinkCopied": "Image Link Copied",
|
||||
"problemCopyingImageLink": "Unable to Copy Image Link",
|
||||
"imageNotLoaded": "No Image Loaded",
|
||||
"imageNotLoadedDesc": "No image found to send to image to image module",
|
||||
"imageNotLoadedDesc": "Could not find image",
|
||||
"imageSavedToGallery": "Image Saved to Gallery",
|
||||
"canvasMerged": "Canvas Merged",
|
||||
"sentToImageToImage": "Sent To Image To Image",
|
||||
@@ -645,7 +654,8 @@
|
||||
"betaClear": "Clear",
|
||||
"betaDarkenOutside": "Darken Outside",
|
||||
"betaLimitToBox": "Limit To Box",
|
||||
"betaPreserveMasked": "Preserve Masked"
|
||||
"betaPreserveMasked": "Preserve Masked",
|
||||
"antialiasing": "Antialiasing"
|
||||
},
|
||||
"ui": {
|
||||
"showProgressImages": "Show Progress Images",
|
||||
|
||||
@@ -1,46 +1,44 @@
|
||||
import ImageUploader from 'common/components/ImageUploader';
|
||||
import ProgressBar from 'features/system/components/ProgressBar';
|
||||
import SiteHeader from 'features/system/components/SiteHeader';
|
||||
import ProgressBar from 'features/system/components/ProgressBar';
|
||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||
|
||||
import useToastWatcher from 'features/system/hooks/useToastWatcher';
|
||||
|
||||
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
||||
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
||||
import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
|
||||
import { Box, Flex, Grid, Portal } from '@chakra-ui/react';
|
||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
||||
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
memo,
|
||||
PropsWithChildren,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useState,
|
||||
} from 'react';
|
||||
import { memo, ReactNode, useCallback, useEffect, useState } from 'react';
|
||||
import { motion, AnimatePresence } from 'framer-motion';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
|
||||
import { PartialAppConfig } from 'app/types/invokeai';
|
||||
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import ProgressImagePreview from 'features/parameters/components/ProgressImagePreview';
|
||||
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||
import i18n from 'i18n';
|
||||
import Toaster from './Toaster';
|
||||
import GlobalHotkeys from './GlobalHotkeys';
|
||||
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
interface Props extends PropsWithChildren {
|
||||
interface Props {
|
||||
config?: PartialAppConfig;
|
||||
headerComponent?: ReactNode;
|
||||
setIsReady?: (isReady: boolean) => void;
|
||||
}
|
||||
|
||||
const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
||||
useToastWatcher();
|
||||
useGlobalHotkeys();
|
||||
const log = useLogger();
|
||||
const App = ({
|
||||
config = DEFAULT_CONFIG,
|
||||
headerComponent,
|
||||
setIsReady,
|
||||
}: Props) => {
|
||||
const language = useAppSelector(languageSelector);
|
||||
|
||||
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
||||
const log = useLogger();
|
||||
|
||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||
|
||||
@@ -48,81 +46,95 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
||||
|
||||
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
||||
|
||||
const { setColorMode } = useColorMode();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
useEffect(() => {
|
||||
i18n.changeLanguage(language);
|
||||
}, [language]);
|
||||
|
||||
useEffect(() => {
|
||||
log.info({ namespace: 'App', data: config }, 'Received config');
|
||||
dispatch(configChanged(config));
|
||||
}, [dispatch, config, log]);
|
||||
|
||||
useEffect(() => {
|
||||
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
|
||||
}, [setColorMode, currentTheme]);
|
||||
|
||||
const handleOverrideClicked = useCallback(() => {
|
||||
setLoadingOverridden(true);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (isApplicationReady && setIsReady) {
|
||||
setIsReady(true);
|
||||
}
|
||||
|
||||
return () => {
|
||||
setIsReady && setIsReady(false);
|
||||
};
|
||||
}, [isApplicationReady, setIsReady]);
|
||||
|
||||
return (
|
||||
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
|
||||
{isLightboxEnabled && <Lightbox />}
|
||||
<ImageUploader>
|
||||
<ProgressBar />
|
||||
<Grid
|
||||
gap={4}
|
||||
p={4}
|
||||
gridAutoRows="min-content auto"
|
||||
w={APP_WIDTH}
|
||||
h={APP_HEIGHT}
|
||||
>
|
||||
{children || <SiteHeader />}
|
||||
<Flex
|
||||
<>
|
||||
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
|
||||
{isLightboxEnabled && <Lightbox />}
|
||||
<ImageUploader>
|
||||
<ProgressBar />
|
||||
<Grid
|
||||
gap={4}
|
||||
w={{ base: '100vw', xl: 'full' }}
|
||||
h="full"
|
||||
flexDir={{ base: 'column', xl: 'row' }}
|
||||
p={4}
|
||||
gridAutoRows="min-content auto"
|
||||
w={APP_WIDTH}
|
||||
h={APP_HEIGHT}
|
||||
>
|
||||
<InvokeTabs />
|
||||
<ImageGalleryPanel />
|
||||
</Flex>
|
||||
</Grid>
|
||||
</ImageUploader>
|
||||
{headerComponent || <SiteHeader />}
|
||||
<Flex
|
||||
gap={4}
|
||||
w={{ base: '100vw', xl: 'full' }}
|
||||
h="full"
|
||||
flexDir={{ base: 'column', xl: 'row' }}
|
||||
>
|
||||
<InvokeTabs />
|
||||
</Flex>
|
||||
</Grid>
|
||||
</ImageUploader>
|
||||
|
||||
<AnimatePresence>
|
||||
{!isApplicationReady && !loadingOverridden && (
|
||||
<motion.div
|
||||
key="loading"
|
||||
initial={{ opacity: 1 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
style={{ zIndex: 3 }}
|
||||
>
|
||||
<Box position="absolute" top={0} left={0} w="100vw" h="100vh">
|
||||
<Loading />
|
||||
</Box>
|
||||
<Box
|
||||
onClick={handleOverrideClicked}
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
cursor="pointer"
|
||||
w="2rem"
|
||||
h="2rem"
|
||||
/>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
<GalleryDrawer />
|
||||
<ParametersDrawer />
|
||||
|
||||
<Portal>
|
||||
<FloatingParametersPanelButtons />
|
||||
</Portal>
|
||||
<Portal>
|
||||
<FloatingGalleryButton />
|
||||
</Portal>
|
||||
<ProgressImagePreview />
|
||||
</Grid>
|
||||
<AnimatePresence>
|
||||
{!isApplicationReady && !loadingOverridden && (
|
||||
<motion.div
|
||||
key="loading"
|
||||
initial={{ opacity: 1 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
style={{ zIndex: 3 }}
|
||||
>
|
||||
<Box position="absolute" top={0} left={0} w="100vw" h="100vh">
|
||||
<Loading />
|
||||
</Box>
|
||||
<Box
|
||||
onClick={handleOverrideClicked}
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
cursor="pointer"
|
||||
w="2rem"
|
||||
h="2rem"
|
||||
/>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
||||
<Portal>
|
||||
<FloatingParametersPanelButtons />
|
||||
</Portal>
|
||||
<Portal>
|
||||
<FloatingGalleryButton />
|
||||
</Portal>
|
||||
</Grid>
|
||||
<Toaster />
|
||||
<GlobalHotkeys />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import { Flex, Spinner, Tooltip } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(systemSelector, (system) => {
|
||||
const { isUploading } = system;
|
||||
|
||||
let tooltip = '';
|
||||
|
||||
if (isUploading) {
|
||||
tooltip = 'Uploading...';
|
||||
}
|
||||
|
||||
return {
|
||||
tooltip,
|
||||
shouldShow: isUploading,
|
||||
};
|
||||
});
|
||||
|
||||
export const AuxiliaryProgressIndicator = () => {
|
||||
const { shouldShow, tooltip } = useAppSelector(selector);
|
||||
|
||||
if (!shouldShow) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
color: 'base.600',
|
||||
}}
|
||||
>
|
||||
<Tooltip label={tooltip} placement="right" hasArrow>
|
||||
<Spinner />
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AuxiliaryProgressIndicator);
|
||||
@@ -2,7 +2,15 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import {
|
||||
setActiveTab,
|
||||
toggleGalleryPanel,
|
||||
toggleParametersPanel,
|
||||
togglePinGalleryPanel,
|
||||
togglePinParametersPanel,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import React, { memo } from 'react';
|
||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
const globalHotkeysSelector = createSelector(
|
||||
@@ -20,7 +28,11 @@ const globalHotkeysSelector = createSelector(
|
||||
|
||||
// TODO: Does not catch keypresses while focused in an input. Maybe there is a way?
|
||||
|
||||
export const useGlobalHotkeys = () => {
|
||||
/**
|
||||
* Logical component. Handles app-level global hotkeys.
|
||||
* @returns null
|
||||
*/
|
||||
const GlobalHotkeys: React.FC = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { shift } = useAppSelector(globalHotkeysSelector);
|
||||
|
||||
@@ -36,4 +48,40 @@ export const useGlobalHotkeys = () => {
|
||||
{ keyup: true, keydown: true },
|
||||
[shift]
|
||||
);
|
||||
|
||||
useHotkeys('o', () => {
|
||||
dispatch(toggleParametersPanel());
|
||||
});
|
||||
|
||||
useHotkeys(['shift+o'], () => {
|
||||
dispatch(togglePinParametersPanel());
|
||||
});
|
||||
|
||||
useHotkeys('g', () => {
|
||||
dispatch(toggleGalleryPanel());
|
||||
});
|
||||
|
||||
useHotkeys(['shift+g'], () => {
|
||||
dispatch(togglePinGalleryPanel());
|
||||
});
|
||||
|
||||
useHotkeys('1', () => {
|
||||
dispatch(setActiveTab('txt2img'));
|
||||
});
|
||||
|
||||
useHotkeys('2', () => {
|
||||
dispatch(setActiveTab('img2img'));
|
||||
});
|
||||
|
||||
useHotkeys('3', () => {
|
||||
dispatch(setActiveTab('unifiedCanvas'));
|
||||
});
|
||||
|
||||
useHotkeys('4', () => {
|
||||
dispatch(setActiveTab('nodes'));
|
||||
});
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(GlobalHotkeys);
|
||||
@@ -1,18 +1,13 @@
|
||||
import React, { lazy, memo, PropsWithChildren, useEffect } from 'react';
|
||||
import React, {
|
||||
lazy,
|
||||
memo,
|
||||
PropsWithChildren,
|
||||
ReactNode,
|
||||
useEffect,
|
||||
} from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { PersistGate } from 'redux-persist/integration/react';
|
||||
import { store } from 'app/store/store';
|
||||
import { persistor } from '../store/persistor';
|
||||
import { OpenAPI } from 'services/api';
|
||||
import '@fontsource/inter/100.css';
|
||||
import '@fontsource/inter/200.css';
|
||||
import '@fontsource/inter/300.css';
|
||||
import '@fontsource/inter/400.css';
|
||||
import '@fontsource/inter/500.css';
|
||||
import '@fontsource/inter/600.css';
|
||||
import '@fontsource/inter/700.css';
|
||||
import '@fontsource/inter/800.css';
|
||||
import '@fontsource/inter/900.css';
|
||||
|
||||
import Loading from '../../common/components/Loading/Loading';
|
||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||
@@ -28,9 +23,17 @@ interface Props extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
token?: string;
|
||||
config?: PartialAppConfig;
|
||||
headerComponent?: ReactNode;
|
||||
setIsReady?: (isReady: boolean) => void;
|
||||
}
|
||||
|
||||
const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
|
||||
const InvokeAIUI = ({
|
||||
apiUrl,
|
||||
token,
|
||||
config,
|
||||
headerComponent,
|
||||
setIsReady,
|
||||
}: Props) => {
|
||||
useEffect(() => {
|
||||
// configure API client token
|
||||
if (token) {
|
||||
@@ -57,13 +60,15 @@ const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
|
||||
return (
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App config={config}>{children}</App>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</PersistGate>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App
|
||||
config={config}
|
||||
headerComponent={headerComponent}
|
||||
setIsReady={setIsReady}
|
||||
/>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</Provider>
|
||||
</React.StrictMode>
|
||||
);
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import { ChakraProvider, extendTheme } from '@chakra-ui/react';
|
||||
import {
|
||||
ChakraProvider,
|
||||
createLocalStorageManager,
|
||||
extendTheme,
|
||||
} from '@chakra-ui/react';
|
||||
import { ReactNode, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { theme as invokeAITheme } from 'theme/theme';
|
||||
@@ -9,15 +13,8 @@ import { greenTeaThemeColors } from 'theme/colors/greenTea';
|
||||
import { invokeAIThemeColors } from 'theme/colors/invokeAI';
|
||||
import { lightThemeColors } from 'theme/colors/lightTheme';
|
||||
import { oceanBlueColors } from 'theme/colors/oceanBlue';
|
||||
import '@fontsource/inter/100.css';
|
||||
import '@fontsource/inter/200.css';
|
||||
import '@fontsource/inter/300.css';
|
||||
import '@fontsource/inter/400.css';
|
||||
import '@fontsource/inter/500.css';
|
||||
import '@fontsource/inter/600.css';
|
||||
import '@fontsource/inter/700.css';
|
||||
import '@fontsource/inter/800.css';
|
||||
import '@fontsource/inter/900.css';
|
||||
|
||||
import '@fontsource/inter/variable.css';
|
||||
import 'overlayscrollbars/overlayscrollbars.css';
|
||||
import 'theme/css/overlayscrollbars.css';
|
||||
|
||||
@@ -32,6 +29,8 @@ const THEMES = {
|
||||
ocean: oceanBlueColors,
|
||||
};
|
||||
|
||||
const manager = createLocalStorageManager('@@invokeai-color-mode');
|
||||
|
||||
function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
const { i18n } = useTranslation();
|
||||
|
||||
@@ -51,7 +50,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
document.body.dir = direction;
|
||||
}, [direction]);
|
||||
|
||||
return <ChakraProvider theme={theme}>{children}</ChakraProvider>;
|
||||
return (
|
||||
<ChakraProvider theme={theme} colorModeManager={manager}>
|
||||
{children}
|
||||
</ChakraProvider>
|
||||
);
|
||||
}
|
||||
|
||||
export default ThemeLocaleProvider;
|
||||
|
||||
65
invokeai/frontend/web/src/app/components/Toaster.ts
Normal file
65
invokeai/frontend/web/src/app/components/Toaster.ts
Normal file
@@ -0,0 +1,65 @@
|
||||
import { useToast, UseToastOptions } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { toastQueueSelector } from 'features/system/store/systemSelectors';
|
||||
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
|
||||
export type MakeToastArg = string | UseToastOptions;
|
||||
|
||||
/**
|
||||
* Makes a toast from a string or a UseToastOptions object.
|
||||
* If a string is passed, the toast will have the status 'info' and will be closable with a duration of 2500ms.
|
||||
*/
|
||||
export const makeToast = (arg: MakeToastArg): UseToastOptions => {
|
||||
if (typeof arg === 'string') {
|
||||
return {
|
||||
title: arg,
|
||||
status: 'info',
|
||||
isClosable: true,
|
||||
duration: 2500,
|
||||
};
|
||||
}
|
||||
|
||||
return { status: 'info', isClosable: true, duration: 2500, ...arg };
|
||||
};
|
||||
|
||||
/**
|
||||
* Logical component. Watches the toast queue and makes toasts when the queue is not empty.
|
||||
* @returns null
|
||||
*/
|
||||
const Toaster = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const toastQueue = useAppSelector(toastQueueSelector);
|
||||
const toast = useToast();
|
||||
useEffect(() => {
|
||||
toastQueue.forEach((t) => {
|
||||
toast(t);
|
||||
});
|
||||
toastQueue.length > 0 && dispatch(clearToastQueue());
|
||||
}, [dispatch, toast, toastQueue]);
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns a function that can be used to make a toast.
|
||||
* @example
|
||||
* const toaster = useAppToaster();
|
||||
* toaster('Hello world!');
|
||||
* toaster({ title: 'Hello world!', status: 'success' });
|
||||
* @returns A function that can be used to make a toast.
|
||||
* @see makeToast
|
||||
* @see MakeToastArg
|
||||
* @see UseToastOptions
|
||||
*/
|
||||
export const useAppToaster = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const toaster = useCallback(
|
||||
(arg: MakeToastArg) => dispatch(addToast(makeToast(arg))),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return toaster;
|
||||
};
|
||||
|
||||
export default Toaster;
|
||||
@@ -1,17 +1,28 @@
|
||||
// TODO: use Enums?
|
||||
|
||||
export const DIFFUSERS_SCHEDULERS: Array<string> = [
|
||||
export const SCHEDULERS = [
|
||||
'ddim',
|
||||
'plms',
|
||||
'k_lms',
|
||||
'dpmpp_2',
|
||||
'k_dpm_2',
|
||||
'k_dpm_2_a',
|
||||
'k_dpmpp_2',
|
||||
'k_euler',
|
||||
'k_euler_a',
|
||||
'k_heun',
|
||||
];
|
||||
'lms',
|
||||
'euler',
|
||||
'euler_k',
|
||||
'euler_a',
|
||||
'dpmpp_2s',
|
||||
'dpmpp_2m',
|
||||
'dpmpp_2m_k',
|
||||
'kdpm_2',
|
||||
'kdpm_2_a',
|
||||
'deis',
|
||||
'ddpm',
|
||||
'pndm',
|
||||
'heun',
|
||||
'heun_k',
|
||||
'unipc',
|
||||
] as const;
|
||||
|
||||
export type Scheduler = (typeof SCHEDULERS)[number];
|
||||
|
||||
export const isScheduler = (x: string): x is Scheduler =>
|
||||
SCHEDULERS.includes(x as Scheduler);
|
||||
|
||||
// Valid image widths
|
||||
export const WIDTHS: Array<number> = Array.from(Array(64)).map(
|
||||
|
||||
@@ -1,26 +1,20 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||
import { initialCanvasImageSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
export const readinessSelector = createSelector(
|
||||
[
|
||||
generationSelector,
|
||||
systemSelector,
|
||||
initialCanvasImageSelector,
|
||||
activeTabNameSelector,
|
||||
],
|
||||
(generation, system, initialCanvasImage, activeTabName) => {
|
||||
[generationSelector, systemSelector, activeTabNameSelector],
|
||||
(generation, system, activeTabName) => {
|
||||
const {
|
||||
prompt,
|
||||
shouldGenerateVariations,
|
||||
seedWeights,
|
||||
initialImage,
|
||||
seed,
|
||||
isImageToImageEnabled,
|
||||
} = generation;
|
||||
|
||||
const { isProcessing, isConnected } = system;
|
||||
@@ -34,7 +28,7 @@ export const readinessSelector = createSelector(
|
||||
reasonsWhyNotReady.push('Missing prompt');
|
||||
}
|
||||
|
||||
if (isImageToImageEnabled && !initialImage) {
|
||||
if (activeTabName === 'img2img' && !initialImage) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('No initial image selected');
|
||||
}
|
||||
@@ -64,10 +58,5 @@ export const readinessSelector = createSelector(
|
||||
// All good
|
||||
return { isReady, reasonsWhyNotReady };
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
equalityCheck: isEqual,
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
@@ -1,209 +1,209 @@
|
||||
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
// import * as InvokeAI from 'app/types/invokeai';
|
||||
// import type { RootState } from 'app/store/store';
|
||||
// import {
|
||||
// frontendToBackendParameters,
|
||||
// FrontendToBackendParametersConfig,
|
||||
// } from 'common/util/parameterTranslation';
|
||||
// import dateFormat from 'dateformat';
|
||||
// import {
|
||||
// GalleryCategory,
|
||||
// GalleryState,
|
||||
// removeImage,
|
||||
// } from 'features/gallery/store/gallerySlice';
|
||||
// import {
|
||||
// generationRequested,
|
||||
// modelChangeRequested,
|
||||
// modelConvertRequested,
|
||||
// modelMergingRequested,
|
||||
// setIsProcessing,
|
||||
// } from 'features/system/store/systemSlice';
|
||||
// import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
// import { Socket } from 'socket.io-client';
|
||||
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/types/invokeai';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import {
|
||||
frontendToBackendParameters,
|
||||
FrontendToBackendParametersConfig,
|
||||
} from 'common/util/parameterTranslation';
|
||||
import dateFormat from 'dateformat';
|
||||
import {
|
||||
GalleryCategory,
|
||||
GalleryState,
|
||||
removeImage,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
generationRequested,
|
||||
modelChangeRequested,
|
||||
modelConvertRequested,
|
||||
modelMergingRequested,
|
||||
setIsProcessing,
|
||||
} from 'features/system/store/systemSlice';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { Socket } from 'socket.io-client';
|
||||
|
||||
// /**
|
||||
// * Returns an object containing all functions which use `socketio.emit()`.
|
||||
// * i.e. those which make server requests.
|
||||
// */
|
||||
// const makeSocketIOEmitters = (
|
||||
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
|
||||
// socketio: Socket
|
||||
// ) => {
|
||||
// // We need to dispatch actions to redux and get pieces of state from the store.
|
||||
// const { dispatch, getState } = store;
|
||||
/**
|
||||
* Returns an object containing all functions which use `socketio.emit()`.
|
||||
* i.e. those which make server requests.
|
||||
*/
|
||||
const makeSocketIOEmitters = (
|
||||
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
|
||||
socketio: Socket
|
||||
) => {
|
||||
// We need to dispatch actions to redux and get pieces of state from the store.
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
// return {
|
||||
// emitGenerateImage: (generationMode: InvokeTabName) => {
|
||||
// dispatch(setIsProcessing(true));
|
||||
return {
|
||||
emitGenerateImage: (generationMode: InvokeTabName) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
// const state: RootState = getState();
|
||||
const state: RootState = getState();
|
||||
|
||||
// const {
|
||||
// generation: generationState,
|
||||
// postprocessing: postprocessingState,
|
||||
// system: systemState,
|
||||
// canvas: canvasState,
|
||||
// } = state;
|
||||
const {
|
||||
generation: generationState,
|
||||
postprocessing: postprocessingState,
|
||||
system: systemState,
|
||||
canvas: canvasState,
|
||||
} = state;
|
||||
|
||||
// const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
|
||||
// {
|
||||
// generationMode,
|
||||
// generationState,
|
||||
// postprocessingState,
|
||||
// canvasState,
|
||||
// systemState,
|
||||
// };
|
||||
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
|
||||
{
|
||||
generationMode,
|
||||
generationState,
|
||||
postprocessingState,
|
||||
canvasState,
|
||||
systemState,
|
||||
};
|
||||
|
||||
// dispatch(generationRequested());
|
||||
dispatch(generationRequested());
|
||||
|
||||
// const { generationParameters, esrganParameters, facetoolParameters } =
|
||||
// frontendToBackendParameters(frontendToBackendParametersConfig);
|
||||
const { generationParameters, esrganParameters, facetoolParameters } =
|
||||
frontendToBackendParameters(frontendToBackendParametersConfig);
|
||||
|
||||
// socketio.emit(
|
||||
// 'generateImage',
|
||||
// generationParameters,
|
||||
// esrganParameters,
|
||||
// facetoolParameters
|
||||
// );
|
||||
socketio.emit(
|
||||
'generateImage',
|
||||
generationParameters,
|
||||
esrganParameters,
|
||||
facetoolParameters
|
||||
);
|
||||
|
||||
// // we need to truncate the init_mask base64 else it takes up the whole log
|
||||
// // TODO: handle maintaining masks for reproducibility in future
|
||||
// if (generationParameters.init_mask) {
|
||||
// generationParameters.init_mask = generationParameters.init_mask
|
||||
// .substr(0, 64)
|
||||
// .concat('...');
|
||||
// }
|
||||
// if (generationParameters.init_img) {
|
||||
// generationParameters.init_img = generationParameters.init_img
|
||||
// .substr(0, 64)
|
||||
// .concat('...');
|
||||
// }
|
||||
// we need to truncate the init_mask base64 else it takes up the whole log
|
||||
// TODO: handle maintaining masks for reproducibility in future
|
||||
if (generationParameters.init_mask) {
|
||||
generationParameters.init_mask = generationParameters.init_mask
|
||||
.substr(0, 64)
|
||||
.concat('...');
|
||||
}
|
||||
if (generationParameters.init_img) {
|
||||
generationParameters.init_img = generationParameters.init_img
|
||||
.substr(0, 64)
|
||||
.concat('...');
|
||||
}
|
||||
|
||||
// dispatch(
|
||||
// addLogEntry({
|
||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
// message: `Image generation requested: ${JSON.stringify({
|
||||
// ...generationParameters,
|
||||
// ...esrganParameters,
|
||||
// ...facetoolParameters,
|
||||
// })}`,
|
||||
// })
|
||||
// );
|
||||
// },
|
||||
// emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
||||
// dispatch(setIsProcessing(true));
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Image generation requested: ${JSON.stringify({
|
||||
...generationParameters,
|
||||
...esrganParameters,
|
||||
...facetoolParameters,
|
||||
})}`,
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
// const {
|
||||
// postprocessing: {
|
||||
// upscalingLevel,
|
||||
// upscalingDenoising,
|
||||
// upscalingStrength,
|
||||
// },
|
||||
// } = getState();
|
||||
const {
|
||||
postprocessing: {
|
||||
upscalingLevel,
|
||||
upscalingDenoising,
|
||||
upscalingStrength,
|
||||
},
|
||||
} = getState();
|
||||
|
||||
// const esrganParameters = {
|
||||
// upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
|
||||
// };
|
||||
// socketio.emit('runPostprocessing', imageToProcess, {
|
||||
// type: 'esrgan',
|
||||
// ...esrganParameters,
|
||||
// });
|
||||
// dispatch(
|
||||
// addLogEntry({
|
||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
// message: `ESRGAN upscale requested: ${JSON.stringify({
|
||||
// file: imageToProcess.url,
|
||||
// ...esrganParameters,
|
||||
// })}`,
|
||||
// })
|
||||
// );
|
||||
// },
|
||||
// emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
||||
// dispatch(setIsProcessing(true));
|
||||
const esrganParameters = {
|
||||
upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
|
||||
};
|
||||
socketio.emit('runPostprocessing', imageToProcess, {
|
||||
type: 'esrgan',
|
||||
...esrganParameters,
|
||||
});
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `ESRGAN upscale requested: ${JSON.stringify({
|
||||
file: imageToProcess.url,
|
||||
...esrganParameters,
|
||||
})}`,
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
// const {
|
||||
// postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
|
||||
// } = getState();
|
||||
const {
|
||||
postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
|
||||
} = getState();
|
||||
|
||||
// const facetoolParameters: Record<string, unknown> = {
|
||||
// facetool_strength: facetoolStrength,
|
||||
// };
|
||||
const facetoolParameters: Record<string, unknown> = {
|
||||
facetool_strength: facetoolStrength,
|
||||
};
|
||||
|
||||
// if (facetoolType === 'codeformer') {
|
||||
// facetoolParameters.codeformer_fidelity = codeformerFidelity;
|
||||
// }
|
||||
if (facetoolType === 'codeformer') {
|
||||
facetoolParameters.codeformer_fidelity = codeformerFidelity;
|
||||
}
|
||||
|
||||
// socketio.emit('runPostprocessing', imageToProcess, {
|
||||
// type: facetoolType,
|
||||
// ...facetoolParameters,
|
||||
// });
|
||||
// dispatch(
|
||||
// addLogEntry({
|
||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
// message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
|
||||
// {
|
||||
// file: imageToProcess.url,
|
||||
// ...facetoolParameters,
|
||||
// }
|
||||
// )}`,
|
||||
// })
|
||||
// );
|
||||
// },
|
||||
// emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
||||
// const { url, uuid, category, thumbnail } = imageToDelete;
|
||||
// dispatch(removeImage(imageToDelete));
|
||||
// socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
||||
// },
|
||||
// emitRequestImages: (category: GalleryCategory) => {
|
||||
// const gallery: GalleryState = getState().gallery;
|
||||
// const { earliest_mtime } = gallery.categories[category];
|
||||
// socketio.emit('requestImages', category, earliest_mtime);
|
||||
// },
|
||||
// emitRequestNewImages: (category: GalleryCategory) => {
|
||||
// const gallery: GalleryState = getState().gallery;
|
||||
// const { latest_mtime } = gallery.categories[category];
|
||||
// socketio.emit('requestLatestImages', category, latest_mtime);
|
||||
// },
|
||||
// emitCancelProcessing: () => {
|
||||
// socketio.emit('cancel');
|
||||
// },
|
||||
// emitRequestSystemConfig: () => {
|
||||
// socketio.emit('requestSystemConfig');
|
||||
// },
|
||||
// emitSearchForModels: (modelFolder: string) => {
|
||||
// socketio.emit('searchForModels', modelFolder);
|
||||
// },
|
||||
// emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
|
||||
// socketio.emit('addNewModel', modelConfig);
|
||||
// },
|
||||
// emitDeleteModel: (modelName: string) => {
|
||||
// socketio.emit('deleteModel', modelName);
|
||||
// },
|
||||
// emitConvertToDiffusers: (
|
||||
// modelToConvert: InvokeAI.InvokeModelConversionProps
|
||||
// ) => {
|
||||
// dispatch(modelConvertRequested());
|
||||
// socketio.emit('convertToDiffusers', modelToConvert);
|
||||
// },
|
||||
// emitMergeDiffusersModels: (
|
||||
// modelMergeInfo: InvokeAI.InvokeModelMergingProps
|
||||
// ) => {
|
||||
// dispatch(modelMergingRequested());
|
||||
// socketio.emit('mergeDiffusersModels', modelMergeInfo);
|
||||
// },
|
||||
// emitRequestModelChange: (modelName: string) => {
|
||||
// dispatch(modelChangeRequested());
|
||||
// socketio.emit('requestModelChange', modelName);
|
||||
// },
|
||||
// emitSaveStagingAreaImageToGallery: (url: string) => {
|
||||
// socketio.emit('requestSaveStagingAreaImageToGallery', url);
|
||||
// },
|
||||
// emitRequestEmptyTempFolder: () => {
|
||||
// socketio.emit('requestEmptyTempFolder');
|
||||
// },
|
||||
// };
|
||||
// };
|
||||
socketio.emit('runPostprocessing', imageToProcess, {
|
||||
type: facetoolType,
|
||||
...facetoolParameters,
|
||||
});
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
|
||||
{
|
||||
file: imageToProcess.url,
|
||||
...facetoolParameters,
|
||||
}
|
||||
)}`,
|
||||
})
|
||||
);
|
||||
},
|
||||
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
||||
const { url, uuid, category, thumbnail } = imageToDelete;
|
||||
dispatch(removeImage(imageToDelete));
|
||||
socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
||||
},
|
||||
emitRequestImages: (category: GalleryCategory) => {
|
||||
const gallery: GalleryState = getState().gallery;
|
||||
const { earliest_mtime } = gallery.categories[category];
|
||||
socketio.emit('requestImages', category, earliest_mtime);
|
||||
},
|
||||
emitRequestNewImages: (category: GalleryCategory) => {
|
||||
const gallery: GalleryState = getState().gallery;
|
||||
const { latest_mtime } = gallery.categories[category];
|
||||
socketio.emit('requestLatestImages', category, latest_mtime);
|
||||
},
|
||||
emitCancelProcessing: () => {
|
||||
socketio.emit('cancel');
|
||||
},
|
||||
emitRequestSystemConfig: () => {
|
||||
socketio.emit('requestSystemConfig');
|
||||
},
|
||||
emitSearchForModels: (modelFolder: string) => {
|
||||
socketio.emit('searchForModels', modelFolder);
|
||||
},
|
||||
emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
|
||||
socketio.emit('addNewModel', modelConfig);
|
||||
},
|
||||
emitDeleteModel: (modelName: string) => {
|
||||
socketio.emit('deleteModel', modelName);
|
||||
},
|
||||
emitConvertToDiffusers: (
|
||||
modelToConvert: InvokeAI.InvokeModelConversionProps
|
||||
) => {
|
||||
dispatch(modelConvertRequested());
|
||||
socketio.emit('convertToDiffusers', modelToConvert);
|
||||
},
|
||||
emitMergeDiffusersModels: (
|
||||
modelMergeInfo: InvokeAI.InvokeModelMergingProps
|
||||
) => {
|
||||
dispatch(modelMergingRequested());
|
||||
socketio.emit('mergeDiffusersModels', modelMergeInfo);
|
||||
},
|
||||
emitRequestModelChange: (modelName: string) => {
|
||||
dispatch(modelChangeRequested());
|
||||
socketio.emit('requestModelChange', modelName);
|
||||
},
|
||||
emitSaveStagingAreaImageToGallery: (url: string) => {
|
||||
socketio.emit('requestSaveStagingAreaImageToGallery', url);
|
||||
},
|
||||
emitRequestEmptyTempFolder: () => {
|
||||
socketio.emit('requestEmptyTempFolder');
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
// export default makeSocketIOEmitters;
|
||||
export default makeSocketIOEmitters;
|
||||
|
||||
export default {};
|
||||
|
||||
4
invokeai/frontend/web/src/app/store/actions.ts
Normal file
4
invokeai/frontend/web/src/app/store/actions.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
|
||||
export const userInvoked = createAction<InvokeTabName>('app/userInvoked');
|
||||
8
invokeai/frontend/web/src/app/store/constants.ts
Normal file
8
invokeai/frontend/web/src/app/store/constants.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
export const LOCALSTORAGE_KEYS = [
|
||||
'chakra-ui-color-mode',
|
||||
'i18nextLng',
|
||||
'ROARR_FILTER',
|
||||
'ROARR_LOG',
|
||||
];
|
||||
|
||||
export const LOCALSTORAGE_PREFIX = '@@invokeai-';
|
||||
@@ -0,0 +1,36 @@
|
||||
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
||||
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
||||
import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist';
|
||||
import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist';
|
||||
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
|
||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
||||
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
|
||||
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
|
||||
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
|
||||
import { omit } from 'lodash-es';
|
||||
import { SerializeFunction } from 'redux-remember';
|
||||
|
||||
const serializationDenylist: {
|
||||
[key: string]: string[];
|
||||
} = {
|
||||
canvas: canvasPersistDenylist,
|
||||
gallery: galleryPersistDenylist,
|
||||
generation: generationPersistDenylist,
|
||||
lightbox: lightboxPersistDenylist,
|
||||
models: modelsPersistDenylist,
|
||||
nodes: nodesPersistDenylist,
|
||||
postprocessing: postprocessingPersistDenylist,
|
||||
results: resultsPersistDenylist,
|
||||
system: systemPersistDenylist,
|
||||
// config: configPersistDenyList,
|
||||
ui: uiPersistDenylist,
|
||||
uploads: uploadsPersistDenylist,
|
||||
// hotkeys: hotkeysPersistDenylist,
|
||||
};
|
||||
|
||||
export const serialize: SerializeFunction = (data, key) => {
|
||||
const result = omit(data, serializationDenylist[key]);
|
||||
return JSON.stringify(result);
|
||||
};
|
||||
@@ -0,0 +1,38 @@
|
||||
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
|
||||
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
|
||||
import { initialResultsState } from 'features/gallery/store/resultsSlice';
|
||||
import { initialUploadsState } from 'features/gallery/store/uploadsSlice';
|
||||
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
|
||||
import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
||||
import { initialConfigState } from 'features/system/store/configSlice';
|
||||
import { initialModelsState } from 'features/system/store/modelSlice';
|
||||
import { initialSystemState } from 'features/system/store/systemSlice';
|
||||
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
||||
import { initialUIState } from 'features/ui/store/uiSlice';
|
||||
import { defaultsDeep } from 'lodash-es';
|
||||
import { UnserializeFunction } from 'redux-remember';
|
||||
|
||||
const initialStates: {
|
||||
[key: string]: any;
|
||||
} = {
|
||||
canvas: initialCanvasState,
|
||||
gallery: initialGalleryState,
|
||||
generation: initialGenerationState,
|
||||
lightbox: initialLightboxState,
|
||||
models: initialModelsState,
|
||||
nodes: initialNodesState,
|
||||
postprocessing: initialPostprocessingState,
|
||||
results: initialResultsState,
|
||||
system: initialSystemState,
|
||||
config: initialConfigState,
|
||||
ui: initialUIState,
|
||||
uploads: initialUploadsState,
|
||||
hotkeys: initialHotkeysState,
|
||||
};
|
||||
|
||||
export const unserialize: UnserializeFunction = (data, key) => {
|
||||
const result = defaultsDeep(JSON.parse(data), initialStates[key]);
|
||||
return result;
|
||||
};
|
||||
@@ -0,0 +1,30 @@
|
||||
import { AnyAction } from '@reduxjs/toolkit';
|
||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { Graph } from 'services/api';
|
||||
|
||||
export const actionSanitizer = <A extends AnyAction>(action: A): A => {
|
||||
if (isAnyGraphBuilt(action)) {
|
||||
if (action.payload.nodes) {
|
||||
const sanitizedNodes: Graph['nodes'] = {};
|
||||
|
||||
// Sanitize nodes as needed
|
||||
forEach(action.payload.nodes, (node, key) => {
|
||||
// Don't log the whole freaking dataURL
|
||||
if (node.type === 'dataURL_image') {
|
||||
const { dataURL, ...rest } = node;
|
||||
sanitizedNodes[key] = { ...rest, dataURL: '<dataURL>' };
|
||||
} else {
|
||||
sanitizedNodes[key] = { ...node };
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
...action,
|
||||
payload: { ...action.payload, nodes: sanitizedNodes },
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return action;
|
||||
};
|
||||
@@ -0,0 +1,11 @@
|
||||
export const actionsDenylist = [
|
||||
'canvas/setCursorPosition',
|
||||
'canvas/setStageCoordinates',
|
||||
'canvas/setStageScale',
|
||||
'canvas/setIsDrawing',
|
||||
'canvas/setBoundingBoxCoordinates',
|
||||
'canvas/setBoundingBoxDimensions',
|
||||
'canvas/setIsDrawing',
|
||||
'canvas/addPointToCurrentLine',
|
||||
'socket/generatorProgress',
|
||||
];
|
||||
@@ -0,0 +1,3 @@
|
||||
export const stateSanitizer = <S>(state: S): S => {
|
||||
return state;
|
||||
};
|
||||
@@ -0,0 +1,54 @@
|
||||
import {
|
||||
createListenerMiddleware,
|
||||
addListener,
|
||||
ListenerEffect,
|
||||
AnyAction,
|
||||
} from '@reduxjs/toolkit';
|
||||
import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
|
||||
|
||||
import type { RootState, AppDispatch } from '../../store';
|
||||
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
||||
import { addImageResultReceivedListener } from './listeners/invocationComplete';
|
||||
import { addImageUploadedListener } from './listeners/imageUploaded';
|
||||
import { addRequestedImageDeletionListener } from './listeners/imageDeleted';
|
||||
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
||||
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
||||
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
|
||||
export const startAppListening =
|
||||
listenerMiddleware.startListening as AppStartListening;
|
||||
|
||||
export const addAppListener = addListener as TypedAddListener<
|
||||
RootState,
|
||||
AppDispatch
|
||||
>;
|
||||
|
||||
export type AppListenerEffect = ListenerEffect<
|
||||
AnyAction,
|
||||
RootState,
|
||||
AppDispatch
|
||||
>;
|
||||
|
||||
addImageUploadedListener();
|
||||
addInitialImageSelectedListener();
|
||||
addImageResultReceivedListener();
|
||||
addRequestedImageDeletionListener();
|
||||
|
||||
addUserInvokedCanvasListener();
|
||||
addUserInvokedNodesListener();
|
||||
addUserInvokedTextToImageListener();
|
||||
addUserInvokedImageToImageListener();
|
||||
|
||||
addCanvasSavedToGalleryListener();
|
||||
addCanvasDownloadedAsImageListener();
|
||||
addCanvasCopiedToClipboardListener();
|
||||
addCanvasMergedListener();
|
||||
@@ -0,0 +1,33 @@
|
||||
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { copyBlobToClipboard } from 'features/canvas/util/copyBlobToClipboard';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
|
||||
|
||||
export const addCanvasCopiedToClipboardListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasCopiedToClipboard,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
|
||||
if (!blob) {
|
||||
moduleLog.error('Problem getting base layer blob');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Copying Canvas',
|
||||
description: 'Unable to export base layer',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
copyBlobToClipboard(blob);
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,33 @@
|
||||
import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { downloadBlob } from 'features/canvas/util/downloadBlob';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
||||
|
||||
export const addCanvasDownloadedAsImageListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasDownloadedAsImage,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
|
||||
if (!blob) {
|
||||
moduleLog.error('Problem getting base layer blob');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Downloading Canvas',
|
||||
description: 'Unable to export base layer',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
downloadBlob(blob, 'mergedCanvas.png');
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,88 @@
|
||||
import { canvasMerged } from 'features/canvas/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
|
||||
|
||||
export const addCanvasMergedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasMerged,
|
||||
effect: async (action, { dispatch, getState, take }) => {
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state, true);
|
||||
|
||||
if (!blob) {
|
||||
moduleLog.error('Problem getting base layer blob');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Merging Canvas',
|
||||
description: 'Unable to export base layer',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const canvasBaseLayer = getCanvasBaseLayer();
|
||||
|
||||
if (!canvasBaseLayer) {
|
||||
moduleLog.error('Problem getting canvas base layer');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Merging Canvas',
|
||||
description: 'Unable to export base layer',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const baseLayerRect = canvasBaseLayer.getClientRect({
|
||||
relativeTo: canvasBaseLayer.getParent(),
|
||||
});
|
||||
|
||||
const filename = `mergedCanvas_${uuidv4()}.png`;
|
||||
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
imageType: 'intermediates',
|
||||
formData: {
|
||||
file: new File([blob], filename, { type: 'image/png' }),
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
const [{ payload }] = await take(
|
||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.meta.arg.formData.file.name === filename
|
||||
);
|
||||
|
||||
const mergedCanvasImage = deserializeImageResponse(payload.response);
|
||||
|
||||
dispatch(
|
||||
setMergedCanvas({
|
||||
kind: 'image',
|
||||
layer: 'base',
|
||||
image: mergedCanvasImage,
|
||||
...baseLayerRect,
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Canvas Merged',
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,40 @@
|
||||
import { canvasSavedToGallery } from 'features/canvas/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
||||
|
||||
export const addCanvasSavedToGalleryListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasSavedToGallery,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
|
||||
if (!blob) {
|
||||
moduleLog.error('Problem getting base layer blob');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Saving Canvas',
|
||||
description: 'Unable to export base layer',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
imageType: 'results',
|
||||
formData: {
|
||||
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
|
||||
},
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,59 @@
|
||||
import { requestedImageDeletion } from 'features/gallery/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { imageDeleted } from 'services/thunks/image';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||
|
||||
export const addRequestedImageDeletionListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: requestedImageDeletion,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const image = action.payload;
|
||||
if (!image) {
|
||||
moduleLog.warn('No image provided');
|
||||
return;
|
||||
}
|
||||
|
||||
const { name, type } = image;
|
||||
|
||||
if (type !== 'uploads' && type !== 'results') {
|
||||
moduleLog.warn({ data: image }, `Invalid image type ${type}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedImageName = getState().gallery.selectedImage?.name;
|
||||
|
||||
if (selectedImageName === name) {
|
||||
const allIds = getState()[type].ids;
|
||||
const allEntities = getState()[type].entities;
|
||||
|
||||
const deletedImageIndex = allIds.findIndex(
|
||||
(result) => result.toString() === name
|
||||
);
|
||||
|
||||
const filteredIds = allIds.filter((id) => id.toString() !== name);
|
||||
|
||||
const newSelectedImageIndex = clamp(
|
||||
deletedImageIndex,
|
||||
0,
|
||||
filteredIds.length - 1
|
||||
);
|
||||
|
||||
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
||||
|
||||
const newSelectedImage = allEntities[newSelectedImageId];
|
||||
|
||||
if (newSelectedImageId) {
|
||||
dispatch(imageSelected(newSelectedImage));
|
||||
} else {
|
||||
dispatch(imageSelected());
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(imageDeleted({ imageName: name, imageType: type }));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,46 @@
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
import { startAppListening } from '..';
|
||||
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
||||
|
||||
export const addImageUploadedListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.payload.response.image_type !== 'intermediates',
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { response } = action.payload;
|
||||
const { imageType } = action.meta.arg;
|
||||
|
||||
const state = getState();
|
||||
const image = deserializeImageResponse(response);
|
||||
|
||||
if (imageType === 'uploads') {
|
||||
dispatch(uploadAdded(image));
|
||||
|
||||
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
||||
|
||||
if (state.gallery.shouldAutoSwitchToNewImages) {
|
||||
dispatch(imageSelected(image));
|
||||
}
|
||||
|
||||
if (action.meta.arg.activeTabName === 'img2img') {
|
||||
dispatch(initialImageSelected(image));
|
||||
}
|
||||
|
||||
if (action.meta.arg.activeTabName === 'unifiedCanvas') {
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
}
|
||||
}
|
||||
|
||||
if (imageType === 'results') {
|
||||
dispatch(resultAdded(image));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,54 @@
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { Image, isInvokeAIImage } from 'app/types/invokeai';
|
||||
import { selectResultsById } from 'features/gallery/store/resultsSlice';
|
||||
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
||||
import { t } from 'i18next';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { startAppListening } from '..';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
|
||||
export const addInitialImageSelectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: initialImageSelected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
if (!action.payload) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
|
||||
)
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isInvokeAIImage(action.payload)) {
|
||||
dispatch(initialImageChanged(action.payload));
|
||||
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
|
||||
return;
|
||||
}
|
||||
|
||||
const { name, type } = action.payload;
|
||||
|
||||
let image: Image | undefined;
|
||||
const state = getState();
|
||||
|
||||
if (type === 'results') {
|
||||
image = selectResultsById(state, name);
|
||||
} else if (type === 'uploads') {
|
||||
image = selectUploadsById(state, name);
|
||||
}
|
||||
|
||||
if (!image) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
|
||||
)
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(initialImageChanged(image));
|
||||
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,88 @@
|
||||
import { invocationComplete } from 'services/events/actions';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import {
|
||||
buildImageUrls,
|
||||
extractTimestampFromImageName,
|
||||
} from 'services/util/deserializeImageField';
|
||||
import { Image } from 'app/types/invokeai';
|
||||
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
||||
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
|
||||
import { startAppListening } from '..';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
|
||||
const nodeDenylist = ['dataURL_image'];
|
||||
|
||||
export const addImageResultReceivedListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action) => {
|
||||
if (
|
||||
invocationComplete.match(action) &&
|
||||
isImageOutput(action.payload.data.result)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
if (!invocationComplete.match(action)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { data, shouldFetchImages } = action.payload;
|
||||
const { result, node, graph_execution_state_id } = data;
|
||||
|
||||
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
||||
const name = result.image.image_name;
|
||||
const type = result.image.image_type;
|
||||
const state = getState();
|
||||
|
||||
// if we need to refetch, set URLs to placeholder for now
|
||||
const { url, thumbnail } = shouldFetchImages
|
||||
? { url: '', thumbnail: '' }
|
||||
: buildImageUrls(type, name);
|
||||
|
||||
const timestamp = extractTimestampFromImageName(name);
|
||||
|
||||
const image: Image = {
|
||||
name,
|
||||
type,
|
||||
url,
|
||||
thumbnail,
|
||||
metadata: {
|
||||
created: timestamp,
|
||||
width: result.width,
|
||||
height: result.height,
|
||||
invokeai: {
|
||||
session_id: graph_execution_state_id,
|
||||
...(node ? { node } : {}),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
dispatch(resultAdded(image));
|
||||
|
||||
if (state.gallery.shouldAutoSwitchToNewImages) {
|
||||
dispatch(imageSelected(image));
|
||||
}
|
||||
|
||||
if (state.config.shouldFetchImages) {
|
||||
dispatch(imageReceived({ imageName: name, imageType: type }));
|
||||
dispatch(
|
||||
thumbnailReceived({
|
||||
thumbnailName: name,
|
||||
thumbnailType: type,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
graph_execution_state_id ===
|
||||
state.canvas.layerState.stagingArea.sessionId
|
||||
) {
|
||||
dispatch(addImageToStagingArea(image));
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,164 @@
|
||||
import { startAppListening } from '..';
|
||||
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
|
||||
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { Graph } from 'services/api';
|
||||
import {
|
||||
canvasSessionIdChanged,
|
||||
stagingAreaInitialized,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
/**
|
||||
* This listener is responsible for building the canvas graph and blobs when the user invokes the canvas.
|
||||
* It is also responsible for uploading the base and mask layers to the server.
|
||||
*/
|
||||
export const addUserInvokedCanvasListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||
userInvoked.match(action) && action.payload === 'unifiedCanvas',
|
||||
effect: async (action, { getState, dispatch, take }) => {
|
||||
const state = getState();
|
||||
|
||||
// Build canvas blobs
|
||||
const canvasBlobsAndImageData = await getCanvasData(state);
|
||||
|
||||
if (!canvasBlobsAndImageData) {
|
||||
moduleLog.error('Unable to create canvas data');
|
||||
return;
|
||||
}
|
||||
|
||||
const { baseBlob, baseImageData, maskBlob, maskImageData } =
|
||||
canvasBlobsAndImageData;
|
||||
|
||||
// Determine the generation mode
|
||||
const generationMode = getCanvasGenerationMode(
|
||||
baseImageData,
|
||||
maskImageData
|
||||
);
|
||||
|
||||
if (state.system.enableImageDebugging) {
|
||||
const baseDataURL = await blobToDataURL(baseBlob);
|
||||
const maskDataURL = await blobToDataURL(maskBlob);
|
||||
openBase64ImageInTab([
|
||||
{ base64: maskDataURL, caption: 'mask b64' },
|
||||
{ base64: baseDataURL, caption: 'image b64' },
|
||||
]);
|
||||
}
|
||||
|
||||
moduleLog.debug(`Generation mode: ${generationMode}`);
|
||||
|
||||
// Build the canvas graph
|
||||
const graphComponents = await buildCanvasGraphComponents(
|
||||
state,
|
||||
generationMode
|
||||
);
|
||||
|
||||
if (!graphComponents) {
|
||||
moduleLog.error('Problem building graph');
|
||||
return;
|
||||
}
|
||||
|
||||
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
|
||||
|
||||
// Upload the base layer, to be used as init image
|
||||
const baseFilename = `${uuidv4()}.png`;
|
||||
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
imageType: 'intermediates',
|
||||
formData: {
|
||||
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
|
||||
const [{ payload: basePayload }] = await take(
|
||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.meta.arg.formData.file.name === baseFilename
|
||||
);
|
||||
|
||||
const { image_name: baseName, image_type: baseType } =
|
||||
basePayload.response;
|
||||
|
||||
baseNode.image = {
|
||||
image_name: baseName,
|
||||
image_type: baseType,
|
||||
};
|
||||
}
|
||||
|
||||
// Upload the mask layer image
|
||||
const maskFilename = `${uuidv4()}.png`;
|
||||
|
||||
if (baseNode.type === 'inpaint') {
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
imageType: 'intermediates',
|
||||
formData: {
|
||||
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
const [{ payload: maskPayload }] = await take(
|
||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.meta.arg.formData.file.name === maskFilename
|
||||
);
|
||||
|
||||
const { image_name: maskName, image_type: maskType } =
|
||||
maskPayload.response;
|
||||
|
||||
baseNode.mask = {
|
||||
image_name: maskName,
|
||||
image_type: maskType,
|
||||
};
|
||||
}
|
||||
|
||||
// Assemble!
|
||||
const nodes: Graph['nodes'] = {
|
||||
[rangeNode.id]: rangeNode,
|
||||
[iterateNode.id]: iterateNode,
|
||||
[baseNode.id]: baseNode,
|
||||
};
|
||||
|
||||
const graph = { nodes, edges };
|
||||
|
||||
dispatch(canvasGraphBuilt(graph));
|
||||
moduleLog({ data: graph }, 'Canvas graph built');
|
||||
|
||||
// Actually create the session
|
||||
dispatch(sessionCreated({ graph }));
|
||||
|
||||
// Wait for the session to be invoked (this is just the HTTP request to start processing)
|
||||
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
|
||||
|
||||
const { sessionId } = meta.arg;
|
||||
|
||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||
dispatch(
|
||||
stagingAreaInitialized({
|
||||
sessionId,
|
||||
boundingBox: {
|
||||
...state.canvas.boundingBoxCoordinates,
|
||||
...state.canvas.boundingBoxDimensions,
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(canvasSessionIdChanged(sessionId));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,24 @@
|
||||
import { startAppListening } from '..';
|
||||
import { buildImageToImageGraph } from 'features/nodes/util/graphBuilders/buildImageToImageGraph';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
export const addUserInvokedImageToImageListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||
userInvoked.match(action) && action.payload === 'img2img',
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
|
||||
const graph = buildImageToImageGraph(state);
|
||||
dispatch(imageToImageGraphBuilt(graph));
|
||||
moduleLog({ data: graph }, 'Image to Image graph built');
|
||||
|
||||
dispatch(sessionCreated({ graph }));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,24 @@
|
||||
import { startAppListening } from '..';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { nodesGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
export const addUserInvokedNodesListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||
userInvoked.match(action) && action.payload === 'nodes',
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
|
||||
const graph = buildNodesGraph(state);
|
||||
dispatch(nodesGraphBuilt(graph));
|
||||
moduleLog({ data: graph }, 'Nodes graph built');
|
||||
|
||||
dispatch(sessionCreated({ graph }));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,24 @@
|
||||
import { startAppListening } from '..';
|
||||
import { buildTextToImageGraph } from 'features/nodes/util/graphBuilders/buildTextToImageGraph';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
export const addUserInvokedTextToImageListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||
userInvoked.match(action) && action.payload === 'txt2img',
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
|
||||
const graph = buildTextToImageGraph(state);
|
||||
dispatch(textToImageGraphBuilt(graph));
|
||||
moduleLog({ data: graph }, 'Text to Image graph built');
|
||||
|
||||
dispatch(sessionCreated({ graph }));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,4 +0,0 @@
|
||||
import { store } from 'app/store/store';
|
||||
import { persistStore } from 'redux-persist';
|
||||
|
||||
export const persistor = persistStore(store);
|
||||
@@ -1,9 +1,12 @@
|
||||
import { combineReducers, configureStore } from '@reduxjs/toolkit';
|
||||
import {
|
||||
AnyAction,
|
||||
ThunkDispatch,
|
||||
combineReducers,
|
||||
configureStore,
|
||||
} from '@reduxjs/toolkit';
|
||||
|
||||
import { persistReducer } from 'redux-persist';
|
||||
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
|
||||
import { rememberReducer, rememberEnhancer } from 'redux-remember';
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
import { getPersistConfig } from 'redux-deep-persist';
|
||||
|
||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||
@@ -19,33 +22,17 @@ import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||
import modelsReducer from 'features/system/store/modelSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
|
||||
import { canvasDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
||||
import { galleryDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
||||
import { generationDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||
import { lightboxDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
|
||||
import { modelsDenylist } from 'features/system/store/modelsPersistDenylist';
|
||||
import { nodesDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||
import { postprocessingDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
||||
import { systemDenylist } from 'features/system/store/systemPersistDenylist';
|
||||
import { uiDenylist } from 'features/ui/store/uiPersistDenylist';
|
||||
import { resultsDenylist } from 'features/gallery/store/resultsPersistDenylist';
|
||||
import { uploadsDenylist } from 'features/gallery/store/uploadsPersistDenylist';
|
||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||
|
||||
/**
|
||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||
*
|
||||
* While we definitely want generation parameters to be persisted, there are a number
|
||||
* of things we do *not* want to be persisted across reloads:
|
||||
* - Gallery/selected image (user may add/delete images from disk between page loads)
|
||||
* - Connection/processing status
|
||||
* - Availability of external libraries like ESRGAN/GFPGAN
|
||||
*
|
||||
* These can be denylisted in redux-persist.
|
||||
*
|
||||
* The necesssary nested persistors with denylists are configured below.
|
||||
*/
|
||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||
|
||||
const rootReducer = combineReducers({
|
||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||
|
||||
const allReducers = {
|
||||
canvas: canvasReducer,
|
||||
gallery: galleryReducer,
|
||||
generation: generationReducer,
|
||||
@@ -59,65 +46,54 @@ const rootReducer = combineReducers({
|
||||
ui: uiReducer,
|
||||
uploads: uploadsReducer,
|
||||
hotkeys: hotkeysReducer,
|
||||
});
|
||||
};
|
||||
|
||||
const rootPersistConfig = getPersistConfig({
|
||||
key: 'root',
|
||||
storage,
|
||||
rootReducer,
|
||||
blacklist: [
|
||||
...canvasDenylist,
|
||||
...galleryDenylist,
|
||||
...generationDenylist,
|
||||
...lightboxDenylist,
|
||||
...modelsDenylist,
|
||||
...nodesDenylist,
|
||||
...postprocessingDenylist,
|
||||
// ...resultsDenylist,
|
||||
'results',
|
||||
...systemDenylist,
|
||||
...uiDenylist,
|
||||
// ...uploadsDenylist,
|
||||
'uploads',
|
||||
'hotkeys',
|
||||
'config',
|
||||
],
|
||||
});
|
||||
const rootReducer = combineReducers(allReducers);
|
||||
|
||||
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
||||
const rememberedRootReducer = rememberReducer(rootReducer);
|
||||
|
||||
// TODO: rip the old middleware out when nodes is complete
|
||||
// export function buildMiddleware() {
|
||||
// if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
|
||||
// return socketMiddleware();
|
||||
// } else {
|
||||
// return socketioMiddleware();
|
||||
// }
|
||||
// }
|
||||
const rememberedKeys: (keyof typeof allReducers)[] = [
|
||||
'canvas',
|
||||
'gallery',
|
||||
'generation',
|
||||
'lightbox',
|
||||
// 'models',
|
||||
'nodes',
|
||||
'postprocessing',
|
||||
'system',
|
||||
'ui',
|
||||
// 'hotkeys',
|
||||
// 'results',
|
||||
// 'uploads',
|
||||
// 'config',
|
||||
];
|
||||
|
||||
export const store = configureStore({
|
||||
reducer: persistedReducer,
|
||||
reducer: rememberedRootReducer,
|
||||
enhancers: [
|
||||
rememberEnhancer(window.localStorage, rememberedKeys, {
|
||||
persistDebounce: 300,
|
||||
serialize,
|
||||
unserialize,
|
||||
prefix: LOCALSTORAGE_PREFIX,
|
||||
}),
|
||||
],
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
immutableCheck: false,
|
||||
serializableCheck: false,
|
||||
}).concat(dynamicMiddlewares),
|
||||
})
|
||||
.concat(dynamicMiddlewares)
|
||||
.prepend(listenerMiddleware.middleware),
|
||||
devTools: {
|
||||
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
||||
actionsDenylist: [
|
||||
'canvas/setCursorPosition',
|
||||
'canvas/setStageCoordinates',
|
||||
'canvas/setStageScale',
|
||||
'canvas/setIsDrawing',
|
||||
'canvas/setBoundingBoxCoordinates',
|
||||
'canvas/setBoundingBoxDimensions',
|
||||
'canvas/setIsDrawing',
|
||||
'canvas/addPointToCurrentLine',
|
||||
'socket/generatorProgress',
|
||||
],
|
||||
actionsDenylist,
|
||||
actionSanitizer,
|
||||
stateSanitizer,
|
||||
trace: true,
|
||||
},
|
||||
});
|
||||
|
||||
export type AppGetState = typeof store.getState;
|
||||
export type RootState = ReturnType<typeof store.getState>;
|
||||
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
|
||||
export type AppDispatch = typeof store.dispatch;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { TypedUseSelectorHook, useDispatch, useSelector } from 'react-redux';
|
||||
import { AppDispatch, RootState } from 'app/store/store';
|
||||
import { AppThunkDispatch, RootState } from 'app/store/store';
|
||||
|
||||
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
||||
export const useAppDispatch: () => AppDispatch = useDispatch;
|
||||
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
|
||||
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
export const defaultSelectorOptions = {
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
};
|
||||
@@ -12,12 +12,10 @@
|
||||
* 'gfpgan'.
|
||||
*/
|
||||
|
||||
import { GalleryCategory } from 'features/gallery/store/gallerySlice';
|
||||
import { FacetoolType } from 'features/parameters/store/postprocessingSlice';
|
||||
import { SelectedImage } from 'features/parameters/store/actions';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
import { ImageResponseMetadata, ImageType } from 'services/api';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
import { O } from 'ts-toolbelt';
|
||||
|
||||
/**
|
||||
@@ -49,15 +47,21 @@ export type CommonGeneratedImageMetadata = {
|
||||
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
|
||||
sampler:
|
||||
| 'ddim'
|
||||
| 'k_dpm_2_a'
|
||||
| 'k_dpm_2'
|
||||
| 'k_dpmpp_2_a'
|
||||
| 'k_dpmpp_2'
|
||||
| 'k_euler_a'
|
||||
| 'k_euler'
|
||||
| 'k_heun'
|
||||
| 'k_lms'
|
||||
| 'plms';
|
||||
| 'ddpm'
|
||||
| 'deis'
|
||||
| 'lms'
|
||||
| 'pndm'
|
||||
| 'heun'
|
||||
| 'heun_k'
|
||||
| 'euler'
|
||||
| 'euler_k'
|
||||
| 'euler_a'
|
||||
| 'kdpm_2'
|
||||
| 'kdpm_2_a'
|
||||
| 'dpmpp_2s'
|
||||
| 'dpmpp_2m'
|
||||
| 'dpmpp_2m_k'
|
||||
| 'unipc';
|
||||
prompt: Prompt;
|
||||
seed: number;
|
||||
variations: SeedWeights;
|
||||
@@ -126,6 +130,14 @@ export type Image = {
|
||||
metadata: ImageResponseMetadata;
|
||||
};
|
||||
|
||||
export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
|
||||
if ('url' in obj && 'thumbnail' in obj) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
/**
|
||||
* Types related to the system status.
|
||||
*/
|
||||
@@ -270,7 +282,7 @@ export type FoundModelResponse = {
|
||||
|
||||
// export type SystemConfigResponse = SystemConfig;
|
||||
|
||||
export type ImageResultResponse = Omit<_Image, 'uuid'> & {
|
||||
export type ImageResultResponse = Omit<Image, 'uuid'> & {
|
||||
boundingBox?: IRect;
|
||||
generationMode: InvokeTabName;
|
||||
};
|
||||
@@ -315,11 +327,11 @@ export type AppFeature =
|
||||
/**
|
||||
* A disable-able Stable Diffusion feature
|
||||
*/
|
||||
export type StableDiffusionFeature =
|
||||
| 'noiseConfig'
|
||||
| 'variations'
|
||||
export type SDFeature =
|
||||
| 'noise'
|
||||
| 'variation'
|
||||
| 'symmetry'
|
||||
| 'tiling'
|
||||
| 'seamless'
|
||||
| 'hires';
|
||||
|
||||
/**
|
||||
@@ -337,6 +349,7 @@ export type AppConfig = {
|
||||
shouldFetchImages: boolean;
|
||||
disabledTabs: InvokeTabName[];
|
||||
disabledFeatures: AppFeature[];
|
||||
disabledSDFeatures: SDFeature[];
|
||||
canRestoreDeletedImagesFromBin: boolean;
|
||||
sd: {
|
||||
iterations: {
|
||||
|
||||
61
invokeai/frontend/web/src/common/components/IAICollapse.tsx
Normal file
61
invokeai/frontend/web/src/common/components/IAICollapse.tsx
Normal file
@@ -0,0 +1,61 @@
|
||||
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import { Box, Collapse, Flex, Spacer, Switch } from '@chakra-ui/react';
|
||||
import { PropsWithChildren, memo } from 'react';
|
||||
|
||||
export type IAIToggleCollapseProps = PropsWithChildren & {
|
||||
label: string;
|
||||
isOpen: boolean;
|
||||
onToggle: () => void;
|
||||
withSwitch?: boolean;
|
||||
};
|
||||
|
||||
const IAICollapse = (props: IAIToggleCollapseProps) => {
|
||||
const { label, isOpen, onToggle, children, withSwitch = false } = props;
|
||||
return (
|
||||
<Box>
|
||||
<Flex
|
||||
onClick={onToggle}
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
p: 2,
|
||||
px: 4,
|
||||
borderTopRadius: 'base',
|
||||
borderBottomRadius: isOpen ? 0 : 'base',
|
||||
bg: isOpen ? 'base.750' : 'base.800',
|
||||
color: 'base.100',
|
||||
_hover: {
|
||||
bg: isOpen ? 'base.700' : 'base.750',
|
||||
},
|
||||
fontSize: 'sm',
|
||||
fontWeight: 600,
|
||||
cursor: 'pointer',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
userSelect: 'none',
|
||||
}}
|
||||
>
|
||||
{label}
|
||||
<Spacer />
|
||||
{withSwitch && <Switch isChecked={isOpen} pointerEvents="none" />}
|
||||
{!withSwitch && (
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
w: '1rem',
|
||||
h: '1rem',
|
||||
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
<Collapse in={isOpen} animateOpacity>
|
||||
<Box sx={{ p: 4, borderBottomRadius: 'base', bg: 'base.800' }}>
|
||||
{children}
|
||||
</Box>
|
||||
</Collapse>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAICollapse);
|
||||
172
invokeai/frontend/web/src/common/components/IAICustomSelect.tsx
Normal file
172
invokeai/frontend/web/src/common/components/IAICustomSelect.tsx
Normal file
@@ -0,0 +1,172 @@
|
||||
import { CheckIcon } from '@chakra-ui/icons';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
FlexProps,
|
||||
FormControl,
|
||||
FormControlProps,
|
||||
FormLabel,
|
||||
Grid,
|
||||
GridItem,
|
||||
List,
|
||||
ListItem,
|
||||
Select,
|
||||
Text,
|
||||
Tooltip,
|
||||
TooltipProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
|
||||
import { useSelect } from 'downshift';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
|
||||
import { memo } from 'react';
|
||||
|
||||
type IAICustomSelectProps = {
|
||||
label?: string;
|
||||
items: string[];
|
||||
selectedItem: string;
|
||||
setSelectedItem: (v: string | null | undefined) => void;
|
||||
withCheckIcon?: boolean;
|
||||
formControlProps?: FormControlProps;
|
||||
buttonProps?: FlexProps;
|
||||
tooltip?: string;
|
||||
tooltipProps?: Omit<TooltipProps, 'children'>;
|
||||
};
|
||||
|
||||
const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
const {
|
||||
label,
|
||||
items,
|
||||
setSelectedItem,
|
||||
selectedItem,
|
||||
withCheckIcon,
|
||||
formControlProps,
|
||||
tooltip,
|
||||
buttonProps,
|
||||
tooltipProps,
|
||||
} = props;
|
||||
|
||||
const {
|
||||
isOpen,
|
||||
getToggleButtonProps,
|
||||
getLabelProps,
|
||||
getMenuProps,
|
||||
highlightedIndex,
|
||||
getItemProps,
|
||||
} = useSelect({
|
||||
items,
|
||||
selectedItem,
|
||||
onSelectedItemChange: ({ selectedItem: newSelectedItem }) =>
|
||||
setSelectedItem(newSelectedItem),
|
||||
});
|
||||
|
||||
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({
|
||||
whileElementsMounted: autoUpdate,
|
||||
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl sx={{ w: 'full' }} {...formControlProps}>
|
||||
{label && (
|
||||
<FormLabel
|
||||
{...getLabelProps()}
|
||||
onClick={() => {
|
||||
refs.floating.current && refs.floating.current.focus();
|
||||
}}
|
||||
>
|
||||
{label}
|
||||
</FormLabel>
|
||||
)}
|
||||
<Tooltip label={tooltip} {...tooltipProps}>
|
||||
<Select
|
||||
{...getToggleButtonProps({ ref: refs.setReference })}
|
||||
{...buttonProps}
|
||||
as={Flex}
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
userSelect: 'none',
|
||||
cursor: 'pointer',
|
||||
}}
|
||||
>
|
||||
<Text sx={{ fontSize: 'sm', fontWeight: 500, color: 'base.100' }}>
|
||||
{selectedItem}
|
||||
</Text>
|
||||
</Select>
|
||||
</Tooltip>
|
||||
<Box {...getMenuProps()}>
|
||||
{isOpen && (
|
||||
<List
|
||||
as={Flex}
|
||||
ref={refs.setFloating}
|
||||
sx={{
|
||||
...floatingStyles,
|
||||
width: 'max-content',
|
||||
top: 0,
|
||||
left: 0,
|
||||
flexDirection: 'column',
|
||||
zIndex: 1,
|
||||
bg: 'base.800',
|
||||
borderRadius: 'base',
|
||||
border: '1px',
|
||||
borderColor: 'base.700',
|
||||
shadow: 'dark-lg',
|
||||
py: 2,
|
||||
px: 0,
|
||||
h: 'fit-content',
|
||||
maxH: 64,
|
||||
}}
|
||||
>
|
||||
<OverlayScrollbarsComponent>
|
||||
{items.map((item, index) => (
|
||||
<ListItem
|
||||
sx={{
|
||||
bg: highlightedIndex === index ? 'base.700' : undefined,
|
||||
py: 1,
|
||||
paddingInlineStart: 3,
|
||||
paddingInlineEnd: 6,
|
||||
cursor: 'pointer',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.15s',
|
||||
}}
|
||||
key={`${item}${index}`}
|
||||
{...getItemProps({ item, index })}
|
||||
>
|
||||
{withCheckIcon ? (
|
||||
<Grid gridTemplateColumns="1.25rem auto">
|
||||
<GridItem>
|
||||
{selectedItem === item && <CheckIcon boxSize={2} />}
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
color: 'base.100',
|
||||
fontWeight: 500,
|
||||
}}
|
||||
>
|
||||
{item}
|
||||
</Text>
|
||||
</GridItem>
|
||||
</Grid>
|
||||
) : (
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
color: 'base.100',
|
||||
fontWeight: 500,
|
||||
}}
|
||||
>
|
||||
{item}
|
||||
</Text>
|
||||
)}
|
||||
</ListItem>
|
||||
))}
|
||||
</OverlayScrollbarsComponent>
|
||||
</List>
|
||||
)}
|
||||
</Box>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAICustomSelect);
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
Input,
|
||||
InputProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||
import { ChangeEvent, memo } from 'react';
|
||||
|
||||
interface IAIInputProps extends InputProps {
|
||||
@@ -31,7 +32,7 @@ const IAIInput = (props: IAIInputProps) => {
|
||||
{...formControlProps}
|
||||
>
|
||||
{label !== '' && <FormLabel>{label}</FormLabel>}
|
||||
<Input {...rest} />
|
||||
<Input {...rest} onPaste={stopPastePropagation} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
Tooltip,
|
||||
TooltipProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||
import { clamp } from 'lodash-es';
|
||||
|
||||
import { FocusEvent, memo, useEffect, useState } from 'react';
|
||||
@@ -125,6 +126,7 @@ const IAINumberInput = (props: Props) => {
|
||||
onChange={handleOnChange}
|
||||
onBlur={handleBlur}
|
||||
{...rest}
|
||||
onPaste={stopPastePropagation}
|
||||
>
|
||||
<NumberInputField {...numberInputFieldProps} />
|
||||
{showStepper && (
|
||||
|
||||
@@ -27,7 +27,7 @@ const IAIPopover = (props: IAIPopoverProps) => {
|
||||
return (
|
||||
<Popover isLazy={isLazy} {...rest}>
|
||||
<PopoverTrigger>{triggerComponent}</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverContent shadow="dark-lg">
|
||||
{hasArrow && <PopoverArrow />}
|
||||
{children}
|
||||
</PopoverContent>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user