mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Compare commits
87 Commits
v3.1.1
...
feat/batch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
732780c376 | ||
|
|
ed7deee8f1 | ||
|
|
d22c4734ee | ||
|
|
3e7dadd7b3 | ||
|
|
b777dba430 | ||
|
|
531c3bb1e2 | ||
|
|
331743ca0c | ||
|
|
13429e66b3 | ||
|
|
2185c85287 | ||
|
|
e8a4a654ac | ||
|
|
26f9ac9f21 | ||
|
|
8d78af5db7 | ||
|
|
babd26feab | ||
|
|
e9b26e5e7d | ||
|
|
6b946f53c4 | ||
|
|
70479b9827 | ||
|
|
35099dcdd8 | ||
|
|
670600a863 | ||
|
|
6d5403e19d | ||
|
|
0f7695a081 | ||
|
|
d567d9f804 | ||
|
|
68f6140685 | ||
|
|
be971617e3 | ||
|
|
1652143671 | ||
|
|
88ae19a768 | ||
|
|
50816432dc | ||
|
|
b98c9b516a | ||
|
|
a15a5bc3b8 | ||
|
|
018ff56314 | ||
|
|
137fbacb92 | ||
|
|
4b6d9a73ed | ||
|
|
3e26214b83 | ||
|
|
0282f46c71 | ||
|
|
99e03fe92e | ||
|
|
cb65526880 | ||
|
|
59bc9ed399 | ||
|
|
e62d5478fd | ||
|
|
2cf0d61b3e | ||
|
|
cc3c2756bd | ||
|
|
67cf594bb3 | ||
|
|
c5b963f1a6 | ||
|
|
4d2dd6bb10 | ||
|
|
7e4beab4ff | ||
|
|
e16b5f7cdc | ||
|
|
1f355d5810 | ||
|
|
df7370f9d9 | ||
|
|
5bec64d65b | ||
|
|
8cf9bd47b2 | ||
|
|
c91621b46c | ||
|
|
f246b236dd | ||
|
|
f7277a8b21 | ||
|
|
796ee1246b | ||
|
|
29fceb960d | ||
|
|
796ff34c8a | ||
|
|
d6a5c2dbe3 | ||
|
|
ef8dc2e8c5 | ||
|
|
314891a125 | ||
|
|
2d3094f988 | ||
|
|
abf09fc8fa | ||
|
|
15e7ca1baa | ||
|
|
6cb90e01de | ||
|
|
faa4574970 | ||
|
|
cc5755d5b1 | ||
|
|
85105fc070 | ||
|
|
ed40aee4c5 | ||
|
|
f8d8b16267 | ||
|
|
846e52f2ea | ||
|
|
69f541075c | ||
|
|
1debc31e3d | ||
|
|
1d798d4119 | ||
|
|
c1dde83abb | ||
|
|
280ac15da2 | ||
|
|
e751f7d815 | ||
|
|
e26e4740b3 | ||
|
|
835d76af45 | ||
|
|
a3e099bbc0 | ||
|
|
a61685696f | ||
|
|
02aa93c67c | ||
|
|
55b921818d | ||
|
|
bb681a8a11 | ||
|
|
74e0fbce42 | ||
|
|
f080c56771 | ||
|
|
d2f968b902 | ||
|
|
e81601acf3 | ||
|
|
7073dc0d5d | ||
|
|
d090be60e8 | ||
|
|
4bad96d9d6 |
12
.github/ISSUE_TEMPLATE/FEATURE_REQUEST.yml
vendored
12
.github/ISSUE_TEMPLATE/FEATURE_REQUEST.yml
vendored
@@ -1,5 +1,5 @@
|
||||
name: Feature Request
|
||||
description: Contribute a idea or request a new feature
|
||||
description: Commit a idea or Request a new feature
|
||||
title: '[enhancement]: '
|
||||
labels: ['enhancement']
|
||||
# assignees:
|
||||
@@ -9,14 +9,14 @@ body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for taking the time to fill out this feature request!
|
||||
Thanks for taking the time to fill out this Feature request!
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
description: |
|
||||
Please make use of the [search function](https://github.com/invoke-ai/InvokeAI/labels/enhancement)
|
||||
to see if a similar issue already exists for the feature you want to request
|
||||
to see if a simmilar issue already exists for the feature you want to request
|
||||
options:
|
||||
- label: I have searched the existing issues
|
||||
required: true
|
||||
@@ -36,7 +36,7 @@ body:
|
||||
label: What should this feature add?
|
||||
description: Please try to explain the functionality this feature should add
|
||||
placeholder: |
|
||||
Instead of one huge text field, it would be nice to have forms for bug-reports, feature-requests, ...
|
||||
Instead of one huge textfield, it would be nice to have forms for bug-reports, feature-requests, ...
|
||||
Great benefits with automatic labeling, assigning and other functionalitys not available in that form
|
||||
via old-fashioned markdown-templates. I would also love to see the use of a moderator bot 🤖 like
|
||||
https://github.com/marketplace/actions/issue-moderator-with-commands to auto close old issues and other things
|
||||
@@ -51,6 +51,6 @@ body:
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional Content
|
||||
label: Aditional Content
|
||||
description: Add any other context or screenshots about the feature request here.
|
||||
placeholder: This is a mockup of the design how I imagine it <screenshot>
|
||||
placeholder: This is a Mockup of the design how I imagine it <screenshot>
|
||||
|
||||
@@ -57,30 +57,6 @@ familiar with containerization technologies such as Docker.
|
||||
For downloads and instructions, visit the [NVIDIA CUDA Container
|
||||
Runtime Site](https://developer.nvidia.com/nvidia-container-runtime)
|
||||
|
||||
### cuDNN Installation for 40/30 Series Optimization* (Optional)
|
||||
|
||||
1. Find the InvokeAI folder
|
||||
2. Click on .venv folder - e.g., YourInvokeFolderHere\\.venv
|
||||
3. Click on Lib folder - e.g., YourInvokeFolderHere\\.venv\Lib
|
||||
4. Click on site-packages folder - e.g., YourInvokeFolderHere\\.venv\Lib\site-packages
|
||||
5. Click on Torch directory - e.g., YourInvokeFolderHere\InvokeAI\\.venv\Lib\site-packages\torch
|
||||
6. Click on the lib folder - e.g., YourInvokeFolderHere\\.venv\Lib\site-packages\torch\lib
|
||||
7. Copy everything inside the folder and save it elsewhere as a backup.
|
||||
8. Go to __https://developer.nvidia.com/cudnn__
|
||||
9. Login or create an Account.
|
||||
10. Choose the newer version of cuDNN. **Note:**
|
||||
There are two versions, 11.x or 12.x for the differents architectures(Turing,Maxwell Etc...) of GPUs.
|
||||
You can find which version you should download from [this link](https://docs.nvidia.com/deeplearning/cudnn/support-matrix/index.html).
|
||||
13. Download the latest version and extract it from the download location
|
||||
14. Find the bin folder E\cudnn-windows-x86_64-__Whatever Version__\bin
|
||||
15. Copy and paste the .dll files into YourInvokeFolderHere\\.venv\Lib\site-packages\torch\lib **Make sure to copy, and not move the files**
|
||||
16. If prompted, replace any existing files
|
||||
|
||||
**Notes:**
|
||||
* If no change is seen or any issues are encountered, follow the same steps as above and paste the torch/lib backup folder you made earlier and replace it. If you didn't make a backup, you can also uninstall and reinstall torch through the command line to repair this folder.
|
||||
* This optimization is intended for the newer version of graphics card (40/30 series) but results have been seen with older graphics card.
|
||||
|
||||
|
||||
### Torch Installation
|
||||
|
||||
When installing torch and torchvision manually with `pip`, remember to provide
|
||||
|
||||
@@ -14,7 +14,7 @@ fi
|
||||
VERSION=$(cd ..; python -c "from invokeai.version import __version__ as version; print(version)")
|
||||
PATCH=""
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
LATEST_TAG="v3-latest"
|
||||
LATEST_TAG="v3.0-latest"
|
||||
|
||||
echo Building installer for version $VERSION
|
||||
echo "Be certain that you're in the 'installer' directory before continuing."
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from logging import Logger
|
||||
import sqlite3
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
)
|
||||
@@ -28,6 +29,8 @@ from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.batch_manager import BatchManager
|
||||
from ..services.batch_manager_storage import SqliteBatchProcessStorage
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
@@ -71,18 +74,18 @@ class ApiDependencies:
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_location = str(db_path)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
@@ -116,15 +119,19 @@ class ApiDependencies:
|
||||
)
|
||||
)
|
||||
|
||||
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||
batch_manager = BatchManager(batch_manager_storage)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=ModelManagerService(config, logger),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
batch_manager=batch_manager,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
|
||||
106
invokeai/app/api/routers/batches.py
Normal file
106
invokeai/app/api/routers/batches.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Response
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from invokeai.app.services.batch_manager_storage import BatchSession, BatchSessionNotFoundException
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ...invocations import * # noqa: F401 F403
|
||||
from ...services.batch_manager import Batch, BatchProcessResponse
|
||||
from ...services.graph import Graph
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
batches_router = APIRouter(prefix="/v1/batches", tags=["sessions"])
|
||||
|
||||
|
||||
@batches_router.post(
|
||||
"/",
|
||||
operation_id="create_batch",
|
||||
responses={
|
||||
200: {"model": BatchProcessResponse},
|
||||
400: {"description": "Invalid json"},
|
||||
},
|
||||
)
|
||||
async def create_batch(
|
||||
graph: Graph = Body(description="The graph to initialize the session with"),
|
||||
batch: Batch = Body(description="Batch config to apply to the given graph"),
|
||||
) -> BatchProcessResponse:
|
||||
"""Creates a batch process"""
|
||||
return ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
|
||||
|
||||
|
||||
@batches_router.put(
|
||||
"/b/{batch_process_id}/invoke",
|
||||
operation_id="start_batch",
|
||||
responses={
|
||||
202: {"description": "Batch process started"},
|
||||
404: {"description": "Batch session not found"},
|
||||
},
|
||||
)
|
||||
async def start_batch(
|
||||
batch_process_id: str = Path(description="ID of Batch to start"),
|
||||
) -> Response:
|
||||
"""Executes a batch process"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
|
||||
return Response(status_code=202)
|
||||
except BatchSessionNotFoundException:
|
||||
raise HTTPException(status_code=404, detail="Batch session not found")
|
||||
|
||||
|
||||
@batches_router.delete(
|
||||
"/b/{batch_process_id}",
|
||||
operation_id="cancel_batch",
|
||||
responses={202: {"description": "The batch is canceled"}},
|
||||
)
|
||||
async def cancel_batch(
|
||||
batch_process_id: str = Path(description="The id of the batch process to cancel"),
|
||||
) -> Response:
|
||||
"""Cancels a batch process"""
|
||||
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@batches_router.get(
|
||||
"/incomplete",
|
||||
operation_id="list_incomplete_batches",
|
||||
responses={200: {"model": list[BatchProcessResponse]}},
|
||||
)
|
||||
async def list_incomplete_batches() -> list[BatchProcessResponse]:
|
||||
"""Lists incomplete batch processes"""
|
||||
return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes()
|
||||
|
||||
|
||||
@batches_router.get(
|
||||
"/",
|
||||
operation_id="list_batches",
|
||||
responses={200: {"model": list[BatchProcessResponse]}},
|
||||
)
|
||||
async def list_batches() -> list[BatchProcessResponse]:
|
||||
"""Lists all batch processes"""
|
||||
return ApiDependencies.invoker.services.batch_manager.get_batch_processes()
|
||||
|
||||
|
||||
@batches_router.get(
|
||||
"/b/{batch_process_id}",
|
||||
operation_id="get_batch",
|
||||
responses={200: {"model": BatchProcessResponse}},
|
||||
)
|
||||
async def get_batch(
|
||||
batch_process_id: str = Path(description="The id of the batch process to get"),
|
||||
) -> BatchProcessResponse:
|
||||
"""Gets a Batch Process"""
|
||||
return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id)
|
||||
|
||||
|
||||
@batches_router.get(
|
||||
"/b/{batch_process_id}/sessions",
|
||||
operation_id="get_batch_sessions",
|
||||
responses={200: {"model": list[BatchSession]}},
|
||||
)
|
||||
async def get_batch_sessions(
|
||||
batch_process_id: str = Path(description="The id of the batch process to get"),
|
||||
) -> list[BatchSession]:
|
||||
"""Gets a list of batch sessions for a given batch process"""
|
||||
return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id)
|
||||
@@ -9,13 +9,7 @@ from pydantic.fields import Field
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ...invocations import * # noqa: F401 F403
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
NodeAlreadyExecutedError,
|
||||
)
|
||||
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
|
||||
@@ -13,11 +13,15 @@ class SocketIO:
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app=app)
|
||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||
|
||||
self.__sio.on("subscribe_session", handler=self._handle_sub_session)
|
||||
self.__sio.on("unsubscribe_session", handler=self._handle_unsub_session)
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
||||
|
||||
self.__sio.on("subscribe_batch", handler=self._handle_sub_batch)
|
||||
self.__sio.on("unsubscribe_batch", handler=self._handle_unsub_batch)
|
||||
local_handler.register(event_name=EventServiceBase.batch_event, _func=self._handle_batch_event)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
@@ -25,12 +29,25 @@ class SocketIO:
|
||||
room=event[1]["data"]["graph_execution_state_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
||||
async def _handle_sub_session(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.enter_room(sid, data["session"])
|
||||
|
||||
# @app.sio.on('unsubscribe')
|
||||
|
||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
||||
async def _handle_unsub_session(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.leave_room(sid, data["session"])
|
||||
|
||||
async def _handle_batch_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["batch_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub_batch(self, sid, data, *args, **kwargs):
|
||||
if "batch_id" in data:
|
||||
self.__sio.enter_room(sid, data["batch_id"])
|
||||
|
||||
async def _handle_unsub_batch(self, sid, data, *args, **kwargs):
|
||||
if "batch_id" in data:
|
||||
self.__sio.enter_room(sid, data["batch_id"])
|
||||
|
||||
@@ -24,7 +24,7 @@ import invokeai.frontend.web as web_dir
|
||||
import mimetypes
|
||||
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||
from .api.routers import sessions, batches, models, images, boards, board_images, app_info
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||
|
||||
@@ -90,6 +90,8 @@ async def shutdown_event():
|
||||
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(batches.batches_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
@@ -5,6 +5,7 @@ import re
|
||||
import shlex
|
||||
import sys
|
||||
import time
|
||||
import sqlite3
|
||||
from typing import Union, get_type_hints, Optional
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
@@ -29,6 +30,8 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.app.services.batch_manager import BatchManager
|
||||
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
@@ -252,19 +255,18 @@ def invoke_cli():
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
@@ -298,15 +300,19 @@ def invoke_cli():
|
||||
)
|
||||
)
|
||||
|
||||
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||
batch_manager = BatchManager(batch_manager_storage)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||
images=images,
|
||||
boards=boards,
|
||||
batch_manager=batch_manager,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
|
||||
215
invokeai/app/services/batch_manager.py
Normal file
215
invokeai/app/services/batch_manager.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.batch_manager_storage import (
|
||||
Batch,
|
||||
BatchProcess,
|
||||
BatchProcessStorageBase,
|
||||
BatchSession,
|
||||
BatchSessionChanges,
|
||||
BatchSessionNotFoundException,
|
||||
)
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
class BatchProcessResponse(BaseModel):
|
||||
batch_id: str = Field(description="ID for the batch")
|
||||
session_ids: list[str] = Field(description="List of session IDs created for this batch")
|
||||
|
||||
|
||||
class BatchManagerBase(ABC):
|
||||
@abstractmethod
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
"""Starts the BatchManager service"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||
"""Creates a batch process"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_batch_process(self, batch_id: str) -> None:
|
||||
"""Runs a batch process"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||
"""Cancels a batch process"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_batch(self, batch_id: str) -> BatchProcessResponse:
|
||||
"""Gets a batch process"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_batch_processes(self) -> list[BatchProcessResponse]:
|
||||
"""Gets all batch processes"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
|
||||
"""Gets all incomplete batch processes"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||
"""Gets the sessions associated with a batch"""
|
||||
pass
|
||||
|
||||
|
||||
class BatchManager(BatchManagerBase):
|
||||
"""Responsible for managing currently running and scheduled batch jobs"""
|
||||
|
||||
__invoker: Invoker
|
||||
__batch_process_storage: BatchProcessStorageBase
|
||||
|
||||
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
|
||||
super().__init__()
|
||||
self.__batch_process_storage = batch_process_storage
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
|
||||
|
||||
async def on_event(self, event: Event):
|
||||
event_name = event[1]["event"]
|
||||
|
||||
match event_name:
|
||||
case "graph_execution_state_complete":
|
||||
await self._process(event, False)
|
||||
case "invocation_error":
|
||||
await self._process(event, True)
|
||||
|
||||
return event
|
||||
|
||||
async def _process(self, event: Event, err: bool) -> None:
|
||||
data = event[1]["data"]
|
||||
try:
|
||||
batch_session = self.__batch_process_storage.get_session_by_session_id(data["graph_execution_state_id"])
|
||||
except BatchSessionNotFoundException:
|
||||
return None
|
||||
changes = BatchSessionChanges(state="error" if err else "completed")
|
||||
batch_session = self.__batch_process_storage.update_session_state(
|
||||
batch_session.batch_id,
|
||||
batch_session.session_id,
|
||||
changes,
|
||||
)
|
||||
sessions = self.get_sessions(batch_session.batch_id)
|
||||
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
|
||||
if not batch_process.canceled:
|
||||
self.run_batch_process(batch_process.batch_id)
|
||||
|
||||
def _create_graph_execution_state(
|
||||
self, batch_process: BatchProcess, batch_indices: tuple[int, ...]
|
||||
) -> GraphExecutionState:
|
||||
graph = batch_process.graph.copy(deep=True)
|
||||
batch = batch_process.batch
|
||||
for index, bdl in enumerate(batch.data):
|
||||
for bd in bdl:
|
||||
node = graph.get_node(bd.node_path)
|
||||
if node is None:
|
||||
continue
|
||||
batch_index = batch_indices[index]
|
||||
datum = bd.items[batch_index]
|
||||
key = bd.field_name
|
||||
node.__dict__[key] = datum
|
||||
graph.update_node(bd.node_path, node)
|
||||
|
||||
return GraphExecutionState(graph=graph)
|
||||
|
||||
def run_batch_process(self, batch_id: str) -> None:
|
||||
self.__batch_process_storage.start(batch_id)
|
||||
batch_process = self.__batch_process_storage.get(batch_id)
|
||||
next_batch_index = self._get_batch_index_tuple(batch_process)
|
||||
if next_batch_index is None:
|
||||
# finished with current run
|
||||
if batch_process.current_run >= (batch_process.batch.runs - 1):
|
||||
# finished with all runs
|
||||
return
|
||||
batch_process.current_batch_index = 0
|
||||
batch_process.current_run += 1
|
||||
next_batch_index = self._get_batch_index_tuple(batch_process)
|
||||
if next_batch_index is None:
|
||||
# shouldn't happen; satisfy types
|
||||
return
|
||||
# remember to increment the batch index
|
||||
batch_process.current_batch_index += 1
|
||||
self.__batch_process_storage.save(batch_process)
|
||||
ges = self._create_graph_execution_state(batch_process=batch_process, batch_indices=next_batch_index)
|
||||
next_session = self.__batch_process_storage.create_session(
|
||||
BatchSession(
|
||||
batch_id=batch_id,
|
||||
session_id=str(uuid4()),
|
||||
state="uninitialized",
|
||||
batch_index=batch_process.current_batch_index,
|
||||
)
|
||||
)
|
||||
ges.id = next_session.session_id
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
self.__batch_process_storage.update_session_state(
|
||||
batch_id=next_session.batch_id,
|
||||
session_id=next_session.session_id,
|
||||
changes=BatchSessionChanges(state="in_progress"),
|
||||
)
|
||||
self.__invoker.services.events.emit_batch_session_created(next_session.batch_id, next_session.session_id)
|
||||
self.__invoker.invoke(ges, invoke_all=True)
|
||||
|
||||
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||
batch_process = BatchProcess(
|
||||
batch=batch,
|
||||
graph=graph,
|
||||
)
|
||||
batch_process = self.__batch_process_storage.save(batch_process)
|
||||
return BatchProcessResponse(
|
||||
batch_id=batch_process.batch_id,
|
||||
session_ids=[],
|
||||
)
|
||||
|
||||
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
|
||||
|
||||
def get_batch(self, batch_id: str) -> BatchProcess:
|
||||
return self.__batch_process_storage.get(batch_id)
|
||||
|
||||
def get_batch_processes(self) -> list[BatchProcessResponse]:
|
||||
bps = self.__batch_process_storage.get_all()
|
||||
return self._get_batch_process_responses(bps)
|
||||
|
||||
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
|
||||
bps = self.__batch_process_storage.get_incomplete()
|
||||
return self._get_batch_process_responses(bps)
|
||||
|
||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||
self.__batch_process_storage.cancel(batch_process_id)
|
||||
|
||||
def _get_batch_process_responses(self, batch_processes: list[BatchProcess]) -> list[BatchProcessResponse]:
|
||||
sessions = list()
|
||||
res: list[BatchProcessResponse] = list()
|
||||
for bp in batch_processes:
|
||||
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
|
||||
res.append(
|
||||
BatchProcessResponse(
|
||||
batch_id=bp.batch_id,
|
||||
session_ids=[session.session_id for session in sessions],
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_batch_index_tuple(self, batch_process: BatchProcess) -> Optional[tuple[int, ...]]:
|
||||
batch_indices = list()
|
||||
for batchdata in batch_process.batch.data:
|
||||
batch_indices.append(list(range(len(batchdata[0].items))))
|
||||
try:
|
||||
return list(product(*batch_indices))[batch_process.current_batch_index]
|
||||
except IndexError:
|
||||
return None
|
||||
707
invokeai/app/services/batch_manager_storage.py
Normal file
707
invokeai/app/services/batch_manager_storage.py
Normal file
@@ -0,0 +1,707 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.graph import Graph
|
||||
|
||||
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
|
||||
|
||||
|
||||
class BatchData(BaseModel):
|
||||
"""
|
||||
A batch data collection.
|
||||
"""
|
||||
|
||||
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
||||
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
||||
items: list[BatchDataType] = Field(
|
||||
default_factory=list, description="The list of items to substitute into the node/field."
|
||||
)
|
||||
|
||||
|
||||
class Batch(BaseModel):
|
||||
"""
|
||||
A batch, consisting of a list of a list of batch data collections.
|
||||
|
||||
First, each inner list[BatchData] is zipped into a single batch data collection.
|
||||
|
||||
Then, the final batch collection is created by taking the Cartesian product of all batch data collections.
|
||||
"""
|
||||
|
||||
data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.")
|
||||
runs: int = Field(default=1, description="Int stating how many times to iterate through all possible batch indices")
|
||||
|
||||
@validator("runs")
|
||||
def validate_positive_runs(cls, r: int):
|
||||
if r < 1:
|
||||
raise ValueError("runs must be a positive integer")
|
||||
return r
|
||||
|
||||
@validator("data")
|
||||
def validate_len(cls, v: list[list[BatchData]]):
|
||||
for batch_data in v:
|
||||
if any(len(batch_data[0].items) != len(i.items) for i in batch_data):
|
||||
raise ValueError("Zipped batch items must have all have same length")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_types(cls, v: list[list[BatchData]]):
|
||||
for batch_data in v:
|
||||
for datum in batch_data:
|
||||
for item in datum.items:
|
||||
if not all(isinstance(item, type(i)) for i in datum.items):
|
||||
raise TypeError("All items in a batch must have have same type")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_unique_field_mappings(cls, v: list[list[BatchData]]):
|
||||
paths: set[tuple[str, str]] = set()
|
||||
count: int = 0
|
||||
for batch_data in v:
|
||||
for datum in batch_data:
|
||||
paths.add((datum.node_path, datum.field_name))
|
||||
count += 1
|
||||
if len(paths) != count:
|
||||
raise ValueError("Each batch data must have unique node_id and field_name")
|
||||
return v
|
||||
|
||||
|
||||
def uuid_string():
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
|
||||
BATCH_SESSION_STATE = Literal["uninitialized", "in_progress", "completed", "error"]
|
||||
|
||||
|
||||
class BatchSession(BaseModel):
|
||||
batch_id: str = Field(defaultdescription="The Batch to which this BatchSession is attached.")
|
||||
session_id: str = Field(
|
||||
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
|
||||
)
|
||||
batch_index: int = Field(description="The index of this batch session in its parent batch process")
|
||||
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
|
||||
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
|
||||
batch: Batch = Field(description="The Batch to apply to this session.")
|
||||
current_batch_index: int = Field(default=0, description="The last executed batch index")
|
||||
current_run: int = Field(default=0, description="The current run of the batch")
|
||||
canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
|
||||
graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
|
||||
|
||||
|
||||
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
||||
state: BATCH_SESSION_STATE = Field(description="The state of this BatchSession")
|
||||
|
||||
|
||||
class BatchProcessNotFoundException(Exception):
|
||||
"""Raised when an Batch Process record is not found."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessSaveException(Exception):
|
||||
"""Raised when an Batch Process record cannot be saved."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessDeleteException(Exception):
|
||||
"""Raised when an Batch Process record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionNotFoundException(Exception):
|
||||
"""Raised when an Batch Session record is not found."""
|
||||
|
||||
def __init__(self, message="BatchSession record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionSaveException(Exception):
|
||||
"""Raised when an Batch Session record cannot be saved."""
|
||||
|
||||
def __init__(self, message="BatchSession record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionDeleteException(Exception):
|
||||
"""Raised when an Batch Session record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="BatchSession record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the Batch Process record store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, batch_id: str) -> None:
|
||||
"""Deletes a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
batch_process: BatchProcess,
|
||||
) -> BatchProcess:
|
||||
"""Saves a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> BatchProcess:
|
||||
"""Gets a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
"""Gets a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_incomplete(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
"""Gets a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> None:
|
||||
"""'Starts' a BatchProcess record by marking its `canceled` attribute to False."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> None:
|
||||
"""'Cancels' a BatchProcess record by setting its `canceled` attribute to True."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
) -> BatchSession:
|
||||
"""Creates a BatchSession attached to a BatchProcess."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_sessions(
|
||||
self,
|
||||
sessions: list[BatchSession],
|
||||
) -> list[BatchSession]:
|
||||
"""Creates many BatchSessions attached to a BatchProcess."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
||||
"""Gets a BatchSession by session_id"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
||||
"""Gets all BatchSession's for a given BatchProcess id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
||||
"""Gets all BatchSession's for a given list of session ids."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_next_session(self, batch_id: str) -> BatchSession:
|
||||
"""Gets the next BatchSession with state `uninitialized`, for a given BatchProcess id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_session_state(
|
||||
self,
|
||||
batch_id: str,
|
||||
session_id: str,
|
||||
changes: BatchSessionChanges,
|
||||
) -> BatchSession:
|
||||
"""Updates the state of a BatchSession record."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
super().__init__()
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `batch_process` table and `batch_session` junction table."""
|
||||
|
||||
# Create the `batch_process` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS batch_process (
|
||||
batch_id TEXT NOT NULL PRIMARY KEY,
|
||||
batch TEXT NOT NULL,
|
||||
graph TEXT NOT NULL,
|
||||
current_batch_index NUMBER NOT NULL,
|
||||
current_run NUMBER NOT NULL,
|
||||
canceled BOOLEAN NOT NULL DEFAULT(0),
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
|
||||
AFTER UPDATE
|
||||
ON batch_process FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE batch_process SET updated_at = current_timestamp
|
||||
WHERE batch_id = old.batch_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `batch_session` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS batch_session (
|
||||
batch_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
state TEXT NOT NULL,
|
||||
batch_index NUMBER NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between batch_process and batch_session using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (batch_id,session_id),
|
||||
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for batch id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for batch id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
|
||||
AFTER UPDATE
|
||||
ON batch_session FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE batch_id = old.batch_id AND session_id = old.session_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def delete(self, batch_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM batch_process
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessDeleteException from e
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
batch_process: BatchProcess,
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR REPLACE INTO batch_process (batch_id, batch, graph, current_batch_index, current_run)
|
||||
VALUES (?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
batch_process.batch_id,
|
||||
batch_process.batch.json(),
|
||||
batch_process.graph.json(),
|
||||
batch_process.current_batch_index,
|
||||
batch_process.current_run,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(batch_process.batch_id)
|
||||
|
||||
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
|
||||
"""Deserializes a batch session."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
batch_raw = session_dict.get("batch", "unknown")
|
||||
graph_raw = session_dict.get("graph", "unknown")
|
||||
current_batch_index = session_dict.get("current_batch_index", 0)
|
||||
current_run = session_dict.get("current_run", 0)
|
||||
canceled = session_dict.get("canceled", 0)
|
||||
return BatchProcess(
|
||||
batch_id=batch_id,
|
||||
batch=parse_raw_as(Batch, batch_raw),
|
||||
graph=parse_raw_as(Graph, graph_raw),
|
||||
current_batch_index=current_batch_index,
|
||||
current_run=current_run,
|
||||
canceled=canceled == 1,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_process
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchProcessNotFoundException
|
||||
return self._deserialize_batch_process(dict(result))
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_process
|
||||
"""
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return list()
|
||||
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
||||
|
||||
def get_incomplete(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT bp.*
|
||||
FROM batch_process bp
|
||||
WHERE bp.batch_id IN
|
||||
(
|
||||
SELECT batch_id
|
||||
FROM batch_session bs
|
||||
WHERE state IN ('uninitialized', 'in_progress')
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return list()
|
||||
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
||||
|
||||
def start(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE batch_process
|
||||
SET canceled = 0
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE batch_process
|
||||
SET canceled = 1
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(session.batch_id, session.session_id, session.state, session.batch_index),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session_by_session_id(session.session_id)
|
||||
|
||||
def create_sessions(
|
||||
self,
|
||||
sessions: list[BatchSession],
|
||||
) -> list[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
session_data = [(session.batch_id, session.session_id, session.state) for session in sessions]
|
||||
self._cursor.executemany(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
session_data,
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_sessions_by_session_ids([session.session_id for session in sessions])
|
||||
|
||||
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE session_id= ?;
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
return self._deserialize_batch_session(dict(result))
|
||||
|
||||
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
|
||||
"""Deserializes a batch session."""
|
||||
|
||||
return BatchSession.parse_obj(session_dict)
|
||||
|
||||
def get_next_session(self, batch_id: str) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ? AND state = 'uninitialized';
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
session = self._deserialize_batch_session(dict(result))
|
||||
return session
|
||||
|
||||
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
placeholders = ",".join("?" * len(session_ids))
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM batch_session
|
||||
WHERE session_id
|
||||
IN ({placeholders})
|
||||
""",
|
||||
tuple(session_ids),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
def update_session_state(
|
||||
self,
|
||||
batch_id: str,
|
||||
session_id: str,
|
||||
changes: BatchSessionChanges,
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Change the state of a batch session
|
||||
if changes.state is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE batch_session
|
||||
SET state = ?
|
||||
WHERE batch_id = ? AND session_id = ?;
|
||||
""",
|
||||
(changes.state, batch_id, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session_by_session_id(session_id)
|
||||
@@ -56,15 +56,13 @@ class BoardImageRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
@@ -89,15 +89,13 @@ class BoardRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
@@ -13,6 +13,7 @@ from invokeai.app.services.model_manager_service import (
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
batch_event: str = "batch_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
@@ -20,12 +21,21 @@ class EventServiceBase:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Session events are emitted to a room with the session_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
def __emit_batch_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Batch events are emitted to a room with the batch_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.batch_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
@@ -187,3 +197,14 @@ class EventServiceBase:
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_batch_session_created(
|
||||
self,
|
||||
batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
) -> None:
|
||||
"""Emitted when a batch session is created"""
|
||||
self.__emit_batch_event(
|
||||
event_name="batch_session_created",
|
||||
payload=dict(batch_id=batch_id, graph_execution_state_id=graph_execution_state_id),
|
||||
)
|
||||
|
||||
@@ -152,15 +152,13 @@ class ImageRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
from invokeai.app.services.batch_manager import BatchManagerBase
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
@@ -22,6 +23,7 @@ class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||
batch_manager: "BatchManagerBase"
|
||||
board_images: "BoardImagesServiceABC"
|
||||
boards: "BoardServiceABC"
|
||||
configuration: "InvokeAIAppConfig"
|
||||
@@ -38,6 +40,7 @@ class InvocationServices:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_manager: "BatchManagerBase",
|
||||
board_images: "BoardImagesServiceABC",
|
||||
boards: "BoardServiceABC",
|
||||
configuration: "InvokeAIAppConfig",
|
||||
@@ -52,6 +55,7 @@ class InvocationServices:
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
):
|
||||
self.batch_manager = batch_manager
|
||||
self.board_images = board_images
|
||||
self.boards = boards
|
||||
self.boards = boards
|
||||
|
||||
@@ -12,23 +12,19 @@ sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_filename: str
|
||||
_table_name: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
||||
def __init__(self, conn: sqlite3.Connection, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._conn = conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
@@ -49,8 +45,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
parsed = parse_raw_as(item_type, item)
|
||||
return parsed
|
||||
return parse_raw_as(item_type, item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
|
||||
171
invokeai/frontend/web/dist/assets/App-78495256.js
vendored
Normal file
171
invokeai/frontend/web/dist/assets/App-78495256.js
vendored
Normal file
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-d1567775.js
vendored
169
invokeai/frontend/web/dist/assets/App-d1567775.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
126
invokeai/frontend/web/dist/assets/index-08cda350.js
vendored
Normal file
126
invokeai/frontend/web/dist/assets/index-08cda350.js
vendored
Normal file
File diff suppressed because one or more lines are too long
128
invokeai/frontend/web/dist/assets/index-f83c2c5c.js
vendored
128
invokeai/frontend/web/dist/assets/index-f83c2c5c.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/menu-3d10c968.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/menu-3d10c968.js
vendored
Normal file
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-f83c2c5c.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-08cda350.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
||||
2
invokeai/frontend/web/dist/locales/en.json
vendored
2
invokeai/frontend/web/dist/locales/en.json
vendored
@@ -511,7 +511,6 @@
|
||||
"maskBlur": "Blur",
|
||||
"maskBlurMethod": "Blur Method",
|
||||
"coherencePassHeader": "Coherence Pass",
|
||||
"coherenceMode": "Mode",
|
||||
"coherenceSteps": "Steps",
|
||||
"coherenceStrength": "Strength",
|
||||
"seamLowThreshold": "Low",
|
||||
@@ -521,7 +520,6 @@
|
||||
"scaledHeight": "Scaled H",
|
||||
"infillMethod": "Infill Method",
|
||||
"tileSize": "Tile Size",
|
||||
"patchmatchDownScaleSize": "Downscale",
|
||||
"boundingBoxHeader": "Bounding Box",
|
||||
"seamCorrectionHeader": "Seam Correction",
|
||||
"infillScalingHeader": "Infill and Scaling",
|
||||
|
||||
@@ -45,7 +45,6 @@ export type AppConfig = {
|
||||
* Whether or not we should update image urls when image loading errors
|
||||
*/
|
||||
shouldUpdateImagesOnConnect: boolean;
|
||||
shouldFetchMetadataFromApi: boolean;
|
||||
disabledTabs: InvokeTabName[];
|
||||
disabledFeatures: AppFeature[];
|
||||
disabledSDFeatures: SDFeature[];
|
||||
|
||||
@@ -27,7 +27,7 @@ import {
|
||||
setShouldShowImageDetails,
|
||||
setShouldShowProgressInViewer,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@@ -49,7 +49,7 @@ import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuIte
|
||||
|
||||
const currentImageButtonsSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
({ gallery, system, ui, config }, activeTabName) => {
|
||||
({ gallery, system, ui }, activeTabName) => {
|
||||
const { isProcessing, isConnected, shouldConfirmOnDelete, progressImage } =
|
||||
system;
|
||||
|
||||
@@ -59,8 +59,6 @@ const currentImageButtonsSelector = createSelector(
|
||||
shouldShowProgressInViewer,
|
||||
} = ui;
|
||||
|
||||
const { shouldFetchMetadataFromApi } = config;
|
||||
|
||||
const lastSelectedImage = gallery.selection[gallery.selection.length - 1];
|
||||
|
||||
return {
|
||||
@@ -74,7 +72,6 @@ const currentImageButtonsSelector = createSelector(
|
||||
shouldHidePreview,
|
||||
shouldShowProgressInViewer,
|
||||
lastSelectedImage,
|
||||
shouldFetchMetadataFromApi,
|
||||
};
|
||||
},
|
||||
{
|
||||
@@ -95,7 +92,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
shouldShowImageDetails,
|
||||
lastSelectedImage,
|
||||
shouldShowProgressInViewer,
|
||||
shouldFetchMetadataFromApi,
|
||||
} = useAppSelector(currentImageButtonsSelector);
|
||||
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
||||
@@ -110,16 +106,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
lastSelectedImage?.image_name ?? skipToken
|
||||
);
|
||||
|
||||
const getMetadataArg = useMemo(() => {
|
||||
if (lastSelectedImage) {
|
||||
return { image: lastSelectedImage, shouldFetchMetadataFromApi };
|
||||
} else {
|
||||
return skipToken;
|
||||
}
|
||||
}, [lastSelectedImage, shouldFetchMetadataFromApi]);
|
||||
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||
getMetadataArg,
|
||||
lastSelectedImage ?? skipToken,
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
isLoading: res.isFetching,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Flex, MenuItem, Spinner } from '@chakra-ui/react';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
imagesToChangeSelected,
|
||||
@@ -34,7 +34,6 @@ import {
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { configSelector } from '../../../system/store/configSelectors';
|
||||
|
||||
type SingleSelectionMenuItemsProps = {
|
||||
imageDTO: ImageDTO;
|
||||
@@ -49,10 +48,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const toaster = useAppToaster();
|
||||
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
const { shouldFetchMetadataFromApi } = useAppSelector(configSelector);
|
||||
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||
{ image: imageDTO, shouldFetchMetadataFromApi },
|
||||
imageDTO,
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
isLoading: res.isFetching,
|
||||
|
||||
@@ -15,8 +15,6 @@ import { useGetImageMetadataFromFileQuery } from 'services/api/endpoints/images'
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import DataViewer from './DataViewer';
|
||||
import ImageMetadataActions from './ImageMetadataActions';
|
||||
import { useAppSelector } from '../../../../app/store/storeHooks';
|
||||
import { configSelector } from '../../../system/store/configSelectors';
|
||||
|
||||
type ImageMetadataViewerProps = {
|
||||
image: ImageDTO;
|
||||
@@ -29,17 +27,12 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
// dispatch(setShouldShowImageDetails(false));
|
||||
// });
|
||||
|
||||
const { shouldFetchMetadataFromApi } = useAppSelector(configSelector);
|
||||
|
||||
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
|
||||
{ image, shouldFetchMetadataFromApi },
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
metadata: res?.currentData?.metadata,
|
||||
workflow: res?.currentData?.workflow,
|
||||
}),
|
||||
}
|
||||
);
|
||||
const { metadata, workflow } = useGetImageMetadataFromFileQuery(image, {
|
||||
selectFromResult: (res) => ({
|
||||
metadata: res?.currentData?.metadata,
|
||||
workflow: res?.currentData?.workflow,
|
||||
}),
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex
|
||||
|
||||
@@ -5,7 +5,6 @@ import { merge } from 'lodash-es';
|
||||
|
||||
export const initialConfigState: AppConfig = {
|
||||
shouldUpdateImagesOnConnect: false,
|
||||
shouldFetchMetadataFromApi: false,
|
||||
disabledTabs: [],
|
||||
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
|
||||
disabledSDFeatures: [
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { EntityState, Update } from '@reduxjs/toolkit';
|
||||
import { fetchBaseQuery } from '@reduxjs/toolkit/dist/query';
|
||||
import { PatchCollection } from '@reduxjs/toolkit/dist/query/core/buildThunks';
|
||||
import {
|
||||
ASSETS_CATEGORIES,
|
||||
@@ -7,14 +6,9 @@ import {
|
||||
IMAGE_CATEGORIES,
|
||||
IMAGE_LIMIT,
|
||||
} from 'features/gallery/store/types';
|
||||
import {
|
||||
ImageMetadataAndWorkflow,
|
||||
zCoreMetadata,
|
||||
} from 'features/nodes/types/types';
|
||||
import { getMetadataAndWorkflowFromImageBlob } from 'features/nodes/util/getMetadataAndWorkflowFromImageBlob';
|
||||
import { keyBy } from 'lodash-es';
|
||||
import { ApiFullTagDescription, LIST_TAG, api } from '..';
|
||||
import { $authToken, $projectId } from '../client';
|
||||
import { components, paths } from '../schema';
|
||||
import {
|
||||
DeleteBoardResult,
|
||||
@@ -33,6 +27,9 @@ import {
|
||||
imagesSelectors,
|
||||
} from '../util';
|
||||
import { boardsApi } from './boards';
|
||||
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
|
||||
import { fetchBaseQuery } from '@reduxjs/toolkit/dist/query';
|
||||
import { $authToken, $projectId } from '../client';
|
||||
|
||||
export const imagesApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
@@ -120,16 +117,8 @@ export const imagesApi = api.injectEndpoints({
|
||||
],
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
getImageMetadataFromFile: build.query<
|
||||
ImageMetadataAndWorkflow,
|
||||
{ image: ImageDTO; shouldFetchMetadataFromApi: boolean }
|
||||
>({
|
||||
queryFn: async (
|
||||
args: { image: ImageDTO; shouldFetchMetadataFromApi: boolean },
|
||||
api,
|
||||
extraOptions,
|
||||
fetchWithBaseQuery
|
||||
) => {
|
||||
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, ImageDTO>({
|
||||
queryFn: async (args: ImageDTO, api, extraOptions) => {
|
||||
const authToken = $authToken.get();
|
||||
const projectId = $projectId.get();
|
||||
const customBaseQuery = fetchBaseQuery({
|
||||
@@ -150,35 +139,17 @@ export const imagesApi = api.injectEndpoints({
|
||||
});
|
||||
|
||||
const response = await customBaseQuery(
|
||||
args.image.image_url,
|
||||
args.image_url,
|
||||
api,
|
||||
extraOptions
|
||||
);
|
||||
const blobData = await getMetadataAndWorkflowFromImageBlob(
|
||||
const data = await getMetadataAndWorkflowFromImageBlob(
|
||||
response.data as Blob
|
||||
);
|
||||
|
||||
let metadata = blobData.metadata;
|
||||
|
||||
if (args.shouldFetchMetadataFromApi) {
|
||||
const metadataResponse = await fetchWithBaseQuery(
|
||||
`images/i/${args.image.image_name}/metadata`
|
||||
);
|
||||
if (metadataResponse.data) {
|
||||
const metadataResult = zCoreMetadata.safeParse(
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(metadataResponse.data as any)?.metadata
|
||||
);
|
||||
if (metadataResult.success) {
|
||||
metadata = metadataResult.data;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { data: { ...blobData, metadata } };
|
||||
return { data };
|
||||
},
|
||||
providesTags: (result, error, { image }) => [
|
||||
{ type: 'ImageMetadataFromFile', id: image.image_name },
|
||||
providesTags: (result, error, image_dto) => [
|
||||
{ type: 'ImageMetadataFromFile', id: image_dto.image_name },
|
||||
],
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
|
||||
344
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
344
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@@ -1 +1 @@
|
||||
__version__ = "3.1.1"
|
||||
__version__ = "3.1.0"
|
||||
|
||||
@@ -102,6 +102,7 @@ dependencies = [
|
||||
"flake8",
|
||||
"Flake8-pyproject",
|
||||
"pytest>6.0.0",
|
||||
"pytest-asyncio",
|
||||
"pytest-cov",
|
||||
"pytest-datadir",
|
||||
]
|
||||
@@ -176,6 +177,7 @@ version = { attr = "invokeai.version.__version__" }
|
||||
#=== Begin: PyTest and Coverage
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml"
|
||||
asyncio_mode = "auto"
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
source = ["invokeai"]
|
||||
|
||||
@@ -25,6 +25,7 @@ from invokeai.app.services.graph import (
|
||||
LibraryGraph,
|
||||
)
|
||||
import pytest
|
||||
import sqlite3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -42,9 +43,8 @@ def simple_graph():
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=sqlite_memory, table_name="graph_executions"
|
||||
)
|
||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
return InvocationServices(
|
||||
model_manager=None, # type: ignore
|
||||
events=TestEventService(),
|
||||
@@ -52,9 +52,10 @@ def mock_services() -> InvocationServices:
|
||||
images=None, # type: ignore
|
||||
latents=None, # type: ignore
|
||||
boards=None, # type: ignore
|
||||
batch_manager=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
|
||||
@@ -6,18 +6,34 @@ from .test_nodes import (
|
||||
create_edge,
|
||||
wait_until,
|
||||
)
|
||||
# from fastapi_events.handlers.local import
|
||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.api.events import FastAPIEventService
|
||||
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
|
||||
from invokeai.app.services.batch_manager import (
|
||||
Batch,
|
||||
BatchManager,
|
||||
)
|
||||
from invokeai.app.services.graph import (
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
GraphInvocation,
|
||||
LibraryGraph,
|
||||
)
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import sqlite3
|
||||
import time
|
||||
|
||||
from httpx import AsyncClient
|
||||
from fastapi import FastAPI
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -29,25 +45,128 @@ def simple_graph():
|
||||
return g
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_batch():
|
||||
return Batch(
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
"Strawberry sushi",
|
||||
"Broccoli sushi",
|
||||
"Asparagus sushi",
|
||||
"Tea sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
[
|
||||
BatchData(
|
||||
node_path="2",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
"Ichigo sushi",
|
||||
"Momo sushi",
|
||||
"Mikan sushi",
|
||||
"Cha sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_with_subgraph():
|
||||
sub_g = Graph()
|
||||
sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
sub_g.add_node(TextToImageTestInvocation(id="2"))
|
||||
sub_g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||
g = Graph()
|
||||
g.add_node(GraphInvocation(id="1", graph=sub_g))
|
||||
return g
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_with_subgraph():
|
||||
return Batch(
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_path="1.1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
"Strawberry sushi",
|
||||
"Broccoli sushi",
|
||||
"Asparagus sushi",
|
||||
"Tea sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
[
|
||||
BatchData(
|
||||
node_path="1.2",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
"Ichigo sushi",
|
||||
"Momo sushi",
|
||||
"Mikan sushi",
|
||||
"Cha sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
# @pytest_asyncio.fixture(scope="module")
|
||||
# def event_loop():
|
||||
# import asyncio
|
||||
# try:
|
||||
# loop = asyncio.get_running_loop()
|
||||
# except RuntimeError as e:
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# # fastapi_events.event_store
|
||||
# yield loop
|
||||
# loop.close()
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def db_conn():
|
||||
return sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
|
||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||
# the test invocations.
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=sqlite_memory, table_name="graph_executions"
|
||||
@pytest.fixture(autouse=True)
|
||||
async def mock_services(db_conn : sqlite3.Connection) -> InvocationServices:
|
||||
app = FastAPI()
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
client = AsyncClient(app=app)
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
return InvocationServices(
|
||||
model_manager=None, # type: ignore
|
||||
events=TestEventService(),
|
||||
events=events,
|
||||
logger=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
latents=None, # type: ignore
|
||||
batch_manager=BatchManager(batch_manager_storage),
|
||||
boards=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
@@ -130,3 +249,135 @@ def test_handles_errors(mock_invoker: Invoker):
|
||||
assert g.is_complete()
|
||||
|
||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
||||
|
||||
|
||||
def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
batch=batch_with_subgraph,
|
||||
graph=graph_with_subgraph,
|
||||
)
|
||||
assert batch_process_res.batch_id
|
||||
# TODO: without the mock events service emitting the `graph_execution_state` events,
|
||||
# the batch sessions do not know when they have finished, so this logic will fail
|
||||
|
||||
# assert len(batch_process_res.session_ids) == 25
|
||||
# mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||
|
||||
# def has_executed_all_batches(batch_id: str):
|
||||
# batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id)
|
||||
# print(batch_sessions)
|
||||
# return all((s.state == "completed" for s in batch_sessions))
|
||||
|
||||
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
||||
|
||||
async def test_can_run_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
batch=batch_with_subgraph,
|
||||
graph=graph_with_subgraph,
|
||||
)
|
||||
mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||
sessions = []
|
||||
attempts = 0
|
||||
import asyncio
|
||||
while len(sessions) != 25 and attempts < 20:
|
||||
batch = mock_invoker.services.batch_manager.get_batch(batch_process_res.batch_id)
|
||||
sessions = mock_invoker.services.batch_manager.get_sessions(batch_process_res.batch_id)
|
||||
await asyncio.sleep(1)
|
||||
attempts += 1
|
||||
assert len(sessions) == 25
|
||||
|
||||
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
batch=simple_batch,
|
||||
graph=simple_graph,
|
||||
)
|
||||
assert batch_process_res.batch_id
|
||||
# TODO: without the mock events service emitting the `graph_execution_state` events,
|
||||
# the batch sessions do not know when they have finished, so this logic will fail
|
||||
|
||||
# assert len(batch_process_res.session_ids) == 25
|
||||
# mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||
|
||||
# def has_executed_all_batches(batch_id: str):
|
||||
# batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id)
|
||||
# print(batch_sessions)
|
||||
# return all((s.state == "completed" for s in batch_sessions))
|
||||
|
||||
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
||||
|
||||
|
||||
def test_cannot_create_bad_batches():
|
||||
batch = None
|
||||
try:
|
||||
batch = Batch( # This batch has a duplicate node_path|fieldname combo
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
[
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
assert e
|
||||
try:
|
||||
batch = Batch( # This batch has different item list lengths in the same group
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
],
|
||||
),
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
"Courgette sushi",
|
||||
],
|
||||
),
|
||||
],
|
||||
[
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
assert e
|
||||
try:
|
||||
batch = Batch( # This batch has a type mismatch in single items list
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=["Tomato sushi", 5],
|
||||
),
|
||||
],
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
assert e
|
||||
assert not batch
|
||||
|
||||
@@ -51,6 +51,7 @@ class ImageTestInvocationOutput(BaseInvocationOutput):
|
||||
@invocation("test_text_to_image")
|
||||
class TextToImageTestInvocation(BaseInvocation):
|
||||
prompt: str = Field(default="")
|
||||
prompt2: str = Field(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
|
||||
@@ -1,20 +1,27 @@
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import pytest
|
||||
import sqlite3
|
||||
|
||||
|
||||
class TestModel(BaseModel):
|
||||
id: str = Field(description="ID")
|
||||
name: str = Field(description="Name")
|
||||
|
||||
|
||||
def test_sqlite_service_can_create_and_get():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
@pytest.fixture
|
||||
def db() -> SqliteItemStorage[TestModel]:
|
||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
return SqliteItemStorage[TestModel](db_conn, "test", "id")
|
||||
|
||||
|
||||
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
assert db.get("1") == TestModel(id="1", name="Test")
|
||||
|
||||
|
||||
def test_sqlite_service_can_list():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_list(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.set(TestModel(id="2", name="Test"))
|
||||
db.set(TestModel(id="3", name="Test"))
|
||||
@@ -30,15 +37,13 @@ def test_sqlite_service_can_list():
|
||||
]
|
||||
|
||||
|
||||
def test_sqlite_service_can_delete():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.delete("1")
|
||||
assert db.get("1") is None
|
||||
|
||||
|
||||
def test_sqlite_service_calls_set_callback():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]):
|
||||
called = False
|
||||
|
||||
def on_changed(item: TestModel):
|
||||
@@ -50,8 +55,7 @@ def test_sqlite_service_calls_set_callback():
|
||||
assert called
|
||||
|
||||
|
||||
def test_sqlite_service_calls_delete_callback():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]):
|
||||
called = False
|
||||
|
||||
def on_deleted(item_id: str):
|
||||
@@ -64,8 +68,7 @@ def test_sqlite_service_calls_delete_callback():
|
||||
assert called
|
||||
|
||||
|
||||
def test_sqlite_service_can_list_with_pagination():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_list_with_pagination(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.set(TestModel(id="2", name="Test"))
|
||||
db.set(TestModel(id="3", name="Test"))
|
||||
@@ -77,8 +80,7 @@ def test_sqlite_service_can_list_with_pagination():
|
||||
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
||||
|
||||
|
||||
def test_sqlite_service_can_list_with_pagination_and_offset():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_list_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.set(TestModel(id="2", name="Test"))
|
||||
db.set(TestModel(id="3", name="Test"))
|
||||
@@ -90,8 +92,7 @@ def test_sqlite_service_can_list_with_pagination_and_offset():
|
||||
assert results.items == [TestModel(id="3", name="Test")]
|
||||
|
||||
|
||||
def test_sqlite_service_can_search():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_search(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.set(TestModel(id="2", name="Test"))
|
||||
db.set(TestModel(id="3", name="Test"))
|
||||
@@ -107,8 +108,7 @@ def test_sqlite_service_can_search():
|
||||
]
|
||||
|
||||
|
||||
def test_sqlite_service_can_search_with_pagination():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_search_with_pagination(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.set(TestModel(id="2", name="Test"))
|
||||
db.set(TestModel(id="3", name="Test"))
|
||||
@@ -120,8 +120,7 @@ def test_sqlite_service_can_search_with_pagination():
|
||||
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
||||
|
||||
|
||||
def test_sqlite_service_can_search_with_pagination_and_offset():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_search_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.set(TestModel(id="2", name="Test"))
|
||||
db.set(TestModel(id="3", name="Test"))
|
||||
|
||||
Reference in New Issue
Block a user