mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-23 17:18:02 -05:00
Compare commits
47 Commits
test
...
feat/restr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09609cd553 | ||
|
|
0aedd6d9f0 | ||
|
|
70a1202deb | ||
|
|
9a1aea9caf | ||
|
|
388d36b839 | ||
|
|
bedb35af8c | ||
|
|
dc232438fb | ||
|
|
d7edf5aaad | ||
|
|
3ad1226d1e | ||
|
|
86ca9f122d | ||
|
|
2c6772f92f | ||
|
|
e6c1e03b8b | ||
|
|
c9d95e5758 | ||
|
|
10755718b8 | ||
|
|
459c7b3b74 | ||
|
|
353719f81d | ||
|
|
bd4b260c23 | ||
|
|
3e389d3f60 | ||
|
|
ffb01f1345 | ||
|
|
faa0a8236c | ||
|
|
e4d73d3659 | ||
|
|
6994783c17 | ||
|
|
3f9708f166 | ||
|
|
bcf0d8a590 | ||
|
|
2060ee22f2 | ||
|
|
3fd79b837f | ||
|
|
1c099e0abb | ||
|
|
95cca9493c | ||
|
|
779c902402 | ||
|
|
99e6bb48ba | ||
|
|
c3d6ff5b11 | ||
|
|
bba962b82f | ||
|
|
78b8cfede3 | ||
|
|
e9879b9e1f | ||
|
|
e21f3af5ab | ||
|
|
2ab7c5f783 | ||
|
|
8bbd938be9 | ||
|
|
b4cee46936 | ||
|
|
48626c40fd | ||
|
|
a1001b6d10 | ||
|
|
50df641e1b | ||
|
|
22dd64dfa4 | ||
|
|
0a929ca3de | ||
|
|
15cabc4968 | ||
|
|
21d5969942 | ||
|
|
334dcf71c4 | ||
|
|
52274087f3 |
@@ -1,35 +1,35 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
|
||||||
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
|
|
||||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
|
||||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
|
||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
|
||||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
|
||||||
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
|
||||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
from ..services.default_graphs import create_system_graphs
|
from ..services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
|
||||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
from ..services.board_images.board_images_default import BoardImagesService
|
||||||
from ..services.image_file_storage import DiskImageFileStorage
|
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.boards.boards_default import BoardService
|
||||||
|
from ..services.config import InvokeAIAppConfig
|
||||||
|
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
||||||
|
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||||
|
from ..services.images.images_default import ImageService
|
||||||
|
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||||
|
from ..services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||||
|
from ..services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from ..services.invocation_stats import InvocationStatsService
|
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||||
from ..services.model_manager_service import ModelManagerService
|
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||||
from ..services.thread import lock
|
from ..services.names.names_default import SimpleNameService
|
||||||
|
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||||
|
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||||
|
from ..services.shared.default_graphs import create_system_graphs
|
||||||
|
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||||
|
from ..services.shared.sqlite import SqliteDatabase
|
||||||
|
from ..services.urls.urls_default import LocalUrlService
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@@ -63,100 +63,64 @@ class ApiDependencies:
|
|||||||
logger.info(f"Root directory = {str(config.root_path)}")
|
logger.info(f"Root directory = {str(config.root_path)}")
|
||||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
|
||||||
|
|
||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
db = SqliteDatabase(config, logger)
|
||||||
if config.use_memory_db:
|
|
||||||
db_location = ":memory:"
|
|
||||||
else:
|
|
||||||
db_path = config.db_path
|
|
||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
db_location = str(db_path)
|
|
||||||
|
|
||||||
logger.info(f"Using database at {db_location}")
|
configuration = config
|
||||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
logger = logger
|
||||||
|
|
||||||
if config.log_sql:
|
board_image_records = SqliteBoardImageRecordStorage(db=db)
|
||||||
db_conn.set_trace_callback(print)
|
board_images = BoardImagesService()
|
||||||
db_conn.execute("PRAGMA foreign_keys = ON;")
|
board_records = SqliteBoardRecordStorage(db=db)
|
||||||
|
boards = BoardService()
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
events = FastAPIEventService(event_handler_id)
|
||||||
conn=db_conn, table_name="graph_executions", lock=lock
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||||
)
|
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
|
||||||
|
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
urls = LocalUrlService()
|
image_records = SqliteImageRecordStorage(db=db)
|
||||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
|
images = ImageService()
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||||
names = SimpleNameService()
|
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
|
model_manager = ModelManagerService(config, logger)
|
||||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
|
names = SimpleNameService()
|
||||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
|
performance_statistics = InvocationStatsService()
|
||||||
|
processor = DefaultInvocationProcessor()
|
||||||
boards = BoardService(
|
queue = MemoryInvocationQueue()
|
||||||
services=BoardServiceDependencies(
|
session_processor = DefaultSessionProcessor()
|
||||||
board_image_record_storage=board_image_record_storage,
|
session_queue = SqliteSessionQueue(db=db)
|
||||||
board_record_storage=board_record_storage,
|
urls = LocalUrlService()
|
||||||
image_record_storage=image_record_storage,
|
|
||||||
url=urls,
|
|
||||||
logger=logger,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
board_images = BoardImagesService(
|
|
||||||
services=BoardImagesServiceDependencies(
|
|
||||||
board_image_record_storage=board_image_record_storage,
|
|
||||||
board_record_storage=board_record_storage,
|
|
||||||
image_record_storage=image_record_storage,
|
|
||||||
url=urls,
|
|
||||||
logger=logger,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
images = ImageService(
|
|
||||||
services=ImageServiceDependencies(
|
|
||||||
board_image_record_storage=board_image_record_storage,
|
|
||||||
image_record_storage=image_record_storage,
|
|
||||||
image_file_storage=image_file_storage,
|
|
||||||
url=urls,
|
|
||||||
logger=logger,
|
|
||||||
names=names,
|
|
||||||
graph_execution_manager=graph_execution_manager,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=ModelManagerService(config, logger),
|
board_image_records=board_image_records,
|
||||||
events=events,
|
|
||||||
latents=latents,
|
|
||||||
images=images,
|
|
||||||
boards=boards,
|
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
board_records=board_records,
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"),
|
boards=boards,
|
||||||
|
configuration=configuration,
|
||||||
|
events=events,
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
graph_library=graph_library,
|
||||||
configuration=config,
|
image_files=image_files,
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
image_records=image_records,
|
||||||
|
images=images,
|
||||||
|
invocation_cache=invocation_cache,
|
||||||
|
latents=latents,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
|
model_manager=model_manager,
|
||||||
session_processor=DefaultSessionProcessor(),
|
names=names,
|
||||||
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
performance_statistics=performance_statistics,
|
||||||
|
processor=processor,
|
||||||
|
queue=queue,
|
||||||
|
session_processor=session_processor,
|
||||||
|
session_queue=session_queue,
|
||||||
|
urls=urls,
|
||||||
)
|
)
|
||||||
|
|
||||||
create_system_graphs(services.graph_library)
|
create_system_graphs(services.graph_library)
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
|
||||||
try:
|
db.clean()
|
||||||
lock.acquire()
|
|
||||||
db_conn.execute("VACUUM;")
|
|
||||||
db_conn.commit()
|
|
||||||
logger.info("Cleaned database")
|
|
||||||
finally:
|
|
||||||
lock.release()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shutdown():
|
def shutdown():
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Any
|
|||||||
|
|
||||||
from fastapi_events.dispatcher import dispatch
|
from fastapi_events.dispatcher import dispatch
|
||||||
|
|
||||||
from ..services.events import EventServiceBase
|
from ..services.events.events_base import EventServiceBase
|
||||||
|
|
||||||
|
|
||||||
class FastAPIEventService(EventServiceBase):
|
class FastAPIEventService(EventServiceBase):
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ from fastapi import Body, HTTPException, Path, Query
|
|||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.services.board_record_storage import BoardChanges
|
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||||
from invokeai.app.services.models.board_record import BoardDTO
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ from PIL import Image
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||||
from invokeai.app.services.models.image_record import ImageDTO, ImageRecordChanges, ImageUrlsDTO
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ async def upload_image(
|
|||||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Uploads an image"""
|
"""Uploads an image"""
|
||||||
if not file.content_type.startswith("image"):
|
if not file.content_type or not file.content_type.startswith("image"):
|
||||||
raise HTTPException(status_code=415, detail="Not an image")
|
raise HTTPException(status_code=415, detail="Not an image")
|
||||||
|
|
||||||
contents = await file.read()
|
contents = await file.read()
|
||||||
|
|||||||
@@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, parse_obj_as
|
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
|
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
@@ -23,8 +23,14 @@ from ..dependencies import ApiDependencies
|
|||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
update_models_response_adapter = TypeAdapter(UpdateModelResponse)
|
||||||
|
|
||||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
import_models_response_adapter = TypeAdapter(ImportModelResponse)
|
||||||
|
|
||||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
convert_models_response_adapter = TypeAdapter(ConvertModelResponse)
|
||||||
|
|
||||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
@@ -32,6 +38,11 @@ ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||||
|
|
||||||
|
model_config = ConfigDict(use_enum_values=True)
|
||||||
|
|
||||||
|
|
||||||
|
models_list_adapter = TypeAdapter(ModelsList)
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/",
|
"/",
|
||||||
@@ -49,7 +60,7 @@ async def list_models(
|
|||||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||||
else:
|
else:
|
||||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||||
models = parse_obj_as(ModelsList, {"models": models_raw})
|
models = models_list_adapter.validate_python({"models": models_raw})
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
@@ -105,11 +116,14 @@ async def update_model(
|
|||||||
info.path = new_info.get("path")
|
info.path = new_info.get("path")
|
||||||
|
|
||||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||||
info_dict = info.dict()
|
info_dict = info.model_dump()
|
||||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||||
|
|
||||||
ApiDependencies.invoker.services.model_manager.update_model(
|
ApiDependencies.invoker.services.model_manager.update_model(
|
||||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
model_attributes=info_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
@@ -117,7 +131,7 @@ async def update_model(
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
model_response = update_models_response_adapter.validate_python(model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -152,13 +166,15 @@ async def import_model(
|
|||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||||
|
|
||||||
|
location = location.strip("\"' ")
|
||||||
items_to_import = {location}
|
items_to_import = {location}
|
||||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
items_to_import=items_to_import,
|
||||||
|
prediction_type_helper=lambda x: prediction_types.get(prediction_type),
|
||||||
)
|
)
|
||||||
info = installed_models.get(location)
|
info = installed_models.get(location)
|
||||||
|
|
||||||
@@ -170,7 +186,7 @@ async def import_model(
|
|||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return import_models_response_adapter.validate_python(model_raw)
|
||||||
|
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
@@ -204,13 +220,18 @@ async def add_model(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.model_manager.add_model(
|
ApiDependencies.invoker.services.model_manager.add_model(
|
||||||
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
info.model_name,
|
||||||
|
info.base_model,
|
||||||
|
info.model_type,
|
||||||
|
model_attributes=info.model_dump(),
|
||||||
)
|
)
|
||||||
logger.info(f"Successfully added {info.model_name}")
|
logger.info(f"Successfully added {info.model_name}")
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
model_name=info.model_name,
|
||||||
|
base_model=info.base_model,
|
||||||
|
model_type=info.model_type,
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return import_models_response_adapter.validate_python(model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
@@ -222,7 +243,10 @@ async def add_model(
|
|||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="del_model",
|
operation_id="del_model",
|
||||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
responses={
|
||||||
|
204: {"description": "Model deleted successfully"},
|
||||||
|
404: {"description": "Model not found"},
|
||||||
|
},
|
||||||
status_code=204,
|
status_code=204,
|
||||||
response_model=None,
|
response_model=None,
|
||||||
)
|
)
|
||||||
@@ -278,7 +302,7 @@ async def convert_model(
|
|||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name, base_model=base_model, model_type=model_type
|
model_name, base_model=base_model, model_type=model_type
|
||||||
)
|
)
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = convert_models_response_adapter.validate_python(model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -301,7 +325,8 @@ async def search_for_models(
|
|||||||
) -> List[pathlib.Path]:
|
) -> List[pathlib.Path]:
|
||||||
if not search_path.is_dir():
|
if not search_path.is_dir():
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
status_code=404,
|
||||||
|
detail=f"The search path '{search_path}' does not exist or is not directory",
|
||||||
)
|
)
|
||||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||||
|
|
||||||
@@ -336,6 +361,26 @@ async def sync_to_config() -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# There's some weird pydantic-fastapi behaviour that requires this to be a separate class
|
||||||
|
# TODO: After a few updates, see if it works inside the route operation handler?
|
||||||
|
class MergeModelsBody(BaseModel):
|
||||||
|
model_names: List[str] = Field(description="model name", min_length=2, max_length=3)
|
||||||
|
merged_model_name: Optional[str] = Field(description="Name of destination model")
|
||||||
|
alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5)
|
||||||
|
interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method")
|
||||||
|
force: Optional[bool] = Field(
|
||||||
|
description="Force merging of models created with different versions of diffusers",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
merge_dest_directory: Optional[str] = Field(
|
||||||
|
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/merge/{base_model}",
|
"/merge/{base_model}",
|
||||||
operation_id="merge_models",
|
operation_id="merge_models",
|
||||||
@@ -348,31 +393,23 @@ async def sync_to_config() -> bool:
|
|||||||
response_model=MergeModelResponse,
|
response_model=MergeModelResponse,
|
||||||
)
|
)
|
||||||
async def merge_models(
|
async def merge_models(
|
||||||
|
body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)],
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
|
||||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
|
||||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
|
||||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
|
||||||
force: Optional[bool] = Body(
|
|
||||||
description="Force merging of models created with different versions of diffusers", default=False
|
|
||||||
),
|
|
||||||
merge_dest_directory: Optional[str] = Body(
|
|
||||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
|
||||||
default=None,
|
|
||||||
),
|
|
||||||
) -> MergeModelResponse:
|
) -> MergeModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model"""
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
logger.info(
|
||||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
f"Merging models: {body.model_names} into {body.merge_dest_directory or '<MODELS>'}/{body.merged_model_name}"
|
||||||
|
)
|
||||||
|
dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None
|
||||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||||
model_names,
|
model_names=body.model_names,
|
||||||
base_model,
|
base_model=base_model,
|
||||||
merged_model_name=merged_model_name or "+".join(model_names),
|
merged_model_name=body.merged_model_name or "+".join(body.model_names),
|
||||||
alpha=alpha,
|
alpha=body.alpha,
|
||||||
interp=interp,
|
interp=body.interp,
|
||||||
force=force,
|
force=body.force,
|
||||||
merge_dest_directory=dest,
|
merge_dest_directory=dest,
|
||||||
)
|
)
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
@@ -380,9 +417,12 @@ async def merge_models(
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Main,
|
model_type=ModelType.Main,
|
||||||
)
|
)
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = convert_models_response_adapter.validate_python(model_raw)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"One or more of the models '{body.model_names}' not found",
|
||||||
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -18,9 +18,9 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueItemDTO,
|
SessionQueueItemDTO,
|
||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.shared.models import CursorPaginatedResults
|
from invokeai.app.services.shared.graph import Graph
|
||||||
|
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||||
|
|
||||||
from ...services.graph import Graph
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||||
|
|||||||
@@ -6,11 +6,12 @@ from fastapi import Body, HTTPException, Path, Query, Response
|
|||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from ...invocations import * # noqa: F401 F403
|
from ...invocations import * # noqa: F401 F403
|
||||||
from ...invocations.baseinvocation import BaseInvocation
|
from ...invocations.baseinvocation import BaseInvocation
|
||||||
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
from ...services.shared.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||||
from ...services.item_storage import PaginatedResults
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
@@ -27,6 +27,7 @@ async def parse_dynamicprompts(
|
|||||||
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
|
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
|
||||||
) -> DynamicPromptsResponse:
|
) -> DynamicPromptsResponse:
|
||||||
"""Creates a batch process"""
|
"""Creates a batch process"""
|
||||||
|
generator: Union[RandomPromptGenerator, CombinatorialPromptGenerator]
|
||||||
try:
|
try:
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
if combinatorial:
|
if combinatorial:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from fastapi_events.handlers.local import local_handler
|
|||||||
from fastapi_events.typing import Event
|
from fastapi_events.typing import Event
|
||||||
from socketio import ASGIApp, AsyncServer
|
from socketio import ASGIApp, AsyncServer
|
||||||
|
|
||||||
from ..services.events import EventServiceBase
|
from ..services.events.events_base import EventServiceBase
|
||||||
|
|
||||||
|
|
||||||
class SocketIO:
|
class SocketIO:
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||||
from pydantic.schema import schema
|
from pydantic.json_schema import models_json_schema
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
@@ -31,7 +31,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
|
from .api.routers import app_info, board_images, boards, images, models, session_queue, utilities
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ mimetypes.add_type("text/css", ".css")
|
|||||||
|
|
||||||
# Create the app
|
# Create the app
|
||||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
|
||||||
|
|
||||||
# Add event handler
|
# Add event handler
|
||||||
event_handler_id: int = id(app)
|
event_handler_id: int = id(app)
|
||||||
@@ -63,18 +63,18 @@ app.add_middleware(
|
|||||||
|
|
||||||
socket_io = SocketIO(app)
|
socket_io = SocketIO(app)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=app_config.allow_origins,
|
||||||
|
allow_credentials=app_config.allow_credentials,
|
||||||
|
allow_methods=app_config.allow_methods,
|
||||||
|
allow_headers=app_config.allow_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=app_config.allow_origins,
|
|
||||||
allow_credentials=app_config.allow_credentials,
|
|
||||||
allow_methods=app_config.allow_methods,
|
|
||||||
allow_headers=app_config.allow_headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
@@ -85,12 +85,7 @@ async def shutdown_event():
|
|||||||
|
|
||||||
|
|
||||||
# Include all routers
|
# Include all routers
|
||||||
# TODO: REMOVE
|
# app.include_router(sessions.session_router, prefix="/api")
|
||||||
# app.include_router(
|
|
||||||
# invocation.invocation_router,
|
|
||||||
# prefix = '/api')
|
|
||||||
|
|
||||||
app.include_router(sessions.session_router, prefix="/api")
|
|
||||||
|
|
||||||
app.include_router(utilities.utilities_router, prefix="/api")
|
app.include_router(utilities.utilities_router, prefix="/api")
|
||||||
|
|
||||||
@@ -117,6 +112,7 @@ def custom_openapi():
|
|||||||
description="An API for invoking AI image operations",
|
description="An API for invoking AI image operations",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
routes=app.routes,
|
routes=app.routes,
|
||||||
|
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add all outputs
|
# Add all outputs
|
||||||
@@ -127,29 +123,32 @@ def custom_openapi():
|
|||||||
output_type = signature(invoker.invoke).return_annotation
|
output_type = signature(invoker.invoke).return_annotation
|
||||||
output_types.add(output_type)
|
output_types.add(output_type)
|
||||||
|
|
||||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
output_schemas = models_json_schema(
|
||||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
|
||||||
output_schema["class"] = "output"
|
)
|
||||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
for schema_key, output_schema in output_schemas[1]["$defs"].items():
|
||||||
|
|
||||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||||
# This could break in some cases, figure out a better way to do it
|
# This could break in some cases, figure out a better way to do it
|
||||||
output_type_titles[schema_key] = output_schema["title"]
|
output_type_titles[schema_key] = output_schema["title"]
|
||||||
|
|
||||||
# Add Node Editor UI helper schemas
|
# Add Node Editor UI helper schemas
|
||||||
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
ui_config_schemas = models_json_schema(
|
||||||
for schema_key, ui_config_schema in ui_config_schemas["definitions"].items():
|
[(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")],
|
||||||
|
ref_template="#/components/schemas/{model}",
|
||||||
|
)
|
||||||
|
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
|
||||||
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
||||||
|
|
||||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||||
for invoker in all_invocations:
|
for invoker in all_invocations:
|
||||||
invoker_name = invoker.__name__
|
invoker_name = invoker.__name__
|
||||||
output_type = signature(invoker.invoke).return_annotation
|
output_type = signature(obj=invoker.invoke).return_annotation
|
||||||
output_type_title = output_type_titles[output_type.__name__]
|
output_type_title = output_type_titles[output_type.__name__]
|
||||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||||
invoker_schema["output"] = outputs_ref
|
invoker_schema["output"] = outputs_ref
|
||||||
invoker_schema["class"] = "invocation"
|
invoker_schema["class"] = "invocation"
|
||||||
|
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
|
||||||
|
|
||||||
from invokeai.backend.model_management.models import get_model_config_enums
|
from invokeai.backend.model_management.models import get_model_config_enums
|
||||||
|
|
||||||
@@ -172,7 +171,7 @@ def custom_openapi():
|
|||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
|
||||||
|
|
||||||
app.openapi = custom_openapi
|
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
||||||
|
|
||||||
# Override API doc favicons
|
# Override API doc favicons
|
||||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
|
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ def add_field_argument(command_parser, name: str, field, default_override=None):
|
|||||||
if field.default_factory is None
|
if field.default_factory is None
|
||||||
else field.default_factory()
|
else field.default_factory()
|
||||||
)
|
)
|
||||||
if get_origin(field.type_) == Literal:
|
if get_origin(field.annotation) == Literal:
|
||||||
allowed_values = get_args(field.type_)
|
allowed_values = get_args(field.annotation)
|
||||||
allowed_types = set()
|
allowed_types = set()
|
||||||
for val in allowed_values:
|
for val in allowed_values:
|
||||||
allowed_types.add(type(val))
|
allowed_types.add(type(val))
|
||||||
@@ -38,15 +38,15 @@ def add_field_argument(command_parser, name: str, field, default_override=None):
|
|||||||
type=field_type,
|
type=field_type,
|
||||||
default=default,
|
default=default,
|
||||||
choices=allowed_values,
|
choices=allowed_values,
|
||||||
help=field.field_info.description,
|
help=field.description,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
command_parser.add_argument(
|
command_parser.add_argument(
|
||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
type=field.type_,
|
type=field.annotation,
|
||||||
default=default,
|
default=default,
|
||||||
help=field.field_info.description,
|
help=field.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -142,7 +142,6 @@ class BaseCommand(ABC, BaseModel):
|
|||||||
"""A CLI command"""
|
"""A CLI command"""
|
||||||
|
|
||||||
# All commands must include a type name like this:
|
# All commands must include a type name like this:
|
||||||
# type: Literal['your_command_name'] = 'your_command_name'
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_subclasses(cls):
|
def get_all_subclasses(cls):
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import validator
|
from pydantic import ValidationInfo, field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
@@ -20,9 +20,9 @@ class RangeInvocation(BaseInvocation):
|
|||||||
stop: int = InputField(default=10, description="The stop of the range")
|
stop: int = InputField(default=10, description="The stop of the range")
|
||||||
step: int = InputField(default=1, description="The step of the range")
|
step: int = InputField(default=1, description="The step of the range")
|
||||||
|
|
||||||
@validator("stop")
|
@field_validator("stop")
|
||||||
def stop_gt_start(cls, v, values):
|
def stop_gt_start(cls, v: int, info: ValidationInfo):
|
||||||
if "start" in values and v <= values["start"]:
|
if "start" in info.data and v <= info.data["start"]:
|
||||||
raise ValueError("stop must be greater than start")
|
raise ValueError("stop must be greater than start")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
@@ -43,7 +43,13 @@ class ConditioningFieldData:
|
|||||||
# PerpNeg = "perp_neg"
|
# PerpNeg = "perp_neg"
|
||||||
|
|
||||||
|
|
||||||
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning", version="1.0.0")
|
@invocation(
|
||||||
|
"compel",
|
||||||
|
title="Prompt",
|
||||||
|
tags=["prompt", "compel"],
|
||||||
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
@@ -60,23 +66,21 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.get_model(
|
||||||
**self.clip.tokenizer.dict(),
|
**self.clip.tokenizer.model_dump(),
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.get_model(
|
||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.model_dump(),
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
lora_info = context.get_model(**lora.model_dump(exclude={"weight"}))
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
@@ -85,11 +89,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
ti_list.append(
|
ti_list.append(
|
||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
context.services.model_manager.get_model(
|
context.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
|
||||||
).context.model,
|
).context.model,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -118,7 +121,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||||
|
|
||||||
if context.services.configuration.log_tokenization:
|
if context.config.log_tokenization:
|
||||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
@@ -139,8 +142,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = context.save_conditioning(conditioning_data)
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
|
||||||
|
|
||||||
return ConditioningOutput(
|
return ConditioningOutput(
|
||||||
conditioning=ConditioningField(
|
conditioning=ConditioningField(
|
||||||
@@ -160,11 +162,11 @@ class SDXLPromptInvocationBase:
|
|||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
):
|
):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**clip_field.text_encoder.dict(),
|
**clip_field.text_encoder.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -172,7 +174,11 @@ class SDXLPromptInvocationBase:
|
|||||||
if prompt == "" and zero_on_empty:
|
if prompt == "" and zero_on_empty:
|
||||||
cpu_text_encoder = text_encoder_info.context.model
|
cpu_text_encoder = text_encoder_info.context.model
|
||||||
c = torch.zeros(
|
c = torch.zeros(
|
||||||
(1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size),
|
(
|
||||||
|
1,
|
||||||
|
cpu_text_encoder.config.max_position_embeddings,
|
||||||
|
cpu_text_encoder.config.hidden_size,
|
||||||
|
),
|
||||||
dtype=text_encoder_info.context.cache.precision,
|
dtype=text_encoder_info.context.cache.precision,
|
||||||
)
|
)
|
||||||
if get_pooled:
|
if get_pooled:
|
||||||
@@ -186,7 +192,9 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.model_dump(exclude={"weight"}), context=context
|
||||||
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
@@ -273,8 +281,16 @@ class SDXLPromptInvocationBase:
|
|||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
prompt: str = InputField(
|
||||||
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
default="",
|
||||||
|
description=FieldDescriptions.compel_prompt,
|
||||||
|
ui_component=UIComponent.Textarea,
|
||||||
|
)
|
||||||
|
style: str = InputField(
|
||||||
|
default="",
|
||||||
|
description=FieldDescriptions.compel_prompt,
|
||||||
|
ui_component=UIComponent.Textarea,
|
||||||
|
)
|
||||||
original_width: int = InputField(default=1024, description="")
|
original_width: int = InputField(default=1024, description="")
|
||||||
original_height: int = InputField(default=1024, description="")
|
original_height: int = InputField(default=1024, description="")
|
||||||
crop_top: int = InputField(default=0, description="")
|
crop_top: int = InputField(default=0, description="")
|
||||||
@@ -310,7 +326,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
[
|
[
|
||||||
c1,
|
c1,
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
(c1.shape[0], c2.shape[1] - c1.shape[1], c1.shape[2]), device=c1.device, dtype=c1.dtype
|
(c1.shape[0], c2.shape[1] - c1.shape[1], c1.shape[2]),
|
||||||
|
device=c1.device,
|
||||||
|
dtype=c1.dtype,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
@@ -321,7 +339,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
[
|
[
|
||||||
c2,
|
c2,
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
(c2.shape[0], c1.shape[1] - c2.shape[1], c2.shape[2]), device=c2.device, dtype=c2.dtype
|
(c2.shape[0], c1.shape[1] - c2.shape[1], c2.shape[2]),
|
||||||
|
device=c2.device,
|
||||||
|
dtype=c2.dtype,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
@@ -359,7 +379,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
style: str = InputField(
|
style: str = InputField(
|
||||||
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
|
default="",
|
||||||
|
description=FieldDescriptions.compel_prompt,
|
||||||
|
ui_component=UIComponent.Textarea,
|
||||||
) # TODO: ?
|
) # TODO: ?
|
||||||
original_width: int = InputField(default=1024, description="")
|
original_width: int = InputField(default=1024, description="")
|
||||||
original_height: int = InputField(default=1024, description="")
|
original_height: int = InputField(default=1024, description="")
|
||||||
@@ -403,10 +425,16 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||||
"""Clip skip node output"""
|
"""Clip skip node output"""
|
||||||
|
|
||||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning", version="1.0.0")
|
@invocation(
|
||||||
|
"clip_skip",
|
||||||
|
title="CLIP Skip",
|
||||||
|
tags=["clipskip", "clip", "skip"],
|
||||||
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ClipSkipInvocation(BaseInvocation):
|
class ClipSkipInvocation(BaseInvocation):
|
||||||
"""Skip layers in clip text_encoder model."""
|
"""Skip layers in clip text_encoder model."""
|
||||||
|
|
||||||
@@ -421,7 +449,9 @@ class ClipSkipInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
def get_max_token_count(
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
tokenizer,
|
||||||
|
prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||||
|
truncate_if_too_long=False,
|
||||||
) -> int:
|
) -> int:
|
||||||
if type(prompt) is Blend:
|
if type(prompt) is Blend:
|
||||||
blend: Blend = prompt
|
blend: Blend = prompt
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# initial implementation by Gregg Helt, 2023
|
# initial implementation by Gregg Helt, 2023
|
||||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||||
from builtins import bool, float
|
from builtins import bool, float
|
||||||
from typing import Dict, List, Literal, Optional, Union
|
from typing import Dict, List, Literal, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -24,12 +24,12 @@ from controlnet_aux import (
|
|||||||
)
|
)
|
||||||
from controlnet_aux.util import HWC3, ade_palette
|
from controlnet_aux.util import HWC3, ade_palette
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType
|
from ...backend.model_management import BaseModelType
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@@ -57,6 +57,8 @@ class ControlNetModelField(BaseModel):
|
|||||||
model_name: str = Field(description="Name of the ControlNet model")
|
model_name: str = Field(description="Name of the ControlNet model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(description="The control image")
|
image: ImageField = Field(description="The control image")
|
||||||
@@ -71,7 +73,7 @@ class ControlField(BaseModel):
|
|||||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||||
|
|
||||||
@validator("control_weight")
|
@field_validator("control_weight")
|
||||||
def validate_control_weight(cls, v):
|
def validate_control_weight(cls, v):
|
||||||
"""Validate that all control weights in the valid range"""
|
"""Validate that all control weights in the valid range"""
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
@@ -124,9 +126,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||||
"image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0"
|
|
||||||
)
|
|
||||||
class ImageProcessorInvocation(BaseInvocation):
|
class ImageProcessorInvocation(BaseInvocation):
|
||||||
"""Base class for invocations that preprocess images for ControlNet"""
|
"""Base class for invocations that preprocess images for ControlNet"""
|
||||||
|
|
||||||
@@ -393,9 +393,9 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
h: int = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||||
w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||||
f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
content_shuffle_processor = ContentShuffleDetector()
|
content_shuffle_processor = ContentShuffleDetector()
|
||||||
@@ -575,14 +575,14 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
def run_processor(self, image: Image.Image):
|
def run_processor(self, image: Image.Image):
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
image = np.array(image, dtype=np.uint8)
|
np_image = np.array(image, dtype=np.uint8)
|
||||||
height, width = image.shape[:2]
|
height, width = np_image.shape[:2]
|
||||||
|
|
||||||
width_tile_size = min(self.color_map_tile_size, width)
|
width_tile_size = min(self.color_map_tile_size, width)
|
||||||
height_tile_size = min(self.color_map_tile_size, height)
|
height_tile_size = min(self.color_map_tile_size, height)
|
||||||
|
|
||||||
color_map = cv2.resize(
|
color_map = cv2.resize(
|
||||||
image,
|
np_image,
|
||||||
(width // width_tile_size, height // height_tile_size),
|
(width // width_tile_size, height // height_tile_size),
|
||||||
interpolation=cv2.INTER_CUBIC,
|
interpolation=cv2.INTER_CUBIC,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import numpy
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import numpy as np
|
|||||||
from mediapipe.python.solutions.face_mesh import FaceMesh # type: ignore[import]
|
from mediapipe.python.solutions.face_mesh import FaceMesh # type: ignore[import]
|
||||||
from PIL import Image, ImageDraw, ImageFilter, ImageFont, ImageOps
|
from PIL import Image, ImageDraw, ImageFilter, ImageFont, ImageOps
|
||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
from pydantic import validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
import invokeai.assets.fonts as font_assets
|
import invokeai.assets.fonts as font_assets
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
@@ -20,7 +20,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("face_mask_output")
|
@invocation_output("face_mask_output")
|
||||||
@@ -550,7 +550,7 @@ class FaceMaskInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
invert_mask: bool = InputField(default=False, description="Toggle to invert the mask")
|
invert_mask: bool = InputField(default=False, description="Toggle to invert the mask")
|
||||||
|
|
||||||
@validator("face_ids")
|
@field_validator("face_ids")
|
||||||
def validate_comma_separated_ints(cls, v) -> str:
|
def validate_comma_separated_ints(cls, v) -> str:
|
||||||
comma_separated_ints_regex = re.compile(r"^\d*(,\d+)*$")
|
comma_separated_ints_regex = re.compile(r"^\d*(,\d+)*$")
|
||||||
if comma_separated_ints_regex.match(v) is None:
|
if comma_separated_ints_regex.match(v) is None:
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ from PIL import Image, ImageChops, ImageFilter, ImageOps
|
|||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
|
||||||
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@@ -36,7 +36,13 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"blank_image",
|
||||||
|
title="Blank Image",
|
||||||
|
tags=["image"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class BlankImageInvocation(BaseInvocation):
|
class BlankImageInvocation(BaseInvocation):
|
||||||
"""Creates a blank image and forwards it to the pipeline"""
|
"""Creates a blank image and forwards it to the pipeline"""
|
||||||
|
|
||||||
@@ -65,7 +71,13 @@ class BlankImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_crop",
|
||||||
|
title="Crop Image",
|
||||||
|
tags=["image", "crop"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageCropInvocation(BaseInvocation):
|
class ImageCropInvocation(BaseInvocation):
|
||||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||||
|
|
||||||
@@ -98,7 +110,13 @@ class ImageCropInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.1")
|
@invocation(
|
||||||
|
"img_paste",
|
||||||
|
title="Paste Image",
|
||||||
|
tags=["image", "paste"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.1",
|
||||||
|
)
|
||||||
class ImagePasteInvocation(BaseInvocation):
|
class ImagePasteInvocation(BaseInvocation):
|
||||||
"""Pastes an image into another image."""
|
"""Pastes an image into another image."""
|
||||||
|
|
||||||
@@ -151,7 +169,13 @@ class ImagePasteInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"tomask",
|
||||||
|
title="Mask from Alpha",
|
||||||
|
tags=["image", "mask"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class MaskFromAlphaInvocation(BaseInvocation):
|
class MaskFromAlphaInvocation(BaseInvocation):
|
||||||
"""Extracts the alpha channel of an image as a mask."""
|
"""Extracts the alpha channel of an image as a mask."""
|
||||||
|
|
||||||
@@ -182,7 +206,13 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_mul",
|
||||||
|
title="Multiply Images",
|
||||||
|
tags=["image", "multiply"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageMultiplyInvocation(BaseInvocation):
|
class ImageMultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
@@ -215,7 +245,13 @@ class ImageMultiplyInvocation(BaseInvocation):
|
|||||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_chan",
|
||||||
|
title="Extract Image Channel",
|
||||||
|
tags=["image", "channel"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageChannelInvocation(BaseInvocation):
|
class ImageChannelInvocation(BaseInvocation):
|
||||||
"""Gets a channel from an image."""
|
"""Gets a channel from an image."""
|
||||||
|
|
||||||
@@ -247,7 +283,13 @@ class ImageChannelInvocation(BaseInvocation):
|
|||||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_conv",
|
||||||
|
title="Convert Image Mode",
|
||||||
|
tags=["image", "convert"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageConvertInvocation(BaseInvocation):
|
class ImageConvertInvocation(BaseInvocation):
|
||||||
"""Converts an image to a different mode."""
|
"""Converts an image to a different mode."""
|
||||||
|
|
||||||
@@ -276,7 +318,13 @@ class ImageConvertInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_blur",
|
||||||
|
title="Blur Image",
|
||||||
|
tags=["image", "blur"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageBlurInvocation(BaseInvocation):
|
class ImageBlurInvocation(BaseInvocation):
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
@@ -330,7 +378,13 @@ PIL_RESAMPLING_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_resize",
|
||||||
|
title="Resize Image",
|
||||||
|
tags=["image", "resize"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageResizeInvocation(BaseInvocation):
|
class ImageResizeInvocation(BaseInvocation):
|
||||||
"""Resizes an image to specific dimensions"""
|
"""Resizes an image to specific dimensions"""
|
||||||
|
|
||||||
@@ -343,7 +397,7 @@ class ImageResizeInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.get_image(self.image.image_name)
|
||||||
|
|
||||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
|
|
||||||
@@ -352,25 +406,22 @@ class ImageResizeInvocation(BaseInvocation):
|
|||||||
resample=resample_mode,
|
resample=resample_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_name = context.save_image(image=resize_image)
|
||||||
image=resize_image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
|
||||||
workflow=self.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
image=ImageField(image_name=image_name),
|
||||||
width=image_dto.width,
|
width=resize_image.width,
|
||||||
height=image_dto.height,
|
height=resize_image.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_scale",
|
||||||
|
title="Scale Image",
|
||||||
|
tags=["image", "scale"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageScaleInvocation(BaseInvocation):
|
class ImageScaleInvocation(BaseInvocation):
|
||||||
"""Scales an image by a factor"""
|
"""Scales an image by a factor"""
|
||||||
|
|
||||||
@@ -411,7 +462,13 @@ class ImageScaleInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_lerp",
|
||||||
|
title="Lerp Image",
|
||||||
|
tags=["image", "lerp"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageLerpInvocation(BaseInvocation):
|
class ImageLerpInvocation(BaseInvocation):
|
||||||
"""Linear interpolation of all pixels of an image"""
|
"""Linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
@@ -444,7 +501,13 @@ class ImageLerpInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_ilerp",
|
||||||
|
title="Inverse Lerp Image",
|
||||||
|
tags=["image", "ilerp"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageInverseLerpInvocation(BaseInvocation):
|
class ImageInverseLerpInvocation(BaseInvocation):
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
@@ -456,7 +519,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
|||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||||
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
|
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 # type: ignore [assignment]
|
||||||
|
|
||||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||||
|
|
||||||
@@ -477,7 +540,13 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_nsfw",
|
||||||
|
title="Blur NSFW Image",
|
||||||
|
tags=["image", "nsfw"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||||
"""Add blur to NSFW-flagged images"""
|
"""Add blur to NSFW-flagged images"""
|
||||||
|
|
||||||
@@ -505,7 +574,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -515,7 +584,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
|||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_caution_img(self) -> Image:
|
def _get_caution_img(self) -> Image.Image:
|
||||||
import invokeai.app.assets.images as image_assets
|
import invokeai.app.assets.images as image_assets
|
||||||
|
|
||||||
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
||||||
@@ -523,7 +592,11 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image", version="1.0.0"
|
"img_watermark",
|
||||||
|
title="Add Invisible Watermark",
|
||||||
|
tags=["image", "watermark"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageWatermarkInvocation(BaseInvocation):
|
class ImageWatermarkInvocation(BaseInvocation):
|
||||||
"""Add an invisible watermark to an image"""
|
"""Add an invisible watermark to an image"""
|
||||||
@@ -544,7 +617,7 @@ class ImageWatermarkInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -555,7 +628,13 @@ class ImageWatermarkInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"mask_edge",
|
||||||
|
title="Mask Edge",
|
||||||
|
tags=["image", "mask", "inpaint"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class MaskEdgeInvocation(BaseInvocation):
|
class MaskEdgeInvocation(BaseInvocation):
|
||||||
"""Applies an edge mask to an image"""
|
"""Applies an edge mask to an image"""
|
||||||
|
|
||||||
@@ -601,7 +680,11 @@ class MaskEdgeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image", version="1.0.0"
|
"mask_combine",
|
||||||
|
title="Combine Masks",
|
||||||
|
tags=["image", "mask", "multiply"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class MaskCombineInvocation(BaseInvocation):
|
class MaskCombineInvocation(BaseInvocation):
|
||||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||||
@@ -632,7 +715,13 @@ class MaskCombineInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"color_correct",
|
||||||
|
title="Color Correct",
|
||||||
|
tags=["image", "color"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ColorCorrectInvocation(BaseInvocation):
|
class ColorCorrectInvocation(BaseInvocation):
|
||||||
"""
|
"""
|
||||||
Shifts the colors of a target image to match the reference image, optionally
|
Shifts the colors of a target image to match the reference image, optionally
|
||||||
@@ -742,7 +831,13 @@ class ColorCorrectInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
|
@invocation(
|
||||||
|
"img_hue_adjust",
|
||||||
|
title="Adjust Image Hue",
|
||||||
|
tags=["image", "hue"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Hue of an image."""
|
"""Adjusts the Hue of an image."""
|
||||||
|
|
||||||
@@ -980,7 +1075,7 @@ class SaveImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||||
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
||||||
metadata: CoreMetadata = InputField(
|
metadata: Optional[CoreMetadata] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.core_metadata,
|
description=FieldDescriptions.core_metadata,
|
||||||
ui_hidden=True,
|
ui_hidden=True,
|
||||||
@@ -997,7 +1092,7 @@ class SaveImageInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -7,12 +7,12 @@ import numpy as np
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||||
from invokeai.backend.image_util.lama import LaMA
|
from invokeai.backend.image_util.lama import LaMA
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
from builtins import float
|
from builtins import float
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@@ -25,11 +25,15 @@ class IPAdapterModelField(BaseModel):
|
|||||||
model_name: str = Field(description="Name of the IP-Adapter model")
|
model_name: str = Field(description="Name of the IP-Adapter model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionModelField(BaseModel):
|
class CLIPVisionModelField(BaseModel):
|
||||||
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
|
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
|
||||||
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
|
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterField(BaseModel):
|
class IPAdapterField(BaseModel):
|
||||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from diffusers.models.attention_processor import (
|
|||||||
)
|
)
|
||||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from pydantic import validator
|
from pydantic import field_validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
@@ -34,6 +34,7 @@ from invokeai.app.invocations.primitives import (
|
|||||||
build_latents_output,
|
build_latents_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
@@ -54,7 +55,6 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@@ -84,12 +84,20 @@ class SchedulerOutput(BaseInvocationOutput):
|
|||||||
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||||
|
|
||||||
|
|
||||||
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents", version="1.0.0")
|
@invocation(
|
||||||
|
"scheduler",
|
||||||
|
title="Scheduler",
|
||||||
|
tags=["scheduler"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class SchedulerInvocation(BaseInvocation):
|
class SchedulerInvocation(BaseInvocation):
|
||||||
"""Selects a scheduler."""
|
"""Selects a scheduler."""
|
||||||
|
|
||||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||||
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
default="euler",
|
||||||
|
description=FieldDescriptions.scheduler,
|
||||||
|
ui_type=UIType.Scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SchedulerOutput:
|
def invoke(self, context: InvocationContext) -> SchedulerOutput:
|
||||||
@@ -97,7 +105,11 @@ class SchedulerInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", version="1.0.0"
|
"create_denoise_mask",
|
||||||
|
title="Create Denoise Mask",
|
||||||
|
tags=["mask", "denoise"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||||
"""Creates mask for denoising model run."""
|
"""Creates mask for denoising model run."""
|
||||||
@@ -106,7 +118,11 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32, ui_order=4)
|
fp32: bool = InputField(
|
||||||
|
default=DEFAULT_PRECISION == "float32",
|
||||||
|
description=FieldDescriptions.fp32,
|
||||||
|
ui_order=4,
|
||||||
|
)
|
||||||
|
|
||||||
def prep_mask_tensor(self, mask_image):
|
def prep_mask_tensor(self, mask_image):
|
||||||
if mask_image.mode != "L":
|
if mask_image.mode != "L":
|
||||||
@@ -134,7 +150,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -166,9 +182,8 @@ def get_scheduler(
|
|||||||
seed: int,
|
seed: int,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.get_model(
|
||||||
**scheduler_info.dict(),
|
**scheduler_info.model_dump(),
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
@@ -209,34 +224,64 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
negative_conditioning: ConditioningField = InputField(
|
negative_conditioning: ConditioningField = InputField(
|
||||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||||
)
|
)
|
||||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
|
noise: Optional[LatentsField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.noise,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=3,
|
||||||
|
)
|
||||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||||
cfg_scale: Union[float, List[float]] = InputField(
|
cfg_scale: Union[float, List[float]] = InputField(
|
||||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
||||||
)
|
)
|
||||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
denoising_start: float = InputField(
|
||||||
|
default=0.0,
|
||||||
|
ge=0,
|
||||||
|
le=1,
|
||||||
|
description=FieldDescriptions.denoising_start,
|
||||||
|
)
|
||||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||||
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
default="euler",
|
||||||
|
description=FieldDescriptions.scheduler,
|
||||||
|
ui_type=UIType.Scheduler,
|
||||||
)
|
)
|
||||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
unet: UNetField = InputField(
|
||||||
control: Union[ControlField, list[ControlField]] = InputField(
|
description=FieldDescriptions.unet,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="UNet",
|
||||||
|
ui_order=2,
|
||||||
|
)
|
||||||
|
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
ui_order=5,
|
ui_order=5,
|
||||||
)
|
)
|
||||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
|
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
|
||||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
|
description=FieldDescriptions.ip_adapter,
|
||||||
|
title="IP-Adapter",
|
||||||
|
default=None,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=6,
|
||||||
)
|
)
|
||||||
t2i_adapter: Union[T2IAdapterField, list[T2IAdapterField]] = InputField(
|
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
|
||||||
description=FieldDescriptions.t2i_adapter, title="T2I-Adapter", default=None, input=Input.Connection, ui_order=7
|
description=FieldDescriptions.t2i_adapter,
|
||||||
|
title="T2I-Adapter",
|
||||||
|
default=None,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=7,
|
||||||
|
)
|
||||||
|
latents: Optional[LatentsField] = InputField(
|
||||||
|
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||||
)
|
)
|
||||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
|
||||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=8
|
default=None,
|
||||||
|
description=FieldDescriptions.mask,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@field_validator("cfg_scale")
|
||||||
def ge_one(cls, v):
|
def ge_one(cls, v):
|
||||||
"""validate that all cfg_scale values are >= 1"""
|
"""validate that all cfg_scale values are >= 1"""
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
@@ -252,15 +297,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
source_node_id: str,
|
|
||||||
intermediate_state: PipelineIntermediateState,
|
intermediate_state: PipelineIntermediateState,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> None:
|
) -> None:
|
||||||
stable_diffusion_step_callback(
|
stable_diffusion_step_callback(
|
||||||
context=context,
|
context=context,
|
||||||
intermediate_state=intermediate_state,
|
intermediate_state=intermediate_state,
|
||||||
node=self.dict(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -271,11 +313,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
unet,
|
unet,
|
||||||
seed,
|
seed,
|
||||||
) -> ConditioningData:
|
) -> ConditioningData:
|
||||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
positive_cond_data = context.get_conditioning(self.positive_conditioning.conditioning_name)
|
||||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
extra_conditioning_info = c.extra_conditioning
|
extra_conditioning_info = c.extra_conditioning
|
||||||
|
|
||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.get_conditioning(self.negative_conditioning.conditioning_name)
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
@@ -362,17 +404,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
controlnet_data = []
|
controlnet_data = []
|
||||||
for control_info in control_list:
|
for control_info in control_list:
|
||||||
control_model = exit_stack.enter_context(
|
control_model = exit_stack.enter_context(
|
||||||
context.services.model_manager.get_model(
|
context.get_model(
|
||||||
model_name=control_info.control_model.model_name,
|
model_name=control_info.control_model.model_name,
|
||||||
model_type=ModelType.ControlNet,
|
model_type=ModelType.ControlNet,
|
||||||
base_model=control_info.control_model.base_model,
|
base_model=control_info.control_model.base_model,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# control_models.append(control_model)
|
# control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
input_image = context.get_image(control_image_field.image_name)
|
||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
# and add in batch_size, num_images_per_prompt?
|
# and add in batch_size, num_images_per_prompt?
|
||||||
@@ -430,30 +471,29 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
conditioning_data.ip_adapter_conditioning = []
|
conditioning_data.ip_adapter_conditioning = []
|
||||||
for single_ip_adapter in ip_adapter:
|
for single_ip_adapter in ip_adapter:
|
||||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||||
context.services.model_manager.get_model(
|
context.get_model(
|
||||||
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
||||||
model_type=ModelType.IPAdapter,
|
model_type=ModelType.IPAdapter,
|
||||||
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
image_encoder_model_info = context.services.model_manager.get_model(
|
image_encoder_model_info = context.get_model(
|
||||||
model_name=single_ip_adapter.image_encoder_model.model_name,
|
model_name=single_ip_adapter.image_encoder_model.model_name,
|
||||||
model_type=ModelType.CLIPVision,
|
model_type=ModelType.CLIPVision,
|
||||||
base_model=single_ip_adapter.image_encoder_model.base_model,
|
base_model=single_ip_adapter.image_encoder_model.base_model,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name)
|
input_image = context.get_image(single_ip_adapter.image.image_name)
|
||||||
|
|
||||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||||
with image_encoder_model_info as image_encoder_model:
|
with image_encoder_model_info as image_encoder_model:
|
||||||
# Get image embeddings from CLIP and ImageProjModel.
|
# Get image embeddings from CLIP and ImageProjModel.
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
(
|
||||||
input_image, image_encoder_model
|
image_prompt_embeds,
|
||||||
)
|
uncond_image_prompt_embeds,
|
||||||
|
) = ip_adapter_model.get_image_embeds(input_image, image_encoder_model)
|
||||||
conditioning_data.ip_adapter_conditioning.append(
|
conditioning_data.ip_adapter_conditioning.append(
|
||||||
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
||||||
)
|
)
|
||||||
@@ -488,13 +528,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
t2i_adapter_data = []
|
t2i_adapter_data = []
|
||||||
for t2i_adapter_field in t2i_adapter:
|
for t2i_adapter_field in t2i_adapter:
|
||||||
t2i_adapter_model_info = context.services.model_manager.get_model(
|
t2i_adapter_model_info = context.get_model(
|
||||||
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
||||||
model_type=ModelType.T2IAdapter,
|
model_type=ModelType.T2IAdapter,
|
||||||
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
|
image = context.get_image(t2i_adapter_field.image.image_name)
|
||||||
|
|
||||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||||
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
||||||
@@ -604,11 +643,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
seed = None
|
seed = None
|
||||||
noise = None
|
noise = None
|
||||||
if self.noise is not None:
|
if self.noise is not None:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.get_latents(self.noise.latents_name)
|
||||||
seed = self.noise.seed
|
seed = self.noise.seed
|
||||||
|
|
||||||
if self.latents is not None:
|
if self.latents is not None:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.get_latents(self.latents.latents_name)
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = self.latents.seed
|
seed = self.latents.seed
|
||||||
|
|
||||||
@@ -628,29 +667,26 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
# below. Investigate whether this is appropriate.
|
# below. Investigate whether this is appropriate.
|
||||||
t2i_adapter_data = self.run_t2i_adapters(
|
t2i_adapter_data = self.run_t2i_adapters(
|
||||||
context, self.t2i_adapter, latents.shape, do_classifier_free_guidance=True
|
context,
|
||||||
|
self.t2i_adapter,
|
||||||
|
latents.shape,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
self.dispatch_progress(context, state, self.unet.unet.base_model)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.get_model(
|
||||||
**lora.dict(exclude={"weight"}),
|
**lora.model_dump(exclude={"weight"}),
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.get_model(
|
||||||
**self.unet.unet.dict(),
|
**self.unet.unet.model_dump(),
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
@@ -700,7 +736,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
(
|
||||||
|
result_latents,
|
||||||
|
result_attention_map_saver,
|
||||||
|
) = pipeline.latents_from_embeddings(
|
||||||
latents=latents,
|
latents=latents,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
init_timestep=init_timestep,
|
init_timestep=init_timestep,
|
||||||
@@ -722,13 +761,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
latents_name = context.save_latents(result_latents)
|
||||||
context.services.latents.save(name, result_latents)
|
return build_latents_output(latents_name=latents_name, latents=result_latents, seed=seed)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", version="1.0.0"
|
"l2i",
|
||||||
|
title="Latents to Image",
|
||||||
|
tags=["latents", "image", "vae", "l2i"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class LatentsToImageInvocation(BaseInvocation):
|
class LatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
@@ -743,7 +785,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||||
metadata: CoreMetadata = InputField(
|
metadata: Optional[CoreMetadata] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.core_metadata,
|
description=FieldDescriptions.core_metadata,
|
||||||
ui_hidden=True,
|
ui_hidden=True,
|
||||||
@@ -751,11 +793,10 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.get_latents(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.model_dump(),
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
@@ -785,7 +826,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
vae.to(dtype=torch.float16)
|
vae.to(dtype=torch.float16)
|
||||||
latents = latents.half()
|
latents = latents.half()
|
||||||
|
|
||||||
if self.tiled or context.services.configuration.tiled_decode:
|
if self.tiled or context.config.tiled_decode:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
else:
|
else:
|
||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
@@ -809,28 +850,25 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_name = context.save_image(image, category=context.categories.GENERAL)
|
||||||
image=image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
|
||||||
workflow=self.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
image=ImageField(image_name=image_name),
|
||||||
width=image_dto.width,
|
width=image.width,
|
||||||
height=image_dto.height,
|
height=image.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
|
|
||||||
|
|
||||||
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
@invocation(
|
||||||
|
"lresize",
|
||||||
|
title="Resize Latents",
|
||||||
|
tags=["latents", "resize"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||||
|
|
||||||
@@ -876,7 +914,13 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
@invocation(
|
||||||
|
"lscale",
|
||||||
|
title="Scale Latents",
|
||||||
|
tags=["latents", "resize"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
"""Scales latents by a given factor."""
|
"""Scales latents by a given factor."""
|
||||||
|
|
||||||
@@ -915,7 +959,11 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", version="1.0.0"
|
"i2l",
|
||||||
|
title="Image to Latents",
|
||||||
|
tags=["latents", "image", "vae", "i2l"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
"""Encodes an image into latents."""
|
"""Encodes an image into latents."""
|
||||||
@@ -979,7 +1027,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.model_dump(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1007,7 +1055,13 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
return vae.encode(image_tensor).latents
|
return vae.encode(image_tensor).latents
|
||||||
|
|
||||||
|
|
||||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
|
@invocation(
|
||||||
|
"lblend",
|
||||||
|
title="Blend Latents",
|
||||||
|
tags=["latents", "blend"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class BlendLatentsInvocation(BaseInvocation):
|
class BlendLatentsInvocation(BaseInvocation):
|
||||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
||||||
|
|
||||||
@@ -72,7 +72,14 @@ class RandomIntInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||||
|
|
||||||
|
|
||||||
@invocation("rand_float", title="Random Float", tags=["math", "float", "random"], category="math", version="1.0.0")
|
@invocation(
|
||||||
|
"rand_float",
|
||||||
|
title="Random Float",
|
||||||
|
tags=["math", "float", "random"],
|
||||||
|
category="math",
|
||||||
|
version="1.0.1",
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
class RandomFloatInvocation(BaseInvocation):
|
class RandomFloatInvocation(BaseInvocation):
|
||||||
"""Outputs a single random float"""
|
"""Outputs a single random float"""
|
||||||
|
|
||||||
@@ -178,7 +185,7 @@ class IntegerMathInvocation(BaseInvocation):
|
|||||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
@validator("b")
|
@field_validator("b")
|
||||||
def no_unrepresentable_results(cls, v, values):
|
def no_unrepresentable_results(cls, v, values):
|
||||||
if values["operation"] == "DIV" and v == 0:
|
if values["operation"] == "DIV" and v == 0:
|
||||||
raise ValueError("Cannot divide by zero")
|
raise ValueError("Cannot divide by zero")
|
||||||
@@ -252,7 +259,7 @@ class FloatMathInvocation(BaseInvocation):
|
|||||||
a: float = InputField(default=0, description=FieldDescriptions.num_1)
|
a: float = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: float = InputField(default=0, description=FieldDescriptions.num_2)
|
b: float = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
@validator("b")
|
@field_validator("b")
|
||||||
def no_unrepresentable_results(cls, v, values):
|
def no_unrepresentable_results(cls, v, values):
|
||||||
if values["operation"] == "DIV" and v == 0:
|
if values["operation"] == "DIV" and v == 0:
|
||||||
raise ValueError("Cannot divide by zero")
|
raise ValueError("Cannot divide by zero")
|
||||||
|
|||||||
@@ -223,4 +223,4 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||||
"""Collects and outputs a CoreMetadata object"""
|
"""Collects and outputs a CoreMetadata object"""
|
||||||
|
|
||||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
|
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.model_dump()))
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
@@ -24,6 +24,8 @@ class ModelInfo(BaseModel):
|
|||||||
model_type: ModelType = Field(description="Info to load submodel")
|
model_type: ModelType = Field(description="Info to load submodel")
|
||||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class LoraInfo(ModelInfo):
|
class LoraInfo(ModelInfo):
|
||||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||||
@@ -65,6 +67,8 @@ class MainModelField(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
model_type: ModelType = Field(description="Model Type")
|
model_type: ModelType = Field(description="Model Type")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelField(BaseModel):
|
class LoRAModelField(BaseModel):
|
||||||
"""LoRA model field"""
|
"""LoRA model field"""
|
||||||
@@ -72,8 +76,16 @@ class LoRAModelField(BaseModel):
|
|||||||
model_name: str = Field(description="Name of the LoRA model")
|
model_name: str = Field(description="Name of the LoRA model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
|
|
||||||
|
@invocation(
|
||||||
|
"main_model_loader",
|
||||||
|
title="Main Model",
|
||||||
|
tags=["model"],
|
||||||
|
category="model",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
@@ -86,7 +98,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
model_type = ModelType.Main
|
model_type = ModelType.Main
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.model_exists(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@@ -180,10 +192,16 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
default=None,
|
||||||
|
description=FieldDescriptions.unet,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="UNet",
|
||||||
)
|
)
|
||||||
clip: Optional[ClipField] = InputField(
|
clip: Optional[ClipField] = InputField(
|
||||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP"
|
default=None,
|
||||||
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="CLIP",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||||
@@ -244,20 +262,35 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
|||||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
|
|
||||||
|
|
||||||
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model", version="1.0.0")
|
@invocation(
|
||||||
|
"sdxl_lora_loader",
|
||||||
|
title="SDXL LoRA",
|
||||||
|
tags=["lora", "model"],
|
||||||
|
category="model",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
default=None,
|
||||||
|
description=FieldDescriptions.unet,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="UNet",
|
||||||
)
|
)
|
||||||
clip: Optional[ClipField] = InputField(
|
clip: Optional[ClipField] = InputField(
|
||||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
default=None,
|
||||||
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="CLIP 1",
|
||||||
)
|
)
|
||||||
clip2: Optional[ClipField] = InputField(
|
clip2: Optional[ClipField] = InputField(
|
||||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
default=None,
|
||||||
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="CLIP 2",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||||
@@ -330,6 +363,8 @@ class VAEModelField(BaseModel):
|
|||||||
model_name: str = Field(description="Name of the model")
|
model_name: str = Field(description="Name of the model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("vae_loader_output")
|
@invocation_output("vae_loader_output")
|
||||||
class VaeLoaderOutput(BaseInvocationOutput):
|
class VaeLoaderOutput(BaseInvocationOutput):
|
||||||
@@ -343,7 +378,10 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
|
||||||
vae_model: VAEModelField = InputField(
|
vae_model: VAEModelField = InputField(
|
||||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
description=FieldDescriptions.vae_model,
|
||||||
|
input=Input.Direct,
|
||||||
|
ui_type=UIType.VaeModel,
|
||||||
|
title="VAE",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||||
@@ -372,19 +410,31 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
class SeamlessModeOutput(BaseInvocationOutput):
|
class SeamlessModeOutput(BaseInvocationOutput):
|
||||||
"""Modified Seamless Model output"""
|
"""Modified Seamless Model output"""
|
||||||
|
|
||||||
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
|
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model", version="1.0.0")
|
@invocation(
|
||||||
|
"seamless",
|
||||||
|
title="Seamless",
|
||||||
|
tags=["seamless", "model"],
|
||||||
|
category="model",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class SeamlessModeInvocation(BaseInvocation):
|
class SeamlessModeInvocation(BaseInvocation):
|
||||||
"""Applies the seamless transformation to the Model UNet and VAE."""
|
"""Applies the seamless transformation to the Model UNet and VAE."""
|
||||||
|
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
default=None,
|
||||||
|
description=FieldDescriptions.unet,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="UNet",
|
||||||
)
|
)
|
||||||
vae: Optional[VaeField] = InputField(
|
vae: Optional[VaeField] = InputField(
|
||||||
default=None, description=FieldDescriptions.vae_model, input=Input.Connection, title="VAE"
|
default=None,
|
||||||
|
description=FieldDescriptions.vae_model,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="VAE",
|
||||||
)
|
)
|
||||||
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
||||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.latent import LatentsField
|
from invokeai.app.invocations.latent import LatentsField
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
@@ -65,7 +65,7 @@ Nodes
|
|||||||
class NoiseOutput(BaseInvocationOutput):
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
"""Invocation noise output"""
|
"""Invocation noise output"""
|
||||||
|
|
||||||
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
noise: LatentsField = OutputField(description=FieldDescriptions.noise)
|
||||||
width: int = OutputField(description=FieldDescriptions.width)
|
width: int = OutputField(description=FieldDescriptions.width)
|
||||||
height: int = OutputField(description=FieldDescriptions.height)
|
height: int = OutputField(description=FieldDescriptions.height)
|
||||||
|
|
||||||
@@ -78,7 +78,13 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents", version="1.0.0")
|
@invocation(
|
||||||
|
"noise",
|
||||||
|
title="Noise",
|
||||||
|
tags=["latents", "noise"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class NoiseInvocation(BaseInvocation):
|
class NoiseInvocation(BaseInvocation):
|
||||||
"""Generates latent noise."""
|
"""Generates latent noise."""
|
||||||
|
|
||||||
@@ -105,7 +111,7 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
description="Use CPU for noise generation (for reproducible results across platforms)",
|
description="Use CPU for noise generation (for reproducible results across platforms)",
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("seed", pre=True)
|
@field_validator("seed", mode="before")
|
||||||
def modulo_seed(cls, v):
|
def modulo_seed(cls, v):
|
||||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||||
return v % (SEED_MAX + 1)
|
return v % (SEED_MAX + 1)
|
||||||
@@ -118,6 +124,5 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
use_cpu=self.use_cpu,
|
use_cpu=self.use_cpu,
|
||||||
)
|
)
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
latents_name = context.save_latents(noise)
|
||||||
context.services.latents.save(name, noise)
|
return build_noise_output(latents_name=latents_name, latents=noise, seed=self.seed)
|
||||||
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
|
|
||||||
|
|||||||
@@ -9,18 +9,18 @@ from typing import List, Literal, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
from ...backend.model_management import ONNXModelPatcher
|
from ...backend.model_management import ONNXModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util import choose_torch_device
|
from ...backend.util import choose_torch_device
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@@ -63,14 +63,17 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**self.clip.tokenizer.dict(),
|
**self.clip.tokenizer.model_dump(),
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.model_dump(),
|
||||||
)
|
)
|
||||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||||
loras = [
|
loras = [
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
(
|
||||||
|
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||||
|
lora.weight,
|
||||||
|
)
|
||||||
for lora in self.clip.loras
|
for lora in self.clip.loras
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -175,14 +178,14 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.unet,
|
description=FieldDescriptions.unet,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
control: Union[ControlField, list[ControlField]] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.control,
|
description=FieldDescriptions.control,
|
||||||
)
|
)
|
||||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@field_validator("cfg_scale")
|
||||||
def ge_one(cls, v):
|
def ge_one(cls, v):
|
||||||
"""validate that all cfg_scale values are >= 1"""
|
"""validate that all cfg_scale values are >= 1"""
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
@@ -241,7 +244,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
stable_diffusion_step_callback(
|
stable_diffusion_step_callback(
|
||||||
context=context,
|
context=context,
|
||||||
intermediate_state=intermediate_state,
|
intermediate_state=intermediate_state,
|
||||||
node=self.dict(),
|
node=self.model_dump(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -254,12 +257,15 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
eta=0.0,
|
eta=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump())
|
||||||
|
|
||||||
with unet_info as unet: # , ExitStack() as stack:
|
with unet_info as unet: # , ExitStack() as stack:
|
||||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
loras = [
|
loras = [
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
(
|
||||||
|
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||||
|
lora.weight,
|
||||||
|
)
|
||||||
for lora in self.unet.loras
|
for lora in self.unet.loras
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -346,7 +352,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
|||||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.model_dump(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# clear memory as vae decode can request a lot
|
# clear memory as vae decode can request a lot
|
||||||
@@ -375,7 +381,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -403,6 +409,8 @@ class OnnxModelField(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
model_type: ModelType = Field(description="Model Type")
|
model_type: ModelType = Field(description="Model Type")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
||||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||||
|
|||||||
@@ -44,13 +44,22 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput
|
|||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math", version="1.0.0")
|
@invocation(
|
||||||
|
"float_range",
|
||||||
|
title="Float Range",
|
||||||
|
tags=["math", "range"],
|
||||||
|
category="math",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class FloatLinearRangeInvocation(BaseInvocation):
|
class FloatLinearRangeInvocation(BaseInvocation):
|
||||||
"""Creates a range"""
|
"""Creates a range"""
|
||||||
|
|
||||||
start: float = InputField(default=5, description="The first value of the range")
|
start: float = InputField(default=5, description="The first value of the range")
|
||||||
stop: float = InputField(default=10, description="The last value of the range")
|
stop: float = InputField(default=10, description="The last value of the range")
|
||||||
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
|
steps: int = InputField(
|
||||||
|
default=30,
|
||||||
|
description="number of values to interpolate over (including start and stop)",
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||||
@@ -95,7 +104,13 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
|||||||
|
|
||||||
|
|
||||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step", version="1.0.0")
|
@invocation(
|
||||||
|
"step_param_easing",
|
||||||
|
title="Step Param Easing",
|
||||||
|
tags=["step", "easing"],
|
||||||
|
category="step",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class StepParamEasingInvocation(BaseInvocation):
|
class StepParamEasingInvocation(BaseInvocation):
|
||||||
"""Experimental per-step parameter easing for denoising steps"""
|
"""Experimental per-step parameter easing for denoising steps"""
|
||||||
|
|
||||||
@@ -159,7 +174,9 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||||
easing_function = easing_class(
|
easing_function = easing_class(
|
||||||
start=self.start_value, end=self.end_value, duration=base_easing_duration - 1
|
start=self.start_value,
|
||||||
|
end=self.end_value,
|
||||||
|
duration=base_easing_duration - 1,
|
||||||
)
|
)
|
||||||
base_easing_vals = list()
|
base_easing_vals = list()
|
||||||
for step_index in range(base_easing_duration):
|
for step_index in range(base_easing_duration):
|
||||||
@@ -199,7 +216,11 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
#
|
#
|
||||||
|
|
||||||
else: # no mirroring (default)
|
else: # no mirroring (default)
|
||||||
easing_function = easing_class(start=self.start_value, end=self.end_value, duration=num_easing_steps - 1)
|
easing_function = easing_class(
|
||||||
|
start=self.start_value,
|
||||||
|
end=self.end_value,
|
||||||
|
duration=num_easing_steps - 1,
|
||||||
|
)
|
||||||
for step_index in range(num_easing_steps):
|
for step_index in range(num_easing_steps):
|
||||||
step_val = easing_function.ease(step_index)
|
step_val = easing_function.ease(step_index)
|
||||||
easing_list.append(step_val)
|
easing_list.append(step_val)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||||
from pydantic import validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||||
|
|
||||||
@@ -21,7 +21,10 @@ from .baseinvocation import BaseInvocation, InputField, InvocationContext, UICom
|
|||||||
class DynamicPromptInvocation(BaseInvocation):
|
class DynamicPromptInvocation(BaseInvocation):
|
||||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||||
|
|
||||||
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
|
prompt: str = InputField(
|
||||||
|
description="The prompt to parse with dynamicprompts",
|
||||||
|
ui_component=UIComponent.Textarea,
|
||||||
|
)
|
||||||
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
||||||
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
||||||
|
|
||||||
@@ -36,21 +39,31 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
return StringCollectionOutput(collection=prompts)
|
return StringCollectionOutput(collection=prompts)
|
||||||
|
|
||||||
|
|
||||||
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt", version="1.0.0")
|
@invocation(
|
||||||
|
"prompt_from_file",
|
||||||
|
title="Prompts from File",
|
||||||
|
tags=["prompt", "file"],
|
||||||
|
category="prompt",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class PromptsFromFileInvocation(BaseInvocation):
|
class PromptsFromFileInvocation(BaseInvocation):
|
||||||
"""Loads prompts from a text file"""
|
"""Loads prompts from a text file"""
|
||||||
|
|
||||||
file_path: str = InputField(description="Path to prompt text file")
|
file_path: str = InputField(description="Path to prompt text file")
|
||||||
pre_prompt: Optional[str] = InputField(
|
pre_prompt: Optional[str] = InputField(
|
||||||
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
|
default=None,
|
||||||
|
description="String to prepend to each prompt",
|
||||||
|
ui_component=UIComponent.Textarea,
|
||||||
)
|
)
|
||||||
post_prompt: Optional[str] = InputField(
|
post_prompt: Optional[str] = InputField(
|
||||||
default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea
|
default=None,
|
||||||
|
description="String to append to each prompt",
|
||||||
|
ui_component=UIComponent.Textarea,
|
||||||
)
|
)
|
||||||
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
|
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
|
||||||
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
|
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||||
|
|
||||||
@validator("file_path")
|
@field_validator("file_path")
|
||||||
def file_path_exists(cls, v):
|
def file_path_exists(cls, v):
|
||||||
if not exists(v):
|
if not exists(v):
|
||||||
raise ValueError(FileNotFoundError)
|
raise ValueError(FileNotFoundError)
|
||||||
@@ -79,6 +92,10 @@ class PromptsFromFileInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||||
prompts = self.promptsFromFile(
|
prompts = self.promptsFromFile(
|
||||||
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
self.file_path,
|
||||||
|
self.pre_prompt,
|
||||||
|
self.post_prompt,
|
||||||
|
self.start_line,
|
||||||
|
self.max_prompts,
|
||||||
)
|
)
|
||||||
return StringCollectionOutput(collection=prompts)
|
return StringCollectionOutput(collection=prompts)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@@ -23,6 +23,8 @@ class T2IAdapterModelField(BaseModel):
|
|||||||
model_name: str = Field(description="Name of the T2I-Adapter model")
|
model_name: str = Field(description="Name of the T2I-Adapter model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterField(BaseModel):
|
class T2IAdapterField(BaseModel):
|
||||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pydantic import ConfigDict
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
@@ -38,6 +39,8 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
|
default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
models_path = context.services.configuration.models_path
|
models_path = context.services.configuration.models_path
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
class CanceledException(Exception):
|
|
||||||
"""Execution canceled by user."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.util.metaenum import MetaEnum
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressImage(BaseModel):
|
|
||||||
"""The progress image sent intermittently during processing"""
|
|
||||||
|
|
||||||
width: int = Field(description="The effective width of the image in pixels")
|
|
||||||
height: int = Field(description="The effective height of the image in pixels")
|
|
||||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
|
||||||
|
|
||||||
|
|
||||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
|
||||||
"""The origin of a resource (eg image).
|
|
||||||
|
|
||||||
- INTERNAL: The resource was created by the application.
|
|
||||||
- EXTERNAL: The resource was not created by the application.
|
|
||||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
|
||||||
"""
|
|
||||||
|
|
||||||
INTERNAL = "internal"
|
|
||||||
"""The resource was created by the application."""
|
|
||||||
EXTERNAL = "external"
|
|
||||||
"""The resource was not created by the application.
|
|
||||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidOriginException(ValueError):
|
|
||||||
"""Raised when a provided value is not a valid ResourceOrigin.
|
|
||||||
|
|
||||||
Subclasses `ValueError`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message="Invalid resource origin."):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
|
||||||
"""The category of an image.
|
|
||||||
|
|
||||||
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
|
||||||
- MASK: The image is a mask image.
|
|
||||||
- CONTROL: The image is a ControlNet control image.
|
|
||||||
- USER: The image is a user-provide image.
|
|
||||||
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
GENERAL = "general"
|
|
||||||
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
|
||||||
MASK = "mask"
|
|
||||||
"""MASK: The image is a mask image."""
|
|
||||||
CONTROL = "control"
|
|
||||||
"""CONTROL: The image is a ControlNet control image."""
|
|
||||||
USER = "user"
|
|
||||||
"""USER: The image is a user-provide image."""
|
|
||||||
OTHER = "other"
|
|
||||||
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidImageCategoryException(ValueError):
|
|
||||||
"""Raised when a provided value is not a valid ImageCategory.
|
|
||||||
|
|
||||||
Subclasses `ValueError`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message="Invalid image category."):
|
|
||||||
super().__init__(message)
|
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImageRecordStorageBase(ABC):
|
||||||
|
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Adds an image to a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Removes an image from a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all_board_image_names_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Gets all board images for a board, as a list of the image names."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Gets an image's board id, if it has one."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_image_count_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> int:
|
||||||
|
"""Gets the number of images for a board."""
|
||||||
|
pass
|
||||||
@@ -1,55 +1,12 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
|
||||||
from invokeai.app.services.models.image_record import ImageRecord, deserialize_image_record
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
|
||||||
|
from .board_image_records_base import BoardImageRecordStorageBase
|
||||||
class BoardImageRecordStorageBase(ABC):
|
|
||||||
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_image_to_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Adds an image to a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def remove_image_from_board(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Removes an image from a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all_board_image_names_for_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Gets all board images for a board, as a list of the image names."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_board_for_image(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Gets an image's board id, if it has one."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_image_count_for_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> int:
|
|
||||||
"""Gets the number of images for a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||||
@@ -57,13 +14,11 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.RLock
|
_lock: threading.RLock
|
||||||
|
|
||||||
def __init__(self, conn: sqlite3.Connection, lock: threading.RLock) -> None:
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._conn = conn
|
self._lock = db.lock
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
self._conn = db.conn
|
||||||
self._conn.row_factory = sqlite3.Row
|
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = lock
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
@@ -1,112 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from logging import Logger
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
|
||||||
from invokeai.app.services.board_record_storage import BoardRecord, BoardRecordStorageBase
|
|
||||||
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
|
|
||||||
from invokeai.app.services.models.board_record import BoardDTO
|
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
|
||||||
|
|
||||||
|
|
||||||
class BoardImagesServiceABC(ABC):
|
|
||||||
"""High-level service for board-image relationship management."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_image_to_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Adds an image to a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def remove_image_from_board(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Removes an image from a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all_board_image_names_for_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Gets all board images for a board, as a list of the image names."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_board_for_image(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Gets an image's board id, if it has one."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BoardImagesServiceDependencies:
|
|
||||||
"""Service dependencies for the BoardImagesService."""
|
|
||||||
|
|
||||||
board_image_records: BoardImageRecordStorageBase
|
|
||||||
board_records: BoardRecordStorageBase
|
|
||||||
image_records: ImageRecordStorageBase
|
|
||||||
urls: UrlServiceBase
|
|
||||||
logger: Logger
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
board_image_record_storage: BoardImageRecordStorageBase,
|
|
||||||
image_record_storage: ImageRecordStorageBase,
|
|
||||||
board_record_storage: BoardRecordStorageBase,
|
|
||||||
url: UrlServiceBase,
|
|
||||||
logger: Logger,
|
|
||||||
):
|
|
||||||
self.board_image_records = board_image_record_storage
|
|
||||||
self.image_records = image_record_storage
|
|
||||||
self.board_records = board_record_storage
|
|
||||||
self.urls = url
|
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
|
|
||||||
class BoardImagesService(BoardImagesServiceABC):
|
|
||||||
_services: BoardImagesServiceDependencies
|
|
||||||
|
|
||||||
def __init__(self, services: BoardImagesServiceDependencies):
|
|
||||||
self._services = services
|
|
||||||
|
|
||||||
def add_image_to_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
self._services.board_image_records.add_image_to_board(board_id, image_name)
|
|
||||||
|
|
||||||
def remove_image_from_board(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
self._services.board_image_records.remove_image_from_board(image_name)
|
|
||||||
|
|
||||||
def get_all_board_image_names_for_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> list[str]:
|
|
||||||
return self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
|
||||||
|
|
||||||
def get_board_for_image(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> Optional[str]:
|
|
||||||
board_id = self._services.board_image_records.get_board_for_image(image_name)
|
|
||||||
return board_id
|
|
||||||
|
|
||||||
|
|
||||||
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
|
||||||
"""Converts a board record to a board DTO."""
|
|
||||||
return BoardDTO(
|
|
||||||
**board_record.dict(exclude={"cover_image_name"}),
|
|
||||||
cover_image_name=cover_image_name,
|
|
||||||
image_count=image_count,
|
|
||||||
)
|
|
||||||
0
invokeai/app/services/board_images/__init__.py
Normal file
0
invokeai/app/services/board_images/__init__.py
Normal file
39
invokeai/app/services/board_images/board_images_base.py
Normal file
39
invokeai/app/services/board_images/board_images_base.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImagesServiceABC(ABC):
|
||||||
|
"""High-level service for board-image relationship management."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Adds an image to a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Removes an image from a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all_board_image_names_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Gets all board images for a board, as a list of the image names."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Gets an image's board id, if it has one."""
|
||||||
|
pass
|
||||||
38
invokeai/app/services/board_images/board_images_default.py
Normal file
38
invokeai/app/services/board_images/board_images_default.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
|
from .board_images_base import BoardImagesServiceABC
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImagesService(BoardImagesServiceABC):
|
||||||
|
__invoker: Invoker
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self.__invoker = invoker
|
||||||
|
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
self.__invoker.services.board_image_records.add_image_to_board(board_id, image_name)
|
||||||
|
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
self.__invoker.services.board_image_records.remove_image_from_board(image_name)
|
||||||
|
|
||||||
|
def get_all_board_image_names_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||||
|
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
board_id = self.__invoker.services.board_image_records.get_board_for_image(image_name)
|
||||||
|
return board_id
|
||||||
55
invokeai/app/services/board_records/board_records_base.py
Normal file
55
invokeai/app/services/board_records/board_records_base.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
|
from .board_records_common import BoardChanges, BoardRecord
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for interfacing with the board record store."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, board_id: str) -> None:
|
||||||
|
"""Deletes a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
board_name: str,
|
||||||
|
) -> BoardRecord:
|
||||||
|
"""Saves a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> BoardRecord:
|
||||||
|
"""Gets a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
changes: BoardChanges,
|
||||||
|
) -> BoardRecord:
|
||||||
|
"""Updates a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> OffsetPaginatedResults[BoardRecord]:
|
||||||
|
"""Gets many board records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BoardRecord]:
|
||||||
|
"""Gets all board records."""
|
||||||
|
pass
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
@@ -18,21 +18,12 @@ class BoardRecord(BaseModelExcludeNull):
|
|||||||
"""The created timestamp of the image."""
|
"""The created timestamp of the image."""
|
||||||
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
|
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
|
||||||
"""The updated timestamp of the image."""
|
"""The updated timestamp of the image."""
|
||||||
deleted_at: Union[datetime, str, None] = Field(description="The deleted timestamp of the board.")
|
deleted_at: Optional[Union[datetime, str]] = Field(default=None, description="The deleted timestamp of the board.")
|
||||||
"""The updated timestamp of the image."""
|
"""The updated timestamp of the image."""
|
||||||
cover_image_name: Optional[str] = Field(description="The name of the cover image of the board.")
|
cover_image_name: Optional[str] = Field(default=None, description="The name of the cover image of the board.")
|
||||||
"""The name of the cover image of the board."""
|
"""The name of the cover image of the board."""
|
||||||
|
|
||||||
|
|
||||||
class BoardDTO(BoardRecord):
|
|
||||||
"""Deserialized board record with cover image URL and image count."""
|
|
||||||
|
|
||||||
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
|
||||||
"""The URL of the thumbnail of the most recent image in the board."""
|
|
||||||
image_count: int = Field(description="The number of images in the board.")
|
|
||||||
"""The number of images in the board."""
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||||
"""Deserializes a board record."""
|
"""Deserializes a board record."""
|
||||||
|
|
||||||
@@ -53,3 +44,29 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
|||||||
updated_at=updated_at,
|
updated_at=updated_at,
|
||||||
deleted_at=deleted_at,
|
deleted_at=deleted_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardChanges(BaseModel, extra="forbid"):
|
||||||
|
board_name: Optional[str] = Field(default=None, description="The board's new name.")
|
||||||
|
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordNotFoundException(Exception):
|
||||||
|
"""Raised when an board record is not found."""
|
||||||
|
|
||||||
|
def __init__(self, message="Board record not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordSaveException(Exception):
|
||||||
|
"""Raised when an board record cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Board record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordDeleteException(Exception):
|
||||||
|
"""Raised when an board record cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Board record not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
@@ -1,89 +1,20 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from typing import Union, cast
|
||||||
from typing import Optional, Union, cast
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
|
||||||
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
from .board_records_base import BoardRecordStorageBase
|
||||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
from .board_records_common import (
|
||||||
board_name: Optional[str] = Field(description="The board's new name.")
|
BoardChanges,
|
||||||
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
|
BoardRecord,
|
||||||
|
BoardRecordDeleteException,
|
||||||
|
BoardRecordNotFoundException,
|
||||||
class BoardRecordNotFoundException(Exception):
|
BoardRecordSaveException,
|
||||||
"""Raised when an board record is not found."""
|
deserialize_board_record,
|
||||||
|
)
|
||||||
def __init__(self, message="Board record not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordSaveException(Exception):
|
|
||||||
"""Raised when an board record cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="Board record not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordDeleteException(Exception):
|
|
||||||
"""Raised when an board record cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="Board record not deleted"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordStorageBase(ABC):
|
|
||||||
"""Low-level service responsible for interfacing with the board record store."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, board_id: str) -> None:
|
|
||||||
"""Deletes a board record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
board_name: str,
|
|
||||||
) -> BoardRecord:
|
|
||||||
"""Saves a board record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> BoardRecord:
|
|
||||||
"""Gets a board record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
changes: BoardChanges,
|
|
||||||
) -> BoardRecord:
|
|
||||||
"""Updates a board record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: int = 0,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> OffsetPaginatedResults[BoardRecord]:
|
|
||||||
"""Gets many board records."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all(
|
|
||||||
self,
|
|
||||||
) -> list[BoardRecord]:
|
|
||||||
"""Gets all board records."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||||
@@ -91,13 +22,11 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.RLock
|
_lock: threading.RLock
|
||||||
|
|
||||||
def __init__(self, conn: sqlite3.Connection, lock: threading.RLock) -> None:
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._conn = conn
|
self._lock = db.lock
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
self._conn = db.conn
|
||||||
self._conn.row_factory = sqlite3.Row
|
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = lock
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
@@ -1,158 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from logging import Logger
|
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
|
||||||
from invokeai.app.services.board_images import board_record_to_dto
|
|
||||||
from invokeai.app.services.board_record_storage import BoardChanges, BoardRecordStorageBase
|
|
||||||
from invokeai.app.services.image_record_storage import ImageRecordStorageBase, OffsetPaginatedResults
|
|
||||||
from invokeai.app.services.models.board_record import BoardDTO
|
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
|
||||||
|
|
||||||
|
|
||||||
class BoardServiceABC(ABC):
|
|
||||||
"""High-level service for board management."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
board_name: str,
|
|
||||||
) -> BoardDTO:
|
|
||||||
"""Creates a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_dto(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> BoardDTO:
|
|
||||||
"""Gets a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
changes: BoardChanges,
|
|
||||||
) -> BoardDTO:
|
|
||||||
"""Updates a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Deletes a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: int = 0,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> OffsetPaginatedResults[BoardDTO]:
|
|
||||||
"""Gets many boards."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all(
|
|
||||||
self,
|
|
||||||
) -> list[BoardDTO]:
|
|
||||||
"""Gets all boards."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BoardServiceDependencies:
|
|
||||||
"""Service dependencies for the BoardService."""
|
|
||||||
|
|
||||||
board_image_records: BoardImageRecordStorageBase
|
|
||||||
board_records: BoardRecordStorageBase
|
|
||||||
image_records: ImageRecordStorageBase
|
|
||||||
urls: UrlServiceBase
|
|
||||||
logger: Logger
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
board_image_record_storage: BoardImageRecordStorageBase,
|
|
||||||
image_record_storage: ImageRecordStorageBase,
|
|
||||||
board_record_storage: BoardRecordStorageBase,
|
|
||||||
url: UrlServiceBase,
|
|
||||||
logger: Logger,
|
|
||||||
):
|
|
||||||
self.board_image_records = board_image_record_storage
|
|
||||||
self.image_records = image_record_storage
|
|
||||||
self.board_records = board_record_storage
|
|
||||||
self.urls = url
|
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
|
|
||||||
class BoardService(BoardServiceABC):
|
|
||||||
_services: BoardServiceDependencies
|
|
||||||
|
|
||||||
def __init__(self, services: BoardServiceDependencies):
|
|
||||||
self._services = services
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
board_name: str,
|
|
||||||
) -> BoardDTO:
|
|
||||||
board_record = self._services.board_records.save(board_name)
|
|
||||||
return board_record_to_dto(board_record, None, 0)
|
|
||||||
|
|
||||||
def get_dto(self, board_id: str) -> BoardDTO:
|
|
||||||
board_record = self._services.board_records.get(board_id)
|
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
|
||||||
if cover_image:
|
|
||||||
cover_image_name = cover_image.image_name
|
|
||||||
else:
|
|
||||||
cover_image_name = None
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
|
||||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
changes: BoardChanges,
|
|
||||||
) -> BoardDTO:
|
|
||||||
board_record = self._services.board_records.update(board_id, changes)
|
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
|
||||||
if cover_image:
|
|
||||||
cover_image_name = cover_image.image_name
|
|
||||||
else:
|
|
||||||
cover_image_name = None
|
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
|
||||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
|
||||||
|
|
||||||
def delete(self, board_id: str) -> None:
|
|
||||||
self._services.board_records.delete(board_id)
|
|
||||||
|
|
||||||
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
|
||||||
board_records = self._services.board_records.get_many(offset, limit)
|
|
||||||
board_dtos = []
|
|
||||||
for r in board_records.items:
|
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
|
||||||
if cover_image:
|
|
||||||
cover_image_name = cover_image.image_name
|
|
||||||
else:
|
|
||||||
cover_image_name = None
|
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
|
||||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
|
||||||
|
|
||||||
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
|
||||||
|
|
||||||
def get_all(self) -> list[BoardDTO]:
|
|
||||||
board_records = self._services.board_records.get_all()
|
|
||||||
board_dtos = []
|
|
||||||
for r in board_records:
|
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
|
||||||
if cover_image:
|
|
||||||
cover_image_name = cover_image.image_name
|
|
||||||
else:
|
|
||||||
cover_image_name = None
|
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
|
||||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
|
||||||
|
|
||||||
return board_dtos
|
|
||||||
0
invokeai/app/services/boards/__init__.py
Normal file
0
invokeai/app/services/boards/__init__.py
Normal file
59
invokeai/app/services/boards/boards_base.py
Normal file
59
invokeai/app/services/boards/boards_base.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||||
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
|
from .boards_common import BoardDTO
|
||||||
|
|
||||||
|
|
||||||
|
class BoardServiceABC(ABC):
|
||||||
|
"""High-level service for board management."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
board_name: str,
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Creates a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_dto(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Gets a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
changes: BoardChanges,
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Updates a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Deletes a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> OffsetPaginatedResults[BoardDTO]:
|
||||||
|
"""Gets many boards."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BoardDTO]:
|
||||||
|
"""Gets all boards."""
|
||||||
|
pass
|
||||||
23
invokeai/app/services/boards/boards_common.py
Normal file
23
invokeai/app/services/boards/boards_common.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from ..board_records.board_records_common import BoardRecord
|
||||||
|
|
||||||
|
|
||||||
|
class BoardDTO(BoardRecord):
|
||||||
|
"""Deserialized board record with cover image URL and image count."""
|
||||||
|
|
||||||
|
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
||||||
|
"""The URL of the thumbnail of the most recent image in the board."""
|
||||||
|
image_count: int = Field(description="The number of images in the board.")
|
||||||
|
"""The number of images in the board."""
|
||||||
|
|
||||||
|
|
||||||
|
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
||||||
|
"""Converts a board record to a board DTO."""
|
||||||
|
return BoardDTO(
|
||||||
|
**board_record.model_dump(exclude={"cover_image_name"}),
|
||||||
|
cover_image_name=cover_image_name,
|
||||||
|
image_count=image_count,
|
||||||
|
)
|
||||||
79
invokeai/app/services/boards/boards_default.py
Normal file
79
invokeai/app/services/boards/boards_default.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||||
|
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
|
from .boards_base import BoardServiceABC
|
||||||
|
from .boards_common import board_record_to_dto
|
||||||
|
|
||||||
|
|
||||||
|
class BoardService(BoardServiceABC):
|
||||||
|
__invoker: Invoker
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self.__invoker = invoker
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
board_name: str,
|
||||||
|
) -> BoardDTO:
|
||||||
|
board_record = self.__invoker.services.board_records.save(board_name)
|
||||||
|
return board_record_to_dto(board_record, None, 0)
|
||||||
|
|
||||||
|
def get_dto(self, board_id: str) -> BoardDTO:
|
||||||
|
board_record = self.__invoker.services.board_records.get(board_id)
|
||||||
|
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
|
||||||
|
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
changes: BoardChanges,
|
||||||
|
) -> BoardDTO:
|
||||||
|
board_record = self.__invoker.services.board_records.update(board_id, changes)
|
||||||
|
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
|
||||||
|
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
|
||||||
|
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||||
|
|
||||||
|
def delete(self, board_id: str) -> None:
|
||||||
|
self.__invoker.services.board_records.delete(board_id)
|
||||||
|
|
||||||
|
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
||||||
|
board_records = self.__invoker.services.board_records.get_many(offset, limit)
|
||||||
|
board_dtos = []
|
||||||
|
for r in board_records.items:
|
||||||
|
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
|
||||||
|
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
|
||||||
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
|
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
||||||
|
|
||||||
|
def get_all(self) -> list[BoardDTO]:
|
||||||
|
board_records = self.__invoker.services.board_records.get_all()
|
||||||
|
board_dtos = []
|
||||||
|
for r in board_records:
|
||||||
|
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
|
||||||
|
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
|
||||||
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
|
return board_dtos
|
||||||
@@ -2,5 +2,5 @@
|
|||||||
Init file for InvokeAI configure package
|
Init file for InvokeAI configure package
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .base import PagingArgumentParser # noqa F401
|
from .config_base import PagingArgumentParser # noqa F401
|
||||||
from .invokeai_config import InvokeAIAppConfig, get_invokeai_config # noqa F401
|
from .config_default import InvokeAIAppConfig, get_invokeai_config # noqa F401
|
||||||
|
|||||||
@@ -12,25 +12,15 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import pydoc
|
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||||
from pydantic import BaseSettings
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
||||||
class PagingArgumentParser(argparse.ArgumentParser):
|
|
||||||
"""
|
|
||||||
A custom ArgumentParser that uses pydoc to page its output.
|
|
||||||
It also supports reading defaults from an init file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def print_help(self, file=None):
|
|
||||||
text = self.format_help()
|
|
||||||
pydoc.pager(text)
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAISettings(BaseSettings):
|
class InvokeAISettings(BaseSettings):
|
||||||
@@ -42,12 +32,14 @@ class InvokeAISettings(BaseSettings):
|
|||||||
initconf: ClassVar[Optional[DictConfig]] = None
|
initconf: ClassVar[Optional[DictConfig]] = None
|
||||||
argparse_groups: ClassVar[Dict] = {}
|
argparse_groups: ClassVar[Dict] = {}
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
||||||
|
|
||||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||||
parser = self.get_parser()
|
parser = self.get_parser()
|
||||||
opt, unknown_opts = parser.parse_known_args(argv)
|
opt, unknown_opts = parser.parse_known_args(argv)
|
||||||
if len(unknown_opts) > 0:
|
if len(unknown_opts) > 0:
|
||||||
print("Unknown args:", unknown_opts)
|
print("Unknown args:", unknown_opts)
|
||||||
for name in self.__fields__:
|
for name in self.model_fields:
|
||||||
if name not in self._excluded():
|
if name not in self._excluded():
|
||||||
value = getattr(opt, name)
|
value = getattr(opt, name)
|
||||||
if isinstance(value, ListConfig):
|
if isinstance(value, ListConfig):
|
||||||
@@ -64,10 +56,12 @@ class InvokeAISettings(BaseSettings):
|
|||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
type = get_args(get_type_hints(cls)["type"])[0]
|
type = get_args(get_type_hints(cls)["type"])[0]
|
||||||
field_dict = dict({type: dict()})
|
field_dict = dict({type: dict()})
|
||||||
for name, field in self.__fields__.items():
|
for name, field in self.model_fields.items():
|
||||||
if name in cls._excluded_from_yaml():
|
if name in cls._excluded_from_yaml():
|
||||||
continue
|
continue
|
||||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
category = (
|
||||||
|
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
|
||||||
|
)
|
||||||
value = getattr(self, name)
|
value = getattr(self, name)
|
||||||
if category not in field_dict[type]:
|
if category not in field_dict[type]:
|
||||||
field_dict[type][category] = dict()
|
field_dict[type][category] = dict()
|
||||||
@@ -83,7 +77,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
else:
|
else:
|
||||||
settings_stanza = "Uncategorized"
|
settings_stanza = "Uncategorized"
|
||||||
|
|
||||||
env_prefix = getattr(cls.Config, "env_prefix", None)
|
env_prefix = getattr(cls.model_config, "env_prefix", None)
|
||||||
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
|
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
|
||||||
|
|
||||||
initconf = (
|
initconf = (
|
||||||
@@ -99,14 +93,18 @@ class InvokeAISettings(BaseSettings):
|
|||||||
for key, value in os.environ.items():
|
for key, value in os.environ.items():
|
||||||
upcase_environ[key.upper()] = value
|
upcase_environ[key.upper()] = value
|
||||||
|
|
||||||
fields = cls.__fields__
|
fields = cls.model_fields
|
||||||
cls.argparse_groups = {}
|
cls.argparse_groups = {}
|
||||||
|
|
||||||
for name, field in fields.items():
|
for name, field in fields.items():
|
||||||
if name not in cls._excluded():
|
if name not in cls._excluded():
|
||||||
current_default = field.default
|
current_default = field.default
|
||||||
|
|
||||||
category = field.field_info.extra.get("category", "Uncategorized")
|
category = (
|
||||||
|
field.json_schema_extra.get("category", "Uncategorized")
|
||||||
|
if field.json_schema_extra
|
||||||
|
else "Uncategorized"
|
||||||
|
)
|
||||||
env_name = env_prefix + "_" + name
|
env_name = env_prefix + "_" + name
|
||||||
if category in initconf and name in initconf.get(category):
|
if category in initconf and name in initconf.get(category):
|
||||||
field.default = initconf.get(category).get(name)
|
field.default = initconf.get(category).get(name)
|
||||||
@@ -156,11 +154,6 @@ class InvokeAISettings(BaseSettings):
|
|||||||
"tiled_decode",
|
"tiled_decode",
|
||||||
]
|
]
|
||||||
|
|
||||||
class Config:
|
|
||||||
env_file_encoding = "utf-8"
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
case_sensitive = True
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||||
field_type = get_type_hints(cls).get(name)
|
field_type = get_type_hints(cls).get(name)
|
||||||
@@ -171,7 +164,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
if field.default_factory is None
|
if field.default_factory is None
|
||||||
else field.default_factory()
|
else field.default_factory()
|
||||||
)
|
)
|
||||||
if category := field.field_info.extra.get("category"):
|
if category := (field.json_schema_extra.get("category", None) if field.json_schema_extra else None):
|
||||||
if category not in cls.argparse_groups:
|
if category not in cls.argparse_groups:
|
||||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||||
argparse_group = cls.argparse_groups[category]
|
argparse_group = cls.argparse_groups[category]
|
||||||
@@ -179,7 +172,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
argparse_group = command_parser
|
argparse_group = command_parser
|
||||||
|
|
||||||
if get_origin(field_type) == Literal:
|
if get_origin(field_type) == Literal:
|
||||||
allowed_values = get_args(field.type_)
|
allowed_values = get_args(field.annotation)
|
||||||
allowed_types = set()
|
allowed_types = set()
|
||||||
for val in allowed_values:
|
for val in allowed_values:
|
||||||
allowed_types.add(type(val))
|
allowed_types.add(type(val))
|
||||||
@@ -192,7 +185,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
type=field_type,
|
type=field_type,
|
||||||
default=default,
|
default=default,
|
||||||
choices=allowed_values,
|
choices=allowed_values,
|
||||||
help=field.field_info.description,
|
help=field.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif get_origin(field_type) == Union:
|
elif get_origin(field_type) == Union:
|
||||||
@@ -201,7 +194,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
dest=name,
|
dest=name,
|
||||||
type=int_or_float_or_str,
|
type=int_or_float_or_str,
|
||||||
default=default,
|
default=default,
|
||||||
help=field.field_info.description,
|
help=field.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif get_origin(field_type) == list:
|
elif get_origin(field_type) == list:
|
||||||
@@ -209,32 +202,17 @@ class InvokeAISettings(BaseSettings):
|
|||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
nargs="*",
|
nargs="*",
|
||||||
type=field.type_,
|
type=field.annotation,
|
||||||
default=default,
|
default=default,
|
||||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
|
||||||
help=field.field_info.description,
|
help=field.description,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
argparse_group.add_argument(
|
argparse_group.add_argument(
|
||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
type=field.type_,
|
type=field.annotation,
|
||||||
default=default,
|
default=default,
|
||||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
|
||||||
help=field.field_info.description,
|
help=field.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
|
||||||
"""
|
|
||||||
Workaround for argparse type checking.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return int(value)
|
|
||||||
except Exception as e: # noqa F841
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
return float(value)
|
|
||||||
except Exception as e: # noqa F841
|
|
||||||
pass
|
|
||||||
return str(value)
|
|
||||||
41
invokeai/app/services/config/config_common.py
Normal file
41
invokeai/app/services/config/config_common.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||||
|
|
||||||
|
"""
|
||||||
|
Base class for the InvokeAI configuration system.
|
||||||
|
It defines a type of pydantic BaseSettings object that
|
||||||
|
is able to read and write from an omegaconf-based config file,
|
||||||
|
with overriding of settings from environment variables and/or
|
||||||
|
the command line.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import pydoc
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
class PagingArgumentParser(argparse.ArgumentParser):
|
||||||
|
"""
|
||||||
|
A custom ArgumentParser that uses pydoc to page its output.
|
||||||
|
It also supports reading defaults from an init file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def print_help(self, file=None):
|
||||||
|
text = self.format_help()
|
||||||
|
pydoc.pager(text)
|
||||||
|
|
||||||
|
|
||||||
|
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
||||||
|
"""
|
||||||
|
Workaround for argparse type checking.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except Exception as e: # noqa F841
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except Exception as e: # noqa F841
|
||||||
|
pass
|
||||||
|
return str(value)
|
||||||
@@ -144,8 +144,8 @@ which is set to the desired top-level name. For example, to create a
|
|||||||
|
|
||||||
class InvokeBatch(InvokeAISettings):
|
class InvokeBatch(InvokeAISettings):
|
||||||
type: Literal["InvokeBatch"] = "InvokeBatch"
|
type: Literal["InvokeBatch"] = "InvokeBatch"
|
||||||
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
|
node_count : int = Field(default=1, description="Number of nodes to run on", json_schema_extra=dict(category='Resources'))
|
||||||
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
|
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", json_schema_extra=dict(category='Resources'))
|
||||||
|
|
||||||
This will now read and write from the "InvokeBatch" section of the
|
This will now read and write from the "InvokeBatch" section of the
|
||||||
config file, look for environment variables named INVOKEBATCH_*, and
|
config file, look for environment variables named INVOKEBATCH_*, and
|
||||||
@@ -175,9 +175,10 @@ from pathlib import Path
|
|||||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic import Field, parse_obj_as
|
from pydantic import Field, TypeAdapter
|
||||||
|
from pydantic_settings import SettingsConfigDict
|
||||||
|
|
||||||
from .base import InvokeAISettings
|
from .config_base import InvokeAISettings
|
||||||
|
|
||||||
INIT_FILE = Path("invokeai.yaml")
|
INIT_FILE = Path("invokeai.yaml")
|
||||||
DB_FILE = Path("invokeai.db")
|
DB_FILE = Path("invokeai.db")
|
||||||
@@ -185,6 +186,21 @@ LEGACY_INIT_FILE = Path("invokeai.init")
|
|||||||
DEFAULT_MAX_VRAM = 0.5
|
DEFAULT_MAX_VRAM = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class Categories(object):
|
||||||
|
WebServer = dict(category="Web Server")
|
||||||
|
Features = dict(category="Features")
|
||||||
|
Paths = dict(category="Paths")
|
||||||
|
Logging = dict(category="Logging")
|
||||||
|
Development = dict(category="Development")
|
||||||
|
Other = dict(category="Other")
|
||||||
|
ModelCache = dict(category="Model Cache")
|
||||||
|
Device = dict(category="Device")
|
||||||
|
Generation = dict(category="Generation")
|
||||||
|
Queue = dict(category="Queue")
|
||||||
|
Nodes = dict(category="Nodes")
|
||||||
|
MemoryPerformance = dict(category="Memory/Performance")
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIAppConfig(InvokeAISettings):
|
class InvokeAIAppConfig(InvokeAISettings):
|
||||||
"""
|
"""
|
||||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||||
@@ -201,86 +217,88 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
type: Literal["InvokeAI"] = "InvokeAI"
|
type: Literal["InvokeAI"] = "InvokeAI"
|
||||||
|
|
||||||
# WEB
|
# WEB
|
||||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
host : str = Field(default="127.0.0.1", description="IP address to bind to", json_schema_extra=Categories.WebServer)
|
||||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
port : int = Field(default=9090, description="Port to bind to", json_schema_extra=Categories.WebServer)
|
||||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", json_schema_extra=Categories.WebServer)
|
||||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Web Server')
|
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
|
||||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||||
|
|
||||||
# FEATURES
|
# FEATURES
|
||||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
|
||||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", json_schema_extra=Categories.Features)
|
||||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", json_schema_extra=Categories.Features)
|
||||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", json_schema_extra=Categories.Features)
|
||||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
|
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', json_schema_extra=Categories.Features)
|
||||||
|
|
||||||
# PATHS
|
# PATHS
|
||||||
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
|
||||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
autoimport_dir : Optional[Path] = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||||
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||||
controlnet_dir : Path = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
|
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
conf_path : Optional[Path] = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||||
models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths')
|
models_dir : Optional[Path] = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
legacy_conf_dir : Optional[Path] = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
db_dir : Optional[Path] = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
outdir : Optional[Path] = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
|
||||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', json_schema_extra=Categories.Logging)
|
||||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', json_schema_extra=Categories.Logging)
|
||||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", json_schema_extra=Categories.Logging)
|
||||||
log_sql : bool = Field(default=False, description="Log SQL queries", category="Logging")
|
log_sql : bool = Field(default=False, description="Log SQL queries", json_schema_extra=Categories.Logging)
|
||||||
|
|
||||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", category="Development")
|
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", json_schema_extra=Categories.Development)
|
||||||
|
|
||||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||||
|
|
||||||
# CACHE
|
# CACHE
|
||||||
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", category="Model Cache", )
|
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||||
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", category="Model Cache", )
|
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
|
||||||
|
|
||||||
# DEVICE
|
# DEVICE
|
||||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", category="Device", )
|
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
|
||||||
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", category="Device", )
|
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
|
||||||
|
|
||||||
# GENERATION
|
# GENERATION
|
||||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
|
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
|
||||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
|
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", json_schema_extra=Categories.Generation)
|
||||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
|
||||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
|
||||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
|
||||||
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", category="Generation", )
|
|
||||||
|
|
||||||
# QUEUE
|
# QUEUE
|
||||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
|
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)
|
||||||
|
|
||||||
# NODES
|
# NODES
|
||||||
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
|
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", json_schema_extra=Categories.Nodes)
|
||||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
|
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
|
||||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", category="Nodes", )
|
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
||||||
|
|
||||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||||
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
||||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
|
||||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
|
||||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
|
||||||
|
|
||||||
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
model_config = SettingsConfigDict(validate_assignment=True, env_prefix="INVOKEAI")
|
||||||
validate_assignment = True
|
|
||||||
env_prefix = "INVOKEAI"
|
|
||||||
|
|
||||||
def parse_args(self, argv: Optional[list[str]] = None, conf: Optional[DictConfig] = None, clobber=False):
|
def parse_args(
|
||||||
|
self,
|
||||||
|
argv: Optional[list[str]] = None,
|
||||||
|
conf: Optional[DictConfig] = None,
|
||||||
|
clobber=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Update settings with contents of init file, environment, and
|
Update settings with contents of init file, environment, and
|
||||||
command-line settings.
|
command-line settings.
|
||||||
@@ -308,7 +326,11 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
if self.singleton_init and not clobber:
|
if self.singleton_init and not clobber:
|
||||||
hints = get_type_hints(self.__class__)
|
hints = get_type_hints(self.__class__)
|
||||||
for k in self.singleton_init:
|
for k in self.singleton_init:
|
||||||
setattr(self, k, parse_obj_as(hints[k], self.singleton_init[k]))
|
setattr(
|
||||||
|
self,
|
||||||
|
k,
|
||||||
|
TypeAdapter(hints[k]).validate_python(self.singleton_init[k]),
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
||||||
0
invokeai/app/services/events/__init__.py
Normal file
0
invokeai/app/services/events/__init__.py
Normal file
@@ -2,8 +2,7 @@
|
|||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from invokeai.app.models.image import ProgressImage
|
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
||||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
|
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
EnqueueBatchResult,
|
EnqueueBatchResult,
|
||||||
@@ -11,6 +10,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
|
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||||
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
@@ -54,7 +55,7 @@ class EventServiceBase:
|
|||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node_id=node.get("id"),
|
node_id=node.get("id"),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
progress_image=progress_image.model_dump() if progress_image is not None else None,
|
||||||
step=step,
|
step=step,
|
||||||
order=order,
|
order=order,
|
||||||
total_steps=total_steps,
|
total_steps=total_steps,
|
||||||
@@ -290,8 +291,8 @@ class EventServiceBase:
|
|||||||
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||||
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||||
),
|
),
|
||||||
batch_status=batch_status.dict(),
|
batch_status=batch_status.model_dump(),
|
||||||
queue_status=queue_status.dict(),
|
queue_status=queue_status.model_dump(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
0
invokeai/app/services/image_files/__init__.py
Normal file
0
invokeai/app/services/image_files/__init__.py
Normal file
43
invokeai/app/services/image_files/image_files_base.py
Normal file
43
invokeai/app/services/image_files/image_files_base.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for storing and retrieving image files."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, image_name: str) -> PILImageType:
|
||||||
|
"""Retrieves an image as PIL Image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
|
||||||
|
"""Gets the internal path to an image or thumbnail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
|
||||||
|
# 500 internal server error. I don't like having this method on the service.
|
||||||
|
@abstractmethod
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
"""Validates the path given for an image or thumbnail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_name: str,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
workflow: Optional[str] = None,
|
||||||
|
thumbnail_size: int = 256,
|
||||||
|
) -> None:
|
||||||
|
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_name: str) -> None:
|
||||||
|
"""Deletes an image and its thumbnail (if one exists)."""
|
||||||
|
pass
|
||||||
20
invokeai/app/services/image_files/image_files_common.py
Normal file
20
invokeai/app/services/image_files/image_files_common.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# TODO: Should these excpetions subclass existing python exceptions?
|
||||||
|
class ImageFileNotFoundException(Exception):
|
||||||
|
"""Raised when an image file is not found in storage."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileSaveException(Exception):
|
||||||
|
"""Raised when an image cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileDeleteException(Exception):
|
||||||
|
"""Raised when an image cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
@@ -9,68 +8,11 @@ from PIL import Image, PngImagePlugin
|
|||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
|
||||||
|
from .image_files_base import ImageFileStorageBase
|
||||||
# TODO: Should these excpetions subclass existing python exceptions?
|
from .image_files_common import ImageFileDeleteException, ImageFileNotFoundException, ImageFileSaveException
|
||||||
class ImageFileNotFoundException(Exception):
|
|
||||||
"""Raised when an image file is not found in storage."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image file not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageFileSaveException(Exception):
|
|
||||||
"""Raised when an image cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image file not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageFileDeleteException(Exception):
|
|
||||||
"""Raised when an image cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image file not deleted"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageFileStorageBase(ABC):
|
|
||||||
"""Low-level service responsible for storing and retrieving image files."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, image_name: str) -> PILImageType:
|
|
||||||
"""Retrieves an image as PIL Image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
"""Gets the internal path to an image or thumbnail."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
|
|
||||||
# 500 internal server error. I don't like having this method on the service.
|
|
||||||
@abstractmethod
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
"""Validates the path given for an image or thumbnail."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
image: PILImageType,
|
|
||||||
image_name: str,
|
|
||||||
metadata: Optional[dict] = None,
|
|
||||||
workflow: Optional[str] = None,
|
|
||||||
thumbnail_size: int = 256,
|
|
||||||
) -> None:
|
|
||||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, image_name: str) -> None:
|
|
||||||
"""Deletes an image and its thumbnail (if one exists)."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DiskImageFileStorage(ImageFileStorageBase):
|
class DiskImageFileStorage(ImageFileStorageBase):
|
||||||
@@ -80,7 +22,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||||
__cache: Dict[Path, PILImageType]
|
__cache: Dict[Path, PILImageType]
|
||||||
__max_cache_size: int
|
__max_cache_size: int
|
||||||
__compress_level: int
|
__invoker: Invoker
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
self.__cache = dict()
|
self.__cache = dict()
|
||||||
@@ -89,10 +31,12 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
|
|
||||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||||
self.__compress_level = InvokeAIAppConfig.get_config().png_compress_level
|
|
||||||
# Validate required output folders at launch
|
# Validate required output folders at launch
|
||||||
self.__validate_storage_folders()
|
self.__validate_storage_folders()
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self.__invoker = invoker
|
||||||
|
|
||||||
def get(self, image_name: str) -> PILImageType:
|
def get(self, image_name: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_name)
|
image_path = self.get_path(image_name)
|
||||||
@@ -136,7 +80,12 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
if original_workflow is not None:
|
if original_workflow is not None:
|
||||||
pnginfo.add_text("invokeai_workflow", original_workflow)
|
pnginfo.add_text("invokeai_workflow", original_workflow)
|
||||||
|
|
||||||
image.save(image_path, "PNG", pnginfo=pnginfo, compress_level=self.__compress_level)
|
image.save(
|
||||||
|
image_path,
|
||||||
|
"PNG",
|
||||||
|
pnginfo=pnginfo,
|
||||||
|
compress_level=self.__invoker.services.configuration.png_compress_level,
|
||||||
|
)
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
||||||
0
invokeai/app/services/image_records/__init__.py
Normal file
0
invokeai/app/services/image_records/__init__.py
Normal file
84
invokeai/app/services/image_records/image_records_base.py
Normal file
84
invokeai/app/services/image_records/image_records_base.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
|
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for interfacing with the image record store."""
|
||||||
|
|
||||||
|
# TODO: Implement an `update()` method
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, image_name: str) -> ImageRecord:
|
||||||
|
"""Gets an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||||
|
"""Gets an image's metadata'."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> None:
|
||||||
|
"""Updates an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
|
is_intermediate: Optional[bool] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
) -> OffsetPaginatedResults[ImageRecord]:
|
||||||
|
"""Gets a page of image records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||||
|
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_name: str) -> None:
|
||||||
|
"""Deletes an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_many(self, image_names: list[str]) -> None:
|
||||||
|
"""Deletes many image records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_intermediates(self) -> list[str]:
|
||||||
|
"""Deletes all intermediate image records, returning a list of deleted image names."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
image_origin: ResourceOrigin,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
is_intermediate: Optional[bool] = False,
|
||||||
|
starred: Optional[bool] = False,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
) -> datetime:
|
||||||
|
"""Saves an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||||
|
"""Gets the most recent image for a board."""
|
||||||
|
pass
|
||||||
@@ -1,13 +1,117 @@
|
|||||||
|
# TODO: Should these excpetions subclass existing python exceptions?
|
||||||
import datetime
|
import datetime
|
||||||
|
from enum import Enum
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import Extra, Field, StrictBool, StrictStr
|
from pydantic import Field, StrictBool, StrictStr
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""The origin of a resource (eg image).
|
||||||
|
|
||||||
|
- INTERNAL: The resource was created by the application.
|
||||||
|
- EXTERNAL: The resource was not created by the application.
|
||||||
|
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||||
|
"""
|
||||||
|
|
||||||
|
INTERNAL = "internal"
|
||||||
|
"""The resource was created by the application."""
|
||||||
|
EXTERNAL = "external"
|
||||||
|
"""The resource was not created by the application.
|
||||||
|
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidOriginException(ValueError):
|
||||||
|
"""Raised when a provided value is not a valid ResourceOrigin.
|
||||||
|
|
||||||
|
Subclasses `ValueError`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message="Invalid resource origin."):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""The category of an image.
|
||||||
|
|
||||||
|
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
||||||
|
- MASK: The image is a mask image.
|
||||||
|
- CONTROL: The image is a ControlNet control image.
|
||||||
|
- USER: The image is a user-provide image.
|
||||||
|
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
GENERAL = "general"
|
||||||
|
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
||||||
|
MASK = "mask"
|
||||||
|
"""MASK: The image is a mask image."""
|
||||||
|
CONTROL = "control"
|
||||||
|
"""CONTROL: The image is a ControlNet control image."""
|
||||||
|
USER = "user"
|
||||||
|
"""USER: The image is a user-provide image."""
|
||||||
|
OTHER = "other"
|
||||||
|
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidImageCategoryException(ValueError):
|
||||||
|
"""Raised when a provided value is not a valid ImageCategory.
|
||||||
|
|
||||||
|
Subclasses `ValueError`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message="Invalid image category."):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordNotFoundException(Exception):
|
||||||
|
"""Raised when an image record is not found."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image record not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordSaveException(Exception):
|
||||||
|
"""Raised when an image record cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordDeleteException(Exception):
|
||||||
|
"""Raised when an image record cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image record not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE_DTO_COLS = ", ".join(
|
||||||
|
list(
|
||||||
|
map(
|
||||||
|
lambda c: "images." + c,
|
||||||
|
[
|
||||||
|
"image_name",
|
||||||
|
"image_origin",
|
||||||
|
"image_category",
|
||||||
|
"width",
|
||||||
|
"height",
|
||||||
|
"session_id",
|
||||||
|
"node_id",
|
||||||
|
"is_intermediate",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
"deleted_at",
|
||||||
|
"starred",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageRecord(BaseModelExcludeNull):
|
class ImageRecord(BaseModelExcludeNull):
|
||||||
"""Deserialized image record without metadata."""
|
"""Deserialized image record without metadata."""
|
||||||
|
|
||||||
@@ -25,7 +129,9 @@ class ImageRecord(BaseModelExcludeNull):
|
|||||||
"""The created timestamp of the image."""
|
"""The created timestamp of the image."""
|
||||||
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.")
|
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.")
|
||||||
"""The updated timestamp of the image."""
|
"""The updated timestamp of the image."""
|
||||||
deleted_at: Union[datetime.datetime, str, None] = Field(description="The deleted timestamp of the image.")
|
deleted_at: Optional[Union[datetime.datetime, str]] = Field(
|
||||||
|
default=None, description="The deleted timestamp of the image."
|
||||||
|
)
|
||||||
"""The deleted timestamp of the image."""
|
"""The deleted timestamp of the image."""
|
||||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||||
"""Whether this is an intermediate image."""
|
"""Whether this is an intermediate image."""
|
||||||
@@ -43,7 +149,7 @@ class ImageRecord(BaseModelExcludeNull):
|
|||||||
"""Whether this image is starred."""
|
"""Whether this image is starred."""
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
class ImageRecordChanges(BaseModelExcludeNull, extra="allow"):
|
||||||
"""A set of changes to apply to an image record.
|
"""A set of changes to apply to an image record.
|
||||||
|
|
||||||
Only limited changes are valid:
|
Only limited changes are valid:
|
||||||
@@ -66,41 +172,6 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
|||||||
"""The image's new `starred` state."""
|
"""The image's new `starred` state."""
|
||||||
|
|
||||||
|
|
||||||
class ImageUrlsDTO(BaseModelExcludeNull):
|
|
||||||
"""The URLs for an image and its thumbnail."""
|
|
||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
|
||||||
"""The unique name of the image."""
|
|
||||||
image_url: str = Field(description="The URL of the image.")
|
|
||||||
"""The URL of the image."""
|
|
||||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
|
||||||
"""The URL of the image's thumbnail."""
|
|
||||||
|
|
||||||
|
|
||||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
|
||||||
"""Deserialized image record, enriched for the frontend."""
|
|
||||||
|
|
||||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
|
||||||
"""The id of the board the image belongs to, if one exists."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def image_record_to_dto(
|
|
||||||
image_record: ImageRecord,
|
|
||||||
image_url: str,
|
|
||||||
thumbnail_url: str,
|
|
||||||
board_id: Optional[str],
|
|
||||||
) -> ImageDTO:
|
|
||||||
"""Converts an image record to an image DTO."""
|
|
||||||
return ImageDTO(
|
|
||||||
**image_record.dict(),
|
|
||||||
image_url=image_url,
|
|
||||||
thumbnail_url=thumbnail_url,
|
|
||||||
board_id=board_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||||
"""Deserializes an image record."""
|
"""Deserializes an image record."""
|
||||||
|
|
||||||
@@ -1,164 +1,36 @@
|
|||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Generic, Optional, TypeVar, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
from pydantic.generics import GenericModel
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from .image_records_base import ImageRecordStorageBase
|
||||||
from invokeai.app.services.models.image_record import ImageRecord, ImageRecordChanges, deserialize_image_record
|
from .image_records_common import (
|
||||||
|
IMAGE_DTO_COLS,
|
||||||
T = TypeVar("T", bound=BaseModel)
|
ImageCategory,
|
||||||
|
ImageRecord,
|
||||||
|
ImageRecordChanges,
|
||||||
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
ImageRecordDeleteException,
|
||||||
"""Offset-paginated results"""
|
ImageRecordNotFoundException,
|
||||||
|
ImageRecordSaveException,
|
||||||
# fmt: off
|
ResourceOrigin,
|
||||||
items: list[T] = Field(description="Items")
|
deserialize_image_record,
|
||||||
offset: int = Field(description="Offset from which to retrieve items")
|
|
||||||
limit: int = Field(description="Limit of items to get")
|
|
||||||
total: int = Field(description="Total number of items in result")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Should these excpetions subclass existing python exceptions?
|
|
||||||
class ImageRecordNotFoundException(Exception):
|
|
||||||
"""Raised when an image record is not found."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image record not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordSaveException(Exception):
|
|
||||||
"""Raised when an image record cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image record not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordDeleteException(Exception):
|
|
||||||
"""Raised when an image record cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image record not deleted"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
IMAGE_DTO_COLS = ", ".join(
|
|
||||||
list(
|
|
||||||
map(
|
|
||||||
lambda c: "images." + c,
|
|
||||||
[
|
|
||||||
"image_name",
|
|
||||||
"image_origin",
|
|
||||||
"image_category",
|
|
||||||
"width",
|
|
||||||
"height",
|
|
||||||
"session_id",
|
|
||||||
"node_id",
|
|
||||||
"is_intermediate",
|
|
||||||
"created_at",
|
|
||||||
"updated_at",
|
|
||||||
"deleted_at",
|
|
||||||
"starred",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordStorageBase(ABC):
|
|
||||||
"""Low-level service responsible for interfacing with the image record store."""
|
|
||||||
|
|
||||||
# TODO: Implement an `update()` method
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, image_name: str) -> ImageRecord:
|
|
||||||
"""Gets an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
|
||||||
"""Gets an image's metadata'."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
changes: ImageRecordChanges,
|
|
||||||
) -> None:
|
|
||||||
"""Updates an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: Optional[int] = None,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
image_origin: Optional[ResourceOrigin] = None,
|
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
|
||||||
is_intermediate: Optional[bool] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
) -> OffsetPaginatedResults[ImageRecord]:
|
|
||||||
"""Gets a page of image records."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
|
||||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, image_name: str) -> None:
|
|
||||||
"""Deletes an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_many(self, image_names: list[str]) -> None:
|
|
||||||
"""Deletes many image records."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_intermediates(self) -> list[str]:
|
|
||||||
"""Deletes all intermediate image records, returning a list of deleted image names."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
image_category: ImageCategory,
|
|
||||||
width: int,
|
|
||||||
height: int,
|
|
||||||
session_id: Optional[str],
|
|
||||||
node_id: Optional[str],
|
|
||||||
metadata: Optional[dict],
|
|
||||||
is_intermediate: bool = False,
|
|
||||||
starred: bool = False,
|
|
||||||
) -> datetime:
|
|
||||||
"""Saves an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
|
||||||
"""Gets the most recent image for a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.RLock
|
_lock: threading.RLock
|
||||||
|
|
||||||
def __init__(self, conn: sqlite3.Connection, lock: threading.RLock) -> None:
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._conn = conn
|
self._lock = db.lock
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
self._conn = db.conn
|
||||||
self._conn.row_factory = sqlite3.Row
|
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = lock
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
@@ -245,7 +117,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(self, image_name: str) -> Optional[ImageRecord]:
|
def get(self, image_name: str) -> ImageRecord:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
||||||
@@ -351,8 +223,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
offset: Optional[int] = None,
|
offset: int = 0,
|
||||||
limit: Optional[int] = None,
|
limit: int = 10,
|
||||||
image_origin: Optional[ResourceOrigin] = None,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
is_intermediate: Optional[bool] = None,
|
is_intermediate: Optional[bool] = None,
|
||||||
@@ -377,7 +249,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
query_conditions = ""
|
query_conditions = ""
|
||||||
query_params = []
|
query_params: list[Union[int, str, bool]] = []
|
||||||
|
|
||||||
if image_origin is not None:
|
if image_origin is not None:
|
||||||
query_conditions += """--sql
|
query_conditions += """--sql
|
||||||
@@ -515,13 +387,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
image_name: str,
|
image_name: str,
|
||||||
image_origin: ResourceOrigin,
|
image_origin: ResourceOrigin,
|
||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
session_id: Optional[str],
|
|
||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
node_id: Optional[str],
|
is_intermediate: Optional[bool] = False,
|
||||||
metadata: Optional[dict],
|
starred: Optional[bool] = False,
|
||||||
is_intermediate: bool = False,
|
session_id: Optional[str] = None,
|
||||||
starred: bool = False,
|
node_id: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
try:
|
try:
|
||||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
metadata_json = None if metadata is None else json.dumps(metadata)
|
||||||
@@ -1,449 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from logging import Logger
|
|
||||||
from typing import TYPE_CHECKING, Callable, Optional
|
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
|
||||||
from invokeai.app.models.image import (
|
|
||||||
ImageCategory,
|
|
||||||
InvalidImageCategoryException,
|
|
||||||
InvalidOriginException,
|
|
||||||
ResourceOrigin,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
|
||||||
from invokeai.app.services.image_file_storage import (
|
|
||||||
ImageFileDeleteException,
|
|
||||||
ImageFileNotFoundException,
|
|
||||||
ImageFileSaveException,
|
|
||||||
ImageFileStorageBase,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.image_record_storage import (
|
|
||||||
ImageRecordDeleteException,
|
|
||||||
ImageRecordNotFoundException,
|
|
||||||
ImageRecordSaveException,
|
|
||||||
ImageRecordStorageBase,
|
|
||||||
OffsetPaginatedResults,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.item_storage import ItemStorageABC
|
|
||||||
from invokeai.app.services.models.image_record import ImageDTO, ImageRecord, ImageRecordChanges, image_record_to_dto
|
|
||||||
from invokeai.app.services.resource_name import NameServiceBase
|
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
|
||||||
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.services.graph import GraphExecutionState
|
|
||||||
|
|
||||||
|
|
||||||
class ImageServiceABC(ABC):
|
|
||||||
"""High-level service for image management."""
|
|
||||||
|
|
||||||
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
|
|
||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._on_changed_callbacks = list()
|
|
||||||
self._on_deleted_callbacks = list()
|
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
|
||||||
"""Register a callback for when an image is changed"""
|
|
||||||
self._on_changed_callbacks.append(on_changed)
|
|
||||||
|
|
||||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
|
||||||
"""Register a callback for when an image is deleted"""
|
|
||||||
self._on_deleted_callbacks.append(on_deleted)
|
|
||||||
|
|
||||||
def _on_changed(self, item: ImageDTO) -> None:
|
|
||||||
for callback in self._on_changed_callbacks:
|
|
||||||
callback(item)
|
|
||||||
|
|
||||||
def _on_deleted(self, item_id: str) -> None:
|
|
||||||
for callback in self._on_deleted_callbacks:
|
|
||||||
callback(item_id)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
image: PILImageType,
|
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
image_category: ImageCategory,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
session_id: Optional[str] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
is_intermediate: bool = False,
|
|
||||||
metadata: Optional[dict] = None,
|
|
||||||
workflow: Optional[str] = None,
|
|
||||||
) -> ImageDTO:
|
|
||||||
"""Creates an image, storing the file and its metadata."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
changes: ImageRecordChanges,
|
|
||||||
) -> ImageDTO:
|
|
||||||
"""Updates an image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
|
||||||
"""Gets an image as a PIL image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_record(self, image_name: str) -> ImageRecord:
|
|
||||||
"""Gets an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_dto(self, image_name: str) -> ImageDTO:
|
|
||||||
"""Gets an image DTO."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
|
||||||
"""Gets an image's metadata."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
"""Gets an image's path."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
"""Validates an image's path."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
"""Gets an image's or thumbnail's URL."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: int = 0,
|
|
||||||
limit: int = 10,
|
|
||||||
image_origin: Optional[ResourceOrigin] = None,
|
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
|
||||||
is_intermediate: Optional[bool] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
) -> OffsetPaginatedResults[ImageDTO]:
|
|
||||||
"""Gets a paginated list of image DTOs."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, image_name: str):
|
|
||||||
"""Deletes an image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_intermediates(self) -> int:
|
|
||||||
"""Deletes all intermediate images."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_images_on_board(self, board_id: str):
|
|
||||||
"""Deletes all images on a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ImageServiceDependencies:
|
|
||||||
"""Service dependencies for the ImageService."""
|
|
||||||
|
|
||||||
image_records: ImageRecordStorageBase
|
|
||||||
image_files: ImageFileStorageBase
|
|
||||||
board_image_records: BoardImageRecordStorageBase
|
|
||||||
urls: UrlServiceBase
|
|
||||||
logger: Logger
|
|
||||||
names: NameServiceBase
|
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
image_record_storage: ImageRecordStorageBase,
|
|
||||||
image_file_storage: ImageFileStorageBase,
|
|
||||||
board_image_record_storage: BoardImageRecordStorageBase,
|
|
||||||
url: UrlServiceBase,
|
|
||||||
logger: Logger,
|
|
||||||
names: NameServiceBase,
|
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
|
||||||
):
|
|
||||||
self.image_records = image_record_storage
|
|
||||||
self.image_files = image_file_storage
|
|
||||||
self.board_image_records = board_image_record_storage
|
|
||||||
self.urls = url
|
|
||||||
self.logger = logger
|
|
||||||
self.names = names
|
|
||||||
self.graph_execution_manager = graph_execution_manager
|
|
||||||
|
|
||||||
|
|
||||||
class ImageService(ImageServiceABC):
|
|
||||||
_services: ImageServiceDependencies
|
|
||||||
|
|
||||||
def __init__(self, services: ImageServiceDependencies):
|
|
||||||
super().__init__()
|
|
||||||
self._services = services
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
image: PILImageType,
|
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
image_category: ImageCategory,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
session_id: Optional[str] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
is_intermediate: bool = False,
|
|
||||||
metadata: Optional[dict] = None,
|
|
||||||
workflow: Optional[str] = None,
|
|
||||||
) -> ImageDTO:
|
|
||||||
if image_origin not in ResourceOrigin:
|
|
||||||
raise InvalidOriginException
|
|
||||||
|
|
||||||
if image_category not in ImageCategory:
|
|
||||||
raise InvalidImageCategoryException
|
|
||||||
|
|
||||||
image_name = self._services.names.create_image_name()
|
|
||||||
|
|
||||||
# TODO: Do we want to store the graph in the image at all? I don't think so...
|
|
||||||
# graph = None
|
|
||||||
# if session_id is not None:
|
|
||||||
# session_raw = self._services.graph_execution_manager.get_raw(session_id)
|
|
||||||
# if session_raw is not None:
|
|
||||||
# try:
|
|
||||||
# graph = get_metadata_graph_from_raw_session(session_raw)
|
|
||||||
# except Exception as e:
|
|
||||||
# self._services.logger.warn(f"Failed to parse session graph: {e}")
|
|
||||||
# graph = None
|
|
||||||
|
|
||||||
(width, height) = image.size
|
|
||||||
|
|
||||||
try:
|
|
||||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
|
||||||
self._services.image_records.save(
|
|
||||||
# Non-nullable fields
|
|
||||||
image_name=image_name,
|
|
||||||
image_origin=image_origin,
|
|
||||||
image_category=image_category,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
# Meta fields
|
|
||||||
is_intermediate=is_intermediate,
|
|
||||||
# Nullable fields
|
|
||||||
node_id=node_id,
|
|
||||||
metadata=metadata,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
if board_id is not None:
|
|
||||||
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
|
||||||
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
|
|
||||||
image_dto = self.get_dto(image_name)
|
|
||||||
|
|
||||||
self._on_changed(image_dto)
|
|
||||||
return image_dto
|
|
||||||
except ImageRecordSaveException:
|
|
||||||
self._services.logger.error("Failed to save image record")
|
|
||||||
raise
|
|
||||||
except ImageFileSaveException:
|
|
||||||
self._services.logger.error("Failed to save image file")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error(f"Problem saving image record and file: {str(e)}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
changes: ImageRecordChanges,
|
|
||||||
) -> ImageDTO:
|
|
||||||
try:
|
|
||||||
self._services.image_records.update(image_name, changes)
|
|
||||||
image_dto = self.get_dto(image_name)
|
|
||||||
self._on_changed(image_dto)
|
|
||||||
return image_dto
|
|
||||||
except ImageRecordSaveException:
|
|
||||||
self._services.logger.error("Failed to update image record")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem updating image record")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
|
||||||
try:
|
|
||||||
return self._services.image_files.get(image_name)
|
|
||||||
except ImageFileNotFoundException:
|
|
||||||
self._services.logger.error("Failed to get image file")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem getting image file")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_record(self, image_name: str) -> ImageRecord:
|
|
||||||
try:
|
|
||||||
return self._services.image_records.get(image_name)
|
|
||||||
except ImageRecordNotFoundException:
|
|
||||||
self._services.logger.error("Image record not found")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem getting image record")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_dto(self, image_name: str) -> ImageDTO:
|
|
||||||
try:
|
|
||||||
image_record = self._services.image_records.get(image_name)
|
|
||||||
|
|
||||||
image_dto = image_record_to_dto(
|
|
||||||
image_record,
|
|
||||||
self._services.urls.get_image_url(image_name),
|
|
||||||
self._services.urls.get_image_url(image_name, True),
|
|
||||||
self._services.board_image_records.get_board_for_image(image_name),
|
|
||||||
)
|
|
||||||
|
|
||||||
return image_dto
|
|
||||||
except ImageRecordNotFoundException:
|
|
||||||
self._services.logger.error("Image record not found")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem getting image DTO")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
|
||||||
try:
|
|
||||||
image_record = self._services.image_records.get(image_name)
|
|
||||||
metadata = self._services.image_records.get_metadata(image_name)
|
|
||||||
|
|
||||||
if not image_record.session_id:
|
|
||||||
return ImageMetadata(metadata=metadata)
|
|
||||||
|
|
||||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
|
||||||
graph = None
|
|
||||||
|
|
||||||
if session_raw:
|
|
||||||
try:
|
|
||||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
|
||||||
graph = None
|
|
||||||
|
|
||||||
return ImageMetadata(graph=graph, metadata=metadata)
|
|
||||||
except ImageRecordNotFoundException:
|
|
||||||
self._services.logger.error("Image record not found")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem getting image DTO")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
try:
|
|
||||||
return self._services.image_files.get_path(image_name, thumbnail)
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem getting image path")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
try:
|
|
||||||
return self._services.image_files.validate_path(path)
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem validating image path")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
try:
|
|
||||||
return self._services.urls.get_image_url(image_name, thumbnail)
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem getting image path")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: int = 0,
|
|
||||||
limit: int = 10,
|
|
||||||
image_origin: Optional[ResourceOrigin] = None,
|
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
|
||||||
is_intermediate: Optional[bool] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
) -> OffsetPaginatedResults[ImageDTO]:
|
|
||||||
try:
|
|
||||||
results = self._services.image_records.get_many(
|
|
||||||
offset,
|
|
||||||
limit,
|
|
||||||
image_origin,
|
|
||||||
categories,
|
|
||||||
is_intermediate,
|
|
||||||
board_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_dtos = list(
|
|
||||||
map(
|
|
||||||
lambda r: image_record_to_dto(
|
|
||||||
r,
|
|
||||||
self._services.urls.get_image_url(r.image_name),
|
|
||||||
self._services.urls.get_image_url(r.image_name, True),
|
|
||||||
self._services.board_image_records.get_board_for_image(r.image_name),
|
|
||||||
),
|
|
||||||
results.items,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return OffsetPaginatedResults[ImageDTO](
|
|
||||||
items=image_dtos,
|
|
||||||
offset=results.offset,
|
|
||||||
limit=results.limit,
|
|
||||||
total=results.total,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem getting paginated image DTOs")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def delete(self, image_name: str):
|
|
||||||
try:
|
|
||||||
self._services.image_files.delete(image_name)
|
|
||||||
self._services.image_records.delete(image_name)
|
|
||||||
self._on_deleted(image_name)
|
|
||||||
except ImageRecordDeleteException:
|
|
||||||
self._services.logger.error("Failed to delete image record")
|
|
||||||
raise
|
|
||||||
except ImageFileDeleteException:
|
|
||||||
self._services.logger.error("Failed to delete image file")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem deleting image record and file")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def delete_images_on_board(self, board_id: str):
|
|
||||||
try:
|
|
||||||
image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
|
||||||
for image_name in image_names:
|
|
||||||
self._services.image_files.delete(image_name)
|
|
||||||
self._services.image_records.delete_many(image_names)
|
|
||||||
for image_name in image_names:
|
|
||||||
self._on_deleted(image_name)
|
|
||||||
except ImageRecordDeleteException:
|
|
||||||
self._services.logger.error("Failed to delete image records")
|
|
||||||
raise
|
|
||||||
except ImageFileDeleteException:
|
|
||||||
self._services.logger.error("Failed to delete image files")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem deleting image records and files")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def delete_intermediates(self) -> int:
|
|
||||||
try:
|
|
||||||
image_names = self._services.image_records.delete_intermediates()
|
|
||||||
count = len(image_names)
|
|
||||||
for image_name in image_names:
|
|
||||||
self._services.image_files.delete(image_name)
|
|
||||||
self._on_deleted(image_name)
|
|
||||||
return count
|
|
||||||
except ImageRecordDeleteException:
|
|
||||||
self._services.logger.error("Failed to delete image records")
|
|
||||||
raise
|
|
||||||
except ImageFileDeleteException:
|
|
||||||
self._services.logger.error("Failed to delete image files")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self._services.logger.error("Problem deleting image records and files")
|
|
||||||
raise e
|
|
||||||
0
invokeai/app/services/images/__init__.py
Normal file
0
invokeai/app/services/images/__init__.py
Normal file
129
invokeai/app/services/images/images_base.py
Normal file
129
invokeai/app/services/images/images_base.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
|
from invokeai.app.services.image_records.image_records_common import (
|
||||||
|
ImageCategory,
|
||||||
|
ImageRecord,
|
||||||
|
ImageRecordChanges,
|
||||||
|
ResourceOrigin,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
|
|
||||||
|
class ImageServiceABC(ABC):
|
||||||
|
"""High-level service for image management."""
|
||||||
|
|
||||||
|
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
|
||||||
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._on_changed_callbacks = list()
|
||||||
|
self._on_deleted_callbacks = list()
|
||||||
|
|
||||||
|
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||||
|
"""Register a callback for when an image is changed"""
|
||||||
|
self._on_changed_callbacks.append(on_changed)
|
||||||
|
|
||||||
|
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||||
|
"""Register a callback for when an image is deleted"""
|
||||||
|
self._on_deleted_callbacks.append(on_deleted)
|
||||||
|
|
||||||
|
def _on_changed(self, item: ImageDTO) -> None:
|
||||||
|
for callback in self._on_changed_callbacks:
|
||||||
|
callback(item)
|
||||||
|
|
||||||
|
def _on_deleted(self, item_id: str) -> None:
|
||||||
|
for callback in self._on_deleted_callbacks:
|
||||||
|
callback(item_id)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_origin: ResourceOrigin,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
is_intermediate: Optional[bool] = False,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
workflow: Optional[str] = None,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Creates an image, storing the file and its metadata."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Updates an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||||
|
"""Gets an image as a PIL image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_record(self, image_name: str) -> ImageRecord:
|
||||||
|
"""Gets an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_dto(self, image_name: str) -> ImageDTO:
|
||||||
|
"""Gets an image DTO."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||||
|
"""Gets an image's metadata."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
"""Gets an image's path."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
"""Validates an image's path."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
"""Gets an image's or thumbnail's URL."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
|
is_intermediate: Optional[bool] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
|
"""Gets a paginated list of image DTOs."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_name: str):
|
||||||
|
"""Deletes an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_intermediates(self) -> int:
|
||||||
|
"""Deletes all intermediate images."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_images_on_board(self, board_id: str):
|
||||||
|
"""Deletes all images on a board."""
|
||||||
|
pass
|
||||||
43
invokeai/app/services/images/images_common.py
Normal file
43
invokeai/app/services/images/images_common.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageRecord
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
|
class ImageUrlsDTO(BaseModelExcludeNull):
|
||||||
|
"""The URLs for an image and its thumbnail."""
|
||||||
|
|
||||||
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
|
"""The unique name of the image."""
|
||||||
|
image_url: str = Field(description="The URL of the image.")
|
||||||
|
"""The URL of the image."""
|
||||||
|
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||||
|
"""The URL of the image's thumbnail."""
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||||
|
"""Deserialized image record, enriched for the frontend."""
|
||||||
|
|
||||||
|
board_id: Optional[str] = Field(
|
||||||
|
default=None, description="The id of the board the image belongs to, if one exists."
|
||||||
|
)
|
||||||
|
"""The id of the board the image belongs to, if one exists."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def image_record_to_dto(
|
||||||
|
image_record: ImageRecord,
|
||||||
|
image_url: str,
|
||||||
|
thumbnail_url: str,
|
||||||
|
board_id: Optional[str],
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Converts an image record to an image DTO."""
|
||||||
|
return ImageDTO(
|
||||||
|
**image_record.model_dump(),
|
||||||
|
image_url=image_url,
|
||||||
|
thumbnail_url=thumbnail_url,
|
||||||
|
board_id=board_id,
|
||||||
|
)
|
||||||
286
invokeai/app/services/images/images_default.py
Normal file
286
invokeai/app/services/images/images_default.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
||||||
|
|
||||||
|
from ..image_files.image_files_common import (
|
||||||
|
ImageFileDeleteException,
|
||||||
|
ImageFileNotFoundException,
|
||||||
|
ImageFileSaveException,
|
||||||
|
)
|
||||||
|
from ..image_records.image_records_common import (
|
||||||
|
ImageCategory,
|
||||||
|
ImageRecord,
|
||||||
|
ImageRecordChanges,
|
||||||
|
ImageRecordDeleteException,
|
||||||
|
ImageRecordNotFoundException,
|
||||||
|
ImageRecordSaveException,
|
||||||
|
InvalidImageCategoryException,
|
||||||
|
InvalidOriginException,
|
||||||
|
ResourceOrigin,
|
||||||
|
)
|
||||||
|
from .images_base import ImageServiceABC
|
||||||
|
from .images_common import ImageDTO, image_record_to_dto
|
||||||
|
|
||||||
|
|
||||||
|
class ImageService(ImageServiceABC):
|
||||||
|
__invoker: Invoker
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self.__invoker = invoker
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_origin: ResourceOrigin,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
is_intermediate: Optional[bool] = False,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
workflow: Optional[str] = None,
|
||||||
|
) -> ImageDTO:
|
||||||
|
if image_origin not in ResourceOrigin:
|
||||||
|
raise InvalidOriginException
|
||||||
|
|
||||||
|
if image_category not in ImageCategory:
|
||||||
|
raise InvalidImageCategoryException
|
||||||
|
|
||||||
|
image_name = self.__invoker.services.names.create_image_name()
|
||||||
|
|
||||||
|
(width, height) = image.size
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||||
|
self.__invoker.services.image_records.save(
|
||||||
|
# Non-nullable fields
|
||||||
|
image_name=image_name,
|
||||||
|
image_origin=image_origin,
|
||||||
|
image_category=image_category,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# Meta fields
|
||||||
|
is_intermediate=is_intermediate,
|
||||||
|
# Nullable fields
|
||||||
|
node_id=node_id,
|
||||||
|
metadata=metadata,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
if board_id is not None:
|
||||||
|
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||||
|
self.__invoker.services.image_files.save(
|
||||||
|
image_name=image_name, image=image, metadata=metadata, workflow=workflow
|
||||||
|
)
|
||||||
|
image_dto = self.get_dto(image_name)
|
||||||
|
|
||||||
|
self._on_changed(image_dto)
|
||||||
|
return image_dto
|
||||||
|
except ImageRecordSaveException:
|
||||||
|
self.__invoker.services.logger.error("Failed to save image record")
|
||||||
|
raise
|
||||||
|
except ImageFileSaveException:
|
||||||
|
self.__invoker.services.logger.error("Failed to save image file")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error(f"Problem saving image record and file: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> ImageDTO:
|
||||||
|
try:
|
||||||
|
self.__invoker.services.image_records.update(image_name, changes)
|
||||||
|
image_dto = self.get_dto(image_name)
|
||||||
|
self._on_changed(image_dto)
|
||||||
|
return image_dto
|
||||||
|
except ImageRecordSaveException:
|
||||||
|
self.__invoker.services.logger.error("Failed to update image record")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem updating image record")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||||
|
try:
|
||||||
|
return self.__invoker.services.image_files.get(image_name)
|
||||||
|
except ImageFileNotFoundException:
|
||||||
|
self.__invoker.services.logger.error("Failed to get image file")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem getting image file")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_record(self, image_name: str) -> ImageRecord:
|
||||||
|
try:
|
||||||
|
return self.__invoker.services.image_records.get(image_name)
|
||||||
|
except ImageRecordNotFoundException:
|
||||||
|
self.__invoker.services.logger.error("Image record not found")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem getting image record")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_dto(self, image_name: str) -> ImageDTO:
|
||||||
|
try:
|
||||||
|
image_record = self.__invoker.services.image_records.get(image_name)
|
||||||
|
|
||||||
|
image_dto = image_record_to_dto(
|
||||||
|
image_record,
|
||||||
|
self.__invoker.services.urls.get_image_url(image_name),
|
||||||
|
self.__invoker.services.urls.get_image_url(image_name, True),
|
||||||
|
self.__invoker.services.board_image_records.get_board_for_image(image_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_dto
|
||||||
|
except ImageRecordNotFoundException:
|
||||||
|
self.__invoker.services.logger.error("Image record not found")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||||
|
try:
|
||||||
|
image_record = self.__invoker.services.image_records.get(image_name)
|
||||||
|
metadata = self.__invoker.services.image_records.get_metadata(image_name)
|
||||||
|
|
||||||
|
if not image_record.session_id:
|
||||||
|
return ImageMetadata(metadata=metadata)
|
||||||
|
|
||||||
|
session_raw = self.__invoker.services.graph_execution_manager.get_raw(image_record.session_id)
|
||||||
|
graph = None
|
||||||
|
|
||||||
|
if session_raw:
|
||||||
|
try:
|
||||||
|
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.warn(f"Failed to parse session graph: {e}")
|
||||||
|
graph = None
|
||||||
|
|
||||||
|
return ImageMetadata(graph=graph, metadata=metadata)
|
||||||
|
except ImageRecordNotFoundException:
|
||||||
|
self.__invoker.services.logger.error("Image record not found")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
try:
|
||||||
|
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem getting image path")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
try:
|
||||||
|
return self.__invoker.services.image_files.validate_path(path)
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem validating image path")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
try:
|
||||||
|
return self.__invoker.services.urls.get_image_url(image_name, thumbnail)
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem getting image path")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
|
is_intermediate: Optional[bool] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
|
try:
|
||||||
|
results = self.__invoker.services.image_records.get_many(
|
||||||
|
offset,
|
||||||
|
limit,
|
||||||
|
image_origin,
|
||||||
|
categories,
|
||||||
|
is_intermediate,
|
||||||
|
board_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_dtos = list(
|
||||||
|
map(
|
||||||
|
lambda r: image_record_to_dto(
|
||||||
|
r,
|
||||||
|
self.__invoker.services.urls.get_image_url(r.image_name),
|
||||||
|
self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||||
|
self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||||
|
),
|
||||||
|
results.items,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return OffsetPaginatedResults[ImageDTO](
|
||||||
|
items=image_dtos,
|
||||||
|
offset=results.offset,
|
||||||
|
limit=results.limit,
|
||||||
|
total=results.total,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem getting paginated image DTOs")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete(self, image_name: str):
|
||||||
|
try:
|
||||||
|
self.__invoker.services.image_files.delete(image_name)
|
||||||
|
self.__invoker.services.image_records.delete(image_name)
|
||||||
|
self._on_deleted(image_name)
|
||||||
|
except ImageRecordDeleteException:
|
||||||
|
self.__invoker.services.logger.error("Failed to delete image record")
|
||||||
|
raise
|
||||||
|
except ImageFileDeleteException:
|
||||||
|
self.__invoker.services.logger.error("Failed to delete image file")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem deleting image record and file")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete_images_on_board(self, board_id: str):
|
||||||
|
try:
|
||||||
|
image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||||
|
for image_name in image_names:
|
||||||
|
self.__invoker.services.image_files.delete(image_name)
|
||||||
|
self.__invoker.services.image_records.delete_many(image_names)
|
||||||
|
for image_name in image_names:
|
||||||
|
self._on_deleted(image_name)
|
||||||
|
except ImageRecordDeleteException:
|
||||||
|
self.__invoker.services.logger.error("Failed to delete image records")
|
||||||
|
raise
|
||||||
|
except ImageFileDeleteException:
|
||||||
|
self.__invoker.services.logger.error("Failed to delete image files")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem deleting image records and files")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete_intermediates(self) -> int:
|
||||||
|
try:
|
||||||
|
image_names = self.__invoker.services.image_records.delete_intermediates()
|
||||||
|
count = len(image_names)
|
||||||
|
for image_name in image_names:
|
||||||
|
self.__invoker.services.image_files.delete(image_name)
|
||||||
|
self._on_deleted(image_name)
|
||||||
|
return count
|
||||||
|
except ImageRecordDeleteException:
|
||||||
|
self.__invoker.services.logger.error("Failed to delete image records")
|
||||||
|
raise
|
||||||
|
except ImageFileDeleteException:
|
||||||
|
self.__invoker.services.logger.error("Failed to delete image files")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem deleting image records and files")
|
||||||
|
raise e
|
||||||
@@ -58,7 +58,12 @@ class MemoryInvocationCache(InvocationCacheBase):
|
|||||||
# If the cache is full, we need to remove the least used
|
# If the cache is full, we need to remove the least used
|
||||||
number_to_delete = len(self._cache) + 1 - self._max_cache_size
|
number_to_delete = len(self._cache) + 1 - self._max_cache_size
|
||||||
self._delete_oldest_access(number_to_delete)
|
self._delete_oldest_access(number_to_delete)
|
||||||
self._cache[key] = CachedItem(invocation_output, invocation_output.json())
|
self._cache[key] = CachedItem(
|
||||||
|
invocation_output,
|
||||||
|
invocation_output.model_dump_json(
|
||||||
|
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||||
number_to_delete = min(number_to_delete, len(self._cache))
|
number_to_delete = min(number_to_delete, len(self._cache))
|
||||||
@@ -85,7 +90,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_key(invocation: BaseInvocation) -> int:
|
def create_key(invocation: BaseInvocation) -> int:
|
||||||
return hash(invocation.json(exclude={"id"}))
|
return hash(invocation.model_dump_json(exclude={"id"}, warnings=False))
|
||||||
|
|
||||||
def disable(self) -> None:
|
def disable(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationProcessorABC(ABC):
|
||||||
|
pass
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressImage(BaseModel):
|
||||||
|
"""The progress image sent intermittently during processing"""
|
||||||
|
|
||||||
|
width: int = Field(description="The effective width of the image in pixels")
|
||||||
|
height: int = Field(description="The effective height of the image in pixels")
|
||||||
|
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||||
|
|
||||||
|
|
||||||
|
class CanceledException(Exception):
|
||||||
|
"""Execution canceled by user."""
|
||||||
|
|
||||||
|
pass
|
||||||
@@ -4,12 +4,12 @@ from threading import BoundedSemaphore, Event, Thread
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.app.invocations.baseinvocation import AppInvocationContext
|
||||||
|
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||||
|
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invoker import Invoker
|
||||||
from ..models.exceptions import CanceledException
|
from .invocation_processor_base import InvocationProcessorABC
|
||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_processor_common import CanceledException
|
||||||
from .invocation_stats import InvocationStatsServiceBase
|
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
@@ -37,7 +37,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
def __process(self, stop_event: Event):
|
def __process(self, stop_event: Event):
|
||||||
try:
|
try:
|
||||||
self.__threadLimit.acquire()
|
self.__threadLimit.acquire()
|
||||||
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
|
||||||
queue_item: Optional[InvocationQueueItem] = None
|
queue_item: Optional[InvocationQueueItem] = None
|
||||||
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
@@ -90,26 +89,28 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.model_dump(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Invoke
|
# Invoke
|
||||||
try:
|
try:
|
||||||
graph_id = graph_execution_state.id
|
graph_id = graph_execution_state.id
|
||||||
model_manager = self.__invoker.services.model_manager
|
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||||
with statistics.collect_stats(invocation, graph_id, model_manager):
|
|
||||||
|
with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
|
||||||
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
||||||
# which handles a few things:
|
# which handles a few things:
|
||||||
# - nodes that require a value, but get it only from a connection
|
# - nodes that require a value, but get it only from a connection
|
||||||
# - referencing the invocation cache instead of executing the node
|
# - referencing the invocation cache instead of executing the node
|
||||||
outputs = invocation.invoke_internal(
|
outputs = invocation.invoke_internal(
|
||||||
InvocationContext(
|
AppInvocationContext(
|
||||||
services=self.__invoker.services,
|
services=self.__invoker.services,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
queue_batch_id=queue_item.session_queue_batch_id,
|
queue_batch_id=queue_item.session_queue_batch_id,
|
||||||
|
source_node_id=source_node_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -129,17 +130,17 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.model_dump(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
result=outputs.dict(),
|
result=outputs.model_dump(),
|
||||||
)
|
)
|
||||||
statistics.log_stats()
|
self.__invoker.services.performance_statistics.log_stats()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except CanceledException:
|
except CanceledException:
|
||||||
statistics.reset_stats(graph_execution_state.id)
|
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -159,12 +160,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.model_dump(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
statistics.reset_stats(graph_execution_state.id)
|
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
@@ -189,7 +190,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.model_dump(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=traceback.format_exc(),
|
error=traceback.format_exc(),
|
||||||
0
invokeai/app/services/invocation_queue/__init__.py
Normal file
0
invokeai/app/services/invocation_queue/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .invocation_queue_common import InvocationQueueItem
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationQueueABC(ABC):
|
||||||
|
"""Abstract base class for all invocation queues"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self) -> InvocationQueueItem:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel(self, graph_execution_state_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationQueueItem(BaseModel):
|
||||||
|
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||||
|
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||||
|
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
||||||
|
session_queue_item_id: int = Field(
|
||||||
|
description="The ID of session queue item from which this invocation queue item came"
|
||||||
|
)
|
||||||
|
session_queue_batch_id: str = Field(
|
||||||
|
description="The ID of the session batch from which this invocation queue item came"
|
||||||
|
)
|
||||||
|
invoke_all: bool = Field(default=False)
|
||||||
|
timestamp: float = Field(default_factory=time.time)
|
||||||
@@ -1,45 +1,11 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from .invocation_queue_base import InvocationQueueABC
|
||||||
|
from .invocation_queue_common import InvocationQueueItem
|
||||||
|
|
||||||
class InvocationQueueItem(BaseModel):
|
|
||||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
|
||||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
|
||||||
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
|
||||||
session_queue_item_id: int = Field(
|
|
||||||
description="The ID of session queue item from which this invocation queue item came"
|
|
||||||
)
|
|
||||||
session_queue_batch_id: str = Field(
|
|
||||||
description="The ID of the session batch from which this invocation queue item came"
|
|
||||||
)
|
|
||||||
invoke_all: bool = Field(default=False)
|
|
||||||
timestamp: float = Field(default_factory=time.time)
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationQueueABC(ABC):
|
|
||||||
"""Abstract base class for all invocation queues"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self) -> InvocationQueueItem:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def cancel(self, graph_execution_state_id: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryInvocationQueue(InvocationQueueABC):
|
class MemoryInvocationQueue(InvocationQueueABC):
|
||||||
@@ -6,21 +6,27 @@ from typing import TYPE_CHECKING
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
from .board_image_records.board_image_records_base import BoardImageRecordStorageBase
|
||||||
from invokeai.app.services.boards import BoardServiceABC
|
from .board_images.board_images_base import BoardImagesServiceABC
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from .board_records.board_records_base import BoardRecordStorageBase
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from .boards.boards_base import BoardServiceABC
|
||||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
from .config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.images import ImageServiceABC
|
from .events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
from .image_files.image_files_base import ImageFileStorageBase
|
||||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
from .image_records.image_records_base import ImageRecordStorageBase
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
from .images.images_base import ImageServiceABC
|
||||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||||
from invokeai.app.services.item_storage import ItemStorageABC
|
from .invocation_processor.invocation_processor_base import InvocationProcessorABC
|
||||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
from .invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||||
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
from .item_storage.item_storage_base import ItemStorageABC
|
||||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||||
|
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||||
|
from .names.names_base import NameServiceBase
|
||||||
|
from .session_processor.session_processor_base import SessionProcessorBase
|
||||||
|
from .session_queue.session_queue_base import SessionQueueBase
|
||||||
|
from .shared.graph import GraphExecutionState, LibraryGraph
|
||||||
|
from .urls.urls_base import UrlServiceBase
|
||||||
|
|
||||||
|
|
||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
@@ -28,12 +34,16 @@ class InvocationServices:
|
|||||||
|
|
||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
board_images: "BoardImagesServiceABC"
|
board_images: "BoardImagesServiceABC"
|
||||||
|
board_image_record_storage: "BoardImageRecordStorageBase"
|
||||||
boards: "BoardServiceABC"
|
boards: "BoardServiceABC"
|
||||||
|
board_records: "BoardRecordStorageBase"
|
||||||
configuration: "InvokeAIAppConfig"
|
configuration: "InvokeAIAppConfig"
|
||||||
events: "EventServiceBase"
|
events: "EventServiceBase"
|
||||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
||||||
graph_library: "ItemStorageABC[LibraryGraph]"
|
graph_library: "ItemStorageABC[LibraryGraph]"
|
||||||
images: "ImageServiceABC"
|
images: "ImageServiceABC"
|
||||||
|
image_records: "ImageRecordStorageBase"
|
||||||
|
image_files: "ImageFileStorageBase"
|
||||||
latents: "LatentsStorageBase"
|
latents: "LatentsStorageBase"
|
||||||
logger: "Logger"
|
logger: "Logger"
|
||||||
model_manager: "ModelManagerServiceBase"
|
model_manager: "ModelManagerServiceBase"
|
||||||
@@ -43,16 +53,22 @@ class InvocationServices:
|
|||||||
session_queue: "SessionQueueBase"
|
session_queue: "SessionQueueBase"
|
||||||
session_processor: "SessionProcessorBase"
|
session_processor: "SessionProcessorBase"
|
||||||
invocation_cache: "InvocationCacheBase"
|
invocation_cache: "InvocationCacheBase"
|
||||||
|
names: "NameServiceBase"
|
||||||
|
urls: "UrlServiceBase"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
|
board_image_records: "BoardImageRecordStorageBase",
|
||||||
boards: "BoardServiceABC",
|
boards: "BoardServiceABC",
|
||||||
|
board_records: "BoardRecordStorageBase",
|
||||||
configuration: "InvokeAIAppConfig",
|
configuration: "InvokeAIAppConfig",
|
||||||
events: "EventServiceBase",
|
events: "EventServiceBase",
|
||||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
||||||
graph_library: "ItemStorageABC[LibraryGraph]",
|
graph_library: "ItemStorageABC[LibraryGraph]",
|
||||||
images: "ImageServiceABC",
|
images: "ImageServiceABC",
|
||||||
|
image_files: "ImageFileStorageBase",
|
||||||
|
image_records: "ImageRecordStorageBase",
|
||||||
latents: "LatentsStorageBase",
|
latents: "LatentsStorageBase",
|
||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
@@ -62,14 +78,20 @@ class InvocationServices:
|
|||||||
session_queue: "SessionQueueBase",
|
session_queue: "SessionQueueBase",
|
||||||
session_processor: "SessionProcessorBase",
|
session_processor: "SessionProcessorBase",
|
||||||
invocation_cache: "InvocationCacheBase",
|
invocation_cache: "InvocationCacheBase",
|
||||||
|
names: "NameServiceBase",
|
||||||
|
urls: "UrlServiceBase",
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
|
self.board_image_records = board_image_records
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
|
self.board_records = board_records
|
||||||
self.configuration = configuration
|
self.configuration = configuration
|
||||||
self.events = events
|
self.events = events
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
self.graph_library = graph_library
|
self.graph_library = graph_library
|
||||||
self.images = images
|
self.images = images
|
||||||
|
self.image_files = image_files
|
||||||
|
self.image_records = image_records
|
||||||
self.latents = latents
|
self.latents = latents
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
@@ -79,3 +101,5 @@ class InvocationServices:
|
|||||||
self.session_queue = session_queue
|
self.session_queue = session_queue
|
||||||
self.session_processor = session_processor
|
self.session_processor = session_processor
|
||||||
self.invocation_cache = invocation_cache
|
self.invocation_cache = invocation_cache
|
||||||
|
self.names = names
|
||||||
|
self.urls = urls
|
||||||
|
|||||||
0
invokeai/app/services/invocation_stats/__init__.py
Normal file
0
invokeai/app/services/invocation_stats/__init__.py
Normal file
121
invokeai/app/services/invocation_stats/invocation_stats_base.py
Normal file
121
invokeai/app/services/invocation_stats/invocation_stats_base.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||||
|
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
statistics = InvocationStatsService(graph_execution_manager)
|
||||||
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||||
|
... execute graphs...
|
||||||
|
statistics.log_stats()
|
||||||
|
|
||||||
|
Typical output:
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
||||||
|
|
||||||
|
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
||||||
|
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import AbstractContextManager
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
|
from .invocation_stats_common import NodeLog
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationStatsServiceBase(ABC):
|
||||||
|
"Abstract base class for recording node memory/time performance statistics"
|
||||||
|
|
||||||
|
# {graph_id => NodeLog}
|
||||||
|
_stats: Dict[str, NodeLog]
|
||||||
|
_cache_stats: Dict[str, CacheStats]
|
||||||
|
ram_used: float
|
||||||
|
ram_changed: float
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the InvocationStatsService and reset counters to zero
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collect_stats(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
) -> AbstractContextManager:
|
||||||
|
"""
|
||||||
|
Return a context object that will capture the statistics on the execution
|
||||||
|
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||||
|
:param invocation: BaseInvocation object from the current graph.
|
||||||
|
:param graph_execution_state_id: The id of the current session.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_stats(self, graph_execution_state_id: str):
|
||||||
|
"""
|
||||||
|
Reset all statistics for the indicated graph
|
||||||
|
:param graph_execution_state_id
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_all_stats(self):
|
||||||
|
"""Zero all statistics"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_invocation_stats(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
invocation_type: str,
|
||||||
|
time_used: float,
|
||||||
|
vram_used: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add timing information on execution of a node. Usually
|
||||||
|
used internally.
|
||||||
|
:param graph_id: ID of the graph that is currently executing
|
||||||
|
:param invocation_type: String literal type of the node
|
||||||
|
:param time_used: Time used by node's exection (sec)
|
||||||
|
:param vram_used: Maximum VRAM used during exection (GB)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def log_stats(self):
|
||||||
|
"""
|
||||||
|
Write out the accumulated statistics to the log or somewhere else.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_mem_stats(
|
||||||
|
self,
|
||||||
|
ram_used: float,
|
||||||
|
ram_changed: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the collector with RAM memory usage info.
|
||||||
|
|
||||||
|
:param ram_used: How much RAM is currently in use.
|
||||||
|
:param ram_changed: How much RAM changed since last generation.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
# size of GIG in bytes
|
||||||
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeStats:
|
||||||
|
"""Class for tracking execution stats of an invocation node"""
|
||||||
|
|
||||||
|
calls: int = 0
|
||||||
|
time_used: float = 0.0 # seconds
|
||||||
|
max_vram: float = 0.0 # GB
|
||||||
|
cache_hits: int = 0
|
||||||
|
cache_misses: int = 0
|
||||||
|
cache_high_watermark: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeLog:
|
||||||
|
"""Class for tracking node usage"""
|
||||||
|
|
||||||
|
# {node_type => NodeStats}
|
||||||
|
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||||
@@ -1,171 +1,35 @@
|
|||||||
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
|
||||||
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
statistics = InvocationStatsService(graph_execution_manager)
|
|
||||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
|
||||||
... execute graphs...
|
|
||||||
statistics.log_stats()
|
|
||||||
|
|
||||||
Typical output:
|
|
||||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
|
||||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
|
||||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
|
||||||
|
|
||||||
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
|
||||||
writes to the system log is stored in InvocationServices.performance_statistics.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from contextlib import AbstractContextManager
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
from ..invocations.baseinvocation import BaseInvocation
|
from .invocation_stats_base import InvocationStatsServiceBase
|
||||||
from .graph import GraphExecutionState
|
from .invocation_stats_common import GIG, NodeLog, NodeStats
|
||||||
from .item_storage import ItemStorageABC
|
|
||||||
from .model_manager_service import ModelManagerService
|
|
||||||
|
|
||||||
# size of GIG in bytes
|
|
||||||
GIG = 1073741824
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NodeStats:
|
|
||||||
"""Class for tracking execution stats of an invocation node"""
|
|
||||||
|
|
||||||
calls: int = 0
|
|
||||||
time_used: float = 0.0 # seconds
|
|
||||||
max_vram: float = 0.0 # GB
|
|
||||||
cache_hits: int = 0
|
|
||||||
cache_misses: int = 0
|
|
||||||
cache_high_watermark: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NodeLog:
|
|
||||||
"""Class for tracking node usage"""
|
|
||||||
|
|
||||||
# {node_type => NodeStats}
|
|
||||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsServiceBase(ABC):
|
|
||||||
"Abstract base class for recording node memory/time performance statistics"
|
|
||||||
|
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
|
||||||
# {graph_id => NodeLog}
|
|
||||||
_stats: Dict[str, NodeLog]
|
|
||||||
_cache_stats: Dict[str, CacheStats]
|
|
||||||
ram_used: float
|
|
||||||
ram_changed: float
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
|
||||||
"""
|
|
||||||
Initialize the InvocationStatsService and reset counters to zero
|
|
||||||
:param graph_execution_manager: Graph execution manager for this session
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def collect_stats(
|
|
||||||
self,
|
|
||||||
invocation: BaseInvocation,
|
|
||||||
graph_execution_state_id: str,
|
|
||||||
) -> AbstractContextManager:
|
|
||||||
"""
|
|
||||||
Return a context object that will capture the statistics on the execution
|
|
||||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
|
||||||
:param invocation: BaseInvocation object from the current graph.
|
|
||||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def reset_stats(self, graph_execution_state_id: str):
|
|
||||||
"""
|
|
||||||
Reset all statistics for the indicated graph
|
|
||||||
:param graph_execution_state_id
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def reset_all_stats(self):
|
|
||||||
"""Zero all statistics"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_invocation_stats(
|
|
||||||
self,
|
|
||||||
graph_id: str,
|
|
||||||
invocation_type: str,
|
|
||||||
time_used: float,
|
|
||||||
vram_used: float,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Add timing information on execution of a node. Usually
|
|
||||||
used internally.
|
|
||||||
:param graph_id: ID of the graph that is currently executing
|
|
||||||
:param invocation_type: String literal type of the node
|
|
||||||
:param time_used: Time used by node's exection (sec)
|
|
||||||
:param vram_used: Maximum VRAM used during exection (GB)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def log_stats(self):
|
|
||||||
"""
|
|
||||||
Write out the accumulated statistics to the log or somewhere else.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_mem_stats(
|
|
||||||
self,
|
|
||||||
ram_used: float,
|
|
||||||
ram_changed: float,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update the collector with RAM memory usage info.
|
|
||||||
|
|
||||||
:param ram_used: How much RAM is currently in use.
|
|
||||||
:param ram_changed: How much RAM changed since last generation.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsService(InvocationStatsServiceBase):
|
class InvocationStatsService(InvocationStatsServiceBase):
|
||||||
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
||||||
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
||||||
|
|
||||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
_invoker: Invoker
|
||||||
self.graph_execution_manager = graph_execution_manager
|
|
||||||
|
def __init__(self):
|
||||||
# {graph_id => NodeLog}
|
# {graph_id => NodeLog}
|
||||||
self._stats: Dict[str, NodeLog] = {}
|
self._stats: Dict[str, NodeLog] = {}
|
||||||
self._cache_stats: Dict[str, CacheStats] = {}
|
self._cache_stats: Dict[str, CacheStats] = {}
|
||||||
self.ram_used: float = 0.0
|
self.ram_used: float = 0.0
|
||||||
self.ram_changed: float = 0.0
|
self.ram_changed: float = 0.0
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self._invoker = invoker
|
||||||
|
|
||||||
class StatsContext:
|
class StatsContext:
|
||||||
"""Context manager for collecting statistics."""
|
"""Context manager for collecting statistics."""
|
||||||
|
|
||||||
@@ -174,13 +38,13 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
graph_id: str
|
graph_id: str
|
||||||
start_time: float
|
start_time: float
|
||||||
ram_used: int
|
ram_used: int
|
||||||
model_manager: ModelManagerService
|
model_manager: ModelManagerServiceBase
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
invocation: BaseInvocation,
|
invocation: BaseInvocation,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
model_manager: ModelManagerService,
|
model_manager: ModelManagerServiceBase,
|
||||||
collector: "InvocationStatsServiceBase",
|
collector: "InvocationStatsServiceBase",
|
||||||
):
|
):
|
||||||
"""Initialize statistics for this run."""
|
"""Initialize statistics for this run."""
|
||||||
@@ -208,7 +72,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
)
|
)
|
||||||
self.collector.update_invocation_stats(
|
self.collector.update_invocation_stats(
|
||||||
graph_id=self.graph_id,
|
graph_id=self.graph_id,
|
||||||
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
invocation_type=self.invocation.type, # type: ignore # `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||||
time_used=time.time() - self.start_time,
|
time_used=time.time() - self.start_time,
|
||||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||||
)
|
)
|
||||||
@@ -217,12 +81,11 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
self,
|
self,
|
||||||
invocation: BaseInvocation,
|
invocation: BaseInvocation,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
model_manager: ModelManagerService,
|
|
||||||
) -> StatsContext:
|
) -> StatsContext:
|
||||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||||
self._stats[graph_execution_state_id] = NodeLog()
|
self._stats[graph_execution_state_id] = NodeLog()
|
||||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||||
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
return self.StatsContext(invocation, graph_execution_state_id, self._invoker.services.model_manager, self)
|
||||||
|
|
||||||
def reset_all_stats(self):
|
def reset_all_stats(self):
|
||||||
"""Zero all statistics"""
|
"""Zero all statistics"""
|
||||||
@@ -261,7 +124,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
errored = set()
|
errored = set()
|
||||||
for graph_id, node_log in self._stats.items():
|
for graph_id, node_log in self._stats.items():
|
||||||
try:
|
try:
|
||||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
errored.add(graph_id)
|
errored.add(graph_id)
|
||||||
continue
|
continue
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from abc import ABC
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .graph import Graph, GraphExecutionState
|
from .invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||||
from .invocation_queue import InvocationQueueItem
|
|
||||||
from .invocation_services import InvocationServices
|
from .invocation_services import InvocationServices
|
||||||
|
from .shared.graph import Graph, GraphExecutionState
|
||||||
|
|
||||||
|
|
||||||
class Invoker:
|
class Invoker:
|
||||||
@@ -84,7 +83,3 @@ class Invoker:
|
|||||||
self.__stop_service(getattr(self.services, service))
|
self.__stop_service(getattr(self.services, service))
|
||||||
|
|
||||||
self.services.queue.put(None)
|
self.services.queue.put(None)
|
||||||
|
|
||||||
|
|
||||||
class InvocationProcessorABC(ABC):
|
|
||||||
pass
|
|
||||||
|
|||||||
0
invokeai/app/services/item_storage/__init__.py
Normal file
0
invokeai/app/services/item_storage/__init__.py
Normal file
@@ -1,25 +1,16 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Generic, Optional, TypeVar
|
from typing import Callable, Generic, Optional, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
from pydantic.generics import GenericModel
|
|
||||||
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
class PaginatedResults(GenericModel, Generic[T]):
|
|
||||||
"""Paginated results"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
items: list[T] = Field(description="Items")
|
|
||||||
page: int = Field(description="Current Page")
|
|
||||||
pages: int = Field(description="Total number of pages")
|
|
||||||
per_page: int = Field(description="Number of items per page")
|
|
||||||
total: int = Field(description="Total number of items in result")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class ItemStorageABC(ABC, Generic[T]):
|
class ItemStorageABC(ABC, Generic[T]):
|
||||||
|
"""Provides storage for a single type of item. The type must be a Pydantic model."""
|
||||||
|
|
||||||
_on_changed_callbacks: list[Callable[[T], None]]
|
_on_changed_callbacks: list[Callable[[T], None]]
|
||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
@@ -2,14 +2,15 @@ import sqlite3
|
|||||||
import threading
|
import threading
|
||||||
from typing import Generic, Optional, TypeVar, get_args
|
from typing import Generic, Optional, TypeVar, get_args
|
||||||
|
|
||||||
from pydantic import BaseModel, parse_raw_as
|
from pydantic import BaseModel, TypeAdapter
|
||||||
|
|
||||||
from .item_storage import ItemStorageABC, PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
|
||||||
|
from .item_storage_base import ItemStorageABC
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
sqlite_memory = ":memory:"
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||||
_table_name: str
|
_table_name: str
|
||||||
@@ -17,15 +18,17 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_id_field: str
|
_id_field: str
|
||||||
_lock: threading.RLock
|
_lock: threading.RLock
|
||||||
|
_adapter: Optional[TypeAdapter[T]]
|
||||||
|
|
||||||
def __init__(self, conn: sqlite3.Connection, table_name: str, lock: threading.RLock, id_field: str = "id"):
|
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self._lock = db.lock
|
||||||
|
self._conn = db.conn
|
||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
self._id_field = id_field # TODO: validate that T has this field
|
||||||
self._lock = lock
|
|
||||||
self._conn = conn
|
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
self._adapter: Optional[TypeAdapter[T]] = None
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
|
|
||||||
@@ -44,15 +47,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
item_type = get_args(self.__orig_class__)[0]
|
if self._adapter is None:
|
||||||
return parse_raw_as(item_type, item)
|
"""
|
||||||
|
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
||||||
|
we can create it when it is first needed instead.
|
||||||
|
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
||||||
|
"""
|
||||||
|
self._adapter = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
||||||
|
return self._adapter.validate_json(item)
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||||
(item.json(),),
|
(item.model_dump_json(warnings=False, exclude_none=True),),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
finally:
|
finally:
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from queue import Queue
|
|
||||||
from typing import Callable, Dict, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsStorageBase(ABC):
|
|
||||||
"""Responsible for storing and retrieving latents."""
|
|
||||||
|
|
||||||
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
|
||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._on_changed_callbacks = list()
|
|
||||||
self._on_deleted_callbacks = list()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, name: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
|
||||||
"""Register a callback for when an item is changed"""
|
|
||||||
self._on_changed_callbacks.append(on_changed)
|
|
||||||
|
|
||||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
|
||||||
"""Register a callback for when an item is deleted"""
|
|
||||||
self._on_deleted_callbacks.append(on_deleted)
|
|
||||||
|
|
||||||
def _on_changed(self, item: torch.Tensor) -> None:
|
|
||||||
for callback in self._on_changed_callbacks:
|
|
||||||
callback(item)
|
|
||||||
|
|
||||||
def _on_deleted(self, item_id: str) -> None:
|
|
||||||
for callback in self._on_deleted_callbacks:
|
|
||||||
callback(item_id)
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|
||||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
|
||||||
|
|
||||||
__cache: Dict[str, torch.Tensor]
|
|
||||||
__cache_ids: Queue
|
|
||||||
__max_cache_size: int
|
|
||||||
__underlying_storage: LatentsStorageBase
|
|
||||||
|
|
||||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
|
||||||
super().__init__()
|
|
||||||
self.__underlying_storage = underlying_storage
|
|
||||||
self.__cache = dict()
|
|
||||||
self.__cache_ids = Queue()
|
|
||||||
self.__max_cache_size = max_cache_size
|
|
||||||
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
|
||||||
cache_item = self.__get_cache(name)
|
|
||||||
if cache_item is not None:
|
|
||||||
return cache_item
|
|
||||||
|
|
||||||
latent = self.__underlying_storage.get(name)
|
|
||||||
self.__set_cache(name, latent)
|
|
||||||
return latent
|
|
||||||
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
|
||||||
self.__underlying_storage.save(name, data)
|
|
||||||
self.__set_cache(name, data)
|
|
||||||
self._on_changed(data)
|
|
||||||
|
|
||||||
def delete(self, name: str) -> None:
|
|
||||||
self.__underlying_storage.delete(name)
|
|
||||||
if name in self.__cache:
|
|
||||||
del self.__cache[name]
|
|
||||||
self._on_deleted(name)
|
|
||||||
|
|
||||||
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
|
||||||
return None if name not in self.__cache else self.__cache[name]
|
|
||||||
|
|
||||||
def __set_cache(self, name: str, data: torch.Tensor):
|
|
||||||
if name not in self.__cache:
|
|
||||||
self.__cache[name] = data
|
|
||||||
self.__cache_ids.put(name)
|
|
||||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
|
||||||
self.__cache.pop(self.__cache_ids.get())
|
|
||||||
|
|
||||||
|
|
||||||
class DiskLatentsStorage(LatentsStorageBase):
|
|
||||||
"""Stores latents in a folder on disk without caching"""
|
|
||||||
|
|
||||||
__output_folder: Union[str, Path]
|
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
|
||||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
|
||||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
|
||||||
latent_path = self.get_path(name)
|
|
||||||
return torch.load(latent_path)
|
|
||||||
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
|
||||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
latent_path = self.get_path(name)
|
|
||||||
torch.save(data, latent_path)
|
|
||||||
|
|
||||||
def delete(self, name: str) -> None:
|
|
||||||
latent_path = self.get_path(name)
|
|
||||||
latent_path.unlink()
|
|
||||||
|
|
||||||
def get_path(self, name: str) -> Path:
|
|
||||||
return self.__output_folder / name
|
|
||||||
0
invokeai/app/services/latents_storage/__init__.py
Normal file
0
invokeai/app/services/latents_storage/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsStorageBase(ABC):
|
||||||
|
"""Responsible for storing and retrieving latents."""
|
||||||
|
|
||||||
|
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
||||||
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._on_changed_callbacks = list()
|
||||||
|
self._on_deleted_callbacks = list()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
||||||
|
"""Register a callback for when an item is changed"""
|
||||||
|
self._on_changed_callbacks.append(on_changed)
|
||||||
|
|
||||||
|
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||||
|
"""Register a callback for when an item is deleted"""
|
||||||
|
self._on_deleted_callbacks.append(on_deleted)
|
||||||
|
|
||||||
|
def _on_changed(self, item: torch.Tensor) -> None:
|
||||||
|
for callback in self._on_changed_callbacks:
|
||||||
|
callback(item)
|
||||||
|
|
||||||
|
def _on_deleted(self, item_id: str) -> None:
|
||||||
|
for callback in self._on_deleted_callbacks:
|
||||||
|
callback(item_id)
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .latents_storage_base import LatentsStorageBase
|
||||||
|
|
||||||
|
|
||||||
|
class DiskLatentsStorage(LatentsStorageBase):
|
||||||
|
"""Stores latents in a folder on disk without caching"""
|
||||||
|
|
||||||
|
__output_folder: Path
|
||||||
|
|
||||||
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
|
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
|
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
return torch.load(latent_path)
|
||||||
|
|
||||||
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
torch.save(data, latent_path)
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
latent_path.unlink()
|
||||||
|
|
||||||
|
def get_path(self, name: str) -> Path:
|
||||||
|
return self.__output_folder / name
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from queue import Queue
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .latents_storage_base import LatentsStorageBase
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||||
|
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||||
|
|
||||||
|
__cache: Dict[str, torch.Tensor]
|
||||||
|
__cache_ids: Queue
|
||||||
|
__max_cache_size: int
|
||||||
|
__underlying_storage: LatentsStorageBase
|
||||||
|
|
||||||
|
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||||
|
super().__init__()
|
||||||
|
self.__underlying_storage = underlying_storage
|
||||||
|
self.__cache = dict()
|
||||||
|
self.__cache_ids = Queue()
|
||||||
|
self.__max_cache_size = max_cache_size
|
||||||
|
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
cache_item = self.__get_cache(name)
|
||||||
|
if cache_item is not None:
|
||||||
|
return cache_item
|
||||||
|
|
||||||
|
latent = self.__underlying_storage.get(name)
|
||||||
|
self.__set_cache(name, latent)
|
||||||
|
return latent
|
||||||
|
|
||||||
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
self.__underlying_storage.save(name, data)
|
||||||
|
self.__set_cache(name, data)
|
||||||
|
self._on_changed(data)
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
self.__underlying_storage.delete(name)
|
||||||
|
if name in self.__cache:
|
||||||
|
del self.__cache[name]
|
||||||
|
self._on_deleted(name)
|
||||||
|
|
||||||
|
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
||||||
|
return None if name not in self.__cache else self.__cache[name]
|
||||||
|
|
||||||
|
def __set_cache(self, name: str, data: torch.Tensor):
|
||||||
|
if name not in self.__cache:
|
||||||
|
self.__cache[name] = data
|
||||||
|
self.__cache_ids.put(name)
|
||||||
|
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||||
|
self.__cache.pop(self.__cache_ids.get())
|
||||||
0
invokeai/app/services/model_manager/__init__.py
Normal file
0
invokeai/app/services/model_manager/__init__.py
Normal file
289
invokeai/app/services/model_manager/model_manager_base.py
Normal file
289
invokeai/app/services/model_manager/model_manager_base.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.model_management import (
|
||||||
|
AddModelResult,
|
||||||
|
BaseModelType,
|
||||||
|
MergeInterpolationMethod,
|
||||||
|
ModelInfo,
|
||||||
|
ModelType,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManagerServiceBase(ABC):
|
||||||
|
"""Responsible for managing models on disk and in memory"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: InvokeAIAppConfig,
|
||||||
|
logger: Logger,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize with the path to the models.yaml config file.
|
||||||
|
Optional parameters are the torch device type, precision, max_models,
|
||||||
|
and sequential_offload boolean. Note that the default device
|
||||||
|
type and precision are set up for a CUDA system running at half precision.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
submodel: Optional[SubModelType] = None,
|
||||||
|
node: Optional[BaseInvocation] = None,
|
||||||
|
) -> ModelInfo:
|
||||||
|
"""Retrieve the indicated model with name and type.
|
||||||
|
submodel can be used to get a part (such as the vae)
|
||||||
|
of a diffusers pipeline."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def logger(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_exists(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
|
"""
|
||||||
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
|
Uses the exact format as the omegaconf stanza.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||||
|
"""
|
||||||
|
Return a dict of models in the format:
|
||||||
|
{ model_type1:
|
||||||
|
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||||
|
'model_name' : name,
|
||||||
|
'model_type' : SDModelType,
|
||||||
|
'description': description,
|
||||||
|
'format': 'folder'|'safetensors'|'ckpt'
|
||||||
|
},
|
||||||
|
model_name2: { etc }
|
||||||
|
},
|
||||||
|
model_type2:
|
||||||
|
{ model_name_n: etc
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
|
"""
|
||||||
|
Return information about the model using the same format as list_models()
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||||
|
"""
|
||||||
|
Returns a list of all the model names known.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
clobber: bool = False,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
|
with an assertion error if provided attributes are incorrect or
|
||||||
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
|
ModelNotFoundException if the name does not already exist.
|
||||||
|
|
||||||
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
|
with an assertion error if provided attributes are incorrect or
|
||||||
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def del_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete the named model from configuration. If delete_files is true,
|
||||||
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
|
as well. Call commit() to write to disk.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def rename_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
new_name: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rename the indicated model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_checkpoint_configs(self) -> List[Path]:
|
||||||
|
"""
|
||||||
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def convert_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
|
version and deleting the original checkpoint file if it is in the models
|
||||||
|
directory.
|
||||||
|
:param model_name: Name of the model to convert
|
||||||
|
:param base_model: Base model type
|
||||||
|
:param model_type: Type of model ['vae' or 'main']
|
||||||
|
|
||||||
|
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||||
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||||
|
directory already in place.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def heuristic_import(
|
||||||
|
self,
|
||||||
|
items_to_import: set[str],
|
||||||
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||||
|
) -> dict[str, AddModelResult]:
|
||||||
|
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
|
successfully imported items.
|
||||||
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
|
|
||||||
|
The prediction type helper is necessary to distinguish between
|
||||||
|
models based on Stable Diffusion 2 Base (requiring
|
||||||
|
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||||
|
(requiring SchedulerPredictionType.VPrediction). It is
|
||||||
|
generally impossible to do this programmatically, so the
|
||||||
|
prediction_type_helper usually asks the user to choose.
|
||||||
|
|
||||||
|
The result is a set of successfully installed models. Each element
|
||||||
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
|
that model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def merge_models(
|
||||||
|
self,
|
||||||
|
model_names: List[str] = Field(
|
||||||
|
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||||
|
),
|
||||||
|
base_model: Union[BaseModelType, str] = Field(
|
||||||
|
default=None, description="Base model shared by all models to be merged"
|
||||||
|
),
|
||||||
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||||
|
alpha: Optional[float] = 0.5,
|
||||||
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
|
force: Optional[bool] = False,
|
||||||
|
merge_dest_directory: Optional[Path] = None,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
|
:param model_names: List of 2-3 models to merge
|
||||||
|
:param base_model: Base model to use for all models
|
||||||
|
:param merged_model_name: Name of destination merged model
|
||||||
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
|
:param interp: Interpolation method. None (default)
|
||||||
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_for_models(self, directory: Path) -> List[Path]:
|
||||||
|
"""
|
||||||
|
Return list of all models found in the designated directory.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sync_to_config(self):
|
||||||
|
"""
|
||||||
|
Re-read models.yaml, rescan the models directory, and reimport models
|
||||||
|
in the autoimport directories. Call after making changes outside the
|
||||||
|
model manager API.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||||
|
"""
|
||||||
|
Reset model cache statistics for graph with graph_id.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||||
|
"""
|
||||||
|
Write current configuration out to the indicated file.
|
||||||
|
If no conf_file is provided, then replaces the
|
||||||
|
original file/database used to initialize the object.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
@@ -2,16 +2,16 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import ModuleType
|
|
||||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_management import (
|
from invokeai.backend.model_management import (
|
||||||
AddModelResult,
|
AddModelResult,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@@ -26,273 +26,12 @@ from invokeai.backend.model_management import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
from invokeai.backend.model_management.model_search import FindModels
|
from invokeai.backend.model_management.model_search import FindModels
|
||||||
|
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||||
|
|
||||||
from ...backend.util import choose_precision, choose_torch_device
|
from .model_manager_base import ModelManagerServiceBase
|
||||||
from .config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||||
|
|
||||||
|
|
||||||
class ModelManagerServiceBase(ABC):
|
|
||||||
"""Responsible for managing models on disk and in memory"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: InvokeAIAppConfig,
|
|
||||||
logger: ModuleType,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize with the path to the models.yaml config file.
|
|
||||||
Optional parameters are the torch device type, precision, max_models,
|
|
||||||
and sequential_offload boolean. Note that the default device
|
|
||||||
type and precision are set up for a CUDA system running at half precision.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
submodel: Optional[SubModelType] = None,
|
|
||||||
node: Optional[BaseInvocation] = None,
|
|
||||||
context: Optional[InvocationContext] = None,
|
|
||||||
) -> ModelInfo:
|
|
||||||
"""Retrieve the indicated model with name and type.
|
|
||||||
submodel can be used to get a part (such as the vae)
|
|
||||||
of a diffusers pipeline."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def logger(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_exists(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
|
||||||
"""
|
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
|
||||||
Uses the exact format as the omegaconf stanza.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
|
||||||
"""
|
|
||||||
Return a dict of models in the format:
|
|
||||||
{ model_type1:
|
|
||||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
|
||||||
'model_name' : name,
|
|
||||||
'model_type' : SDModelType,
|
|
||||||
'description': description,
|
|
||||||
'format': 'folder'|'safetensors'|'ckpt'
|
|
||||||
},
|
|
||||||
model_name2: { etc }
|
|
||||||
},
|
|
||||||
model_type2:
|
|
||||||
{ model_name_n: etc
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
|
||||||
"""
|
|
||||||
Return information about the model using the same format as list_models()
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
|
||||||
"""
|
|
||||||
Returns a list of all the model names known.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
model_attributes: dict,
|
|
||||||
clobber: bool = False,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
|
||||||
with an assertion error if provided attributes are incorrect or
|
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
model_attributes: dict,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Update the named model with a dictionary of attributes. Will fail with a
|
|
||||||
ModelNotFoundException if the name does not already exist.
|
|
||||||
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
|
||||||
with an assertion error if provided attributes are incorrect or
|
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def del_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Delete the named model from configuration. If delete_files is true,
|
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
|
||||||
as well. Call commit() to write to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def rename_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
new_name: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Rename the indicated model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_checkpoint_configs(self) -> List[Path]:
|
|
||||||
"""
|
|
||||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def convert_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
|
||||||
version and deleting the original checkpoint file if it is in the models
|
|
||||||
directory.
|
|
||||||
:param model_name: Name of the model to convert
|
|
||||||
:param base_model: Base model type
|
|
||||||
:param model_type: Type of model ['vae' or 'main']
|
|
||||||
|
|
||||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
|
||||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
|
||||||
directory already in place.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def heuristic_import(
|
|
||||||
self,
|
|
||||||
items_to_import: set[str],
|
|
||||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
||||||
) -> dict[str, AddModelResult]:
|
|
||||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
|
||||||
successfully imported items.
|
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
|
||||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
|
||||||
|
|
||||||
The prediction type helper is necessary to distinguish between
|
|
||||||
models based on Stable Diffusion 2 Base (requiring
|
|
||||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
|
||||||
(requiring SchedulerPredictionType.VPrediction). It is
|
|
||||||
generally impossible to do this programmatically, so the
|
|
||||||
prediction_type_helper usually asks the user to choose.
|
|
||||||
|
|
||||||
The result is a set of successfully installed models. Each element
|
|
||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
|
||||||
that model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def merge_models(
|
|
||||||
self,
|
|
||||||
model_names: List[str] = Field(
|
|
||||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
|
||||||
),
|
|
||||||
base_model: Union[BaseModelType, str] = Field(
|
|
||||||
default=None, description="Base model shared by all models to be merged"
|
|
||||||
),
|
|
||||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
|
||||||
alpha: Optional[float] = 0.5,
|
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
|
||||||
force: Optional[bool] = False,
|
|
||||||
merge_dest_directory: Optional[Path] = None,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
|
||||||
:param model_names: List of 2-3 models to merge
|
|
||||||
:param base_model: Base model to use for all models
|
|
||||||
:param merged_model_name: Name of destination merged model
|
|
||||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
|
||||||
:param interp: Interpolation method. None (default)
|
|
||||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search_for_models(self, directory: Path) -> List[Path]:
|
|
||||||
"""
|
|
||||||
Return list of all models found in the designated directory.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def sync_to_config(self):
|
|
||||||
"""
|
|
||||||
Re-read models.yaml, rescan the models directory, and reimport models
|
|
||||||
in the autoimport directories. Call after making changes outside the
|
|
||||||
model manager API.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
|
||||||
"""
|
|
||||||
Reset model cache statistics for graph with graph_id.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
|
||||||
"""
|
|
||||||
Write current configuration out to the indicated file.
|
|
||||||
If no conf_file is provided, then replaces the
|
|
||||||
original file/database used to initialize the object.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# simple implementation
|
# simple implementation
|
||||||
@@ -348,28 +87,35 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
)
|
)
|
||||||
logger.info("Model manager service initialized")
|
logger.info("Model manager service initialized")
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self._invoker = invoker
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
|
graph_execution_state_id: str,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
|
||||||
) -> ModelInfo:
|
) -> ModelInfo:
|
||||||
"""
|
"""
|
||||||
Retrieve the indicated model. submodel can be used to get a
|
Retrieve the indicated model. submodel can be used to get a
|
||||||
part (such as the vae) of a diffusers mode.
|
part (such as the vae) of a diffusers mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# we can emit model loading events if we are executing with access to the invocation context
|
self._emit_load_event(
|
||||||
if context:
|
queue_id=queue_id,
|
||||||
self._emit_load_event(
|
queue_item_id=queue_item_id,
|
||||||
context=context,
|
queue_batch_id=queue_batch_id,
|
||||||
model_name=model_name,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
base_model=base_model,
|
model_name=model_name,
|
||||||
model_type=model_type,
|
base_model=base_model,
|
||||||
submodel=submodel,
|
model_type=model_type,
|
||||||
)
|
submodel=submodel,
|
||||||
|
)
|
||||||
|
|
||||||
model_info = self.mgr.get_model(
|
model_info = self.mgr.get_model(
|
||||||
model_name,
|
model_name,
|
||||||
@@ -378,15 +124,17 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
submodel,
|
submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if context:
|
self._emit_load_event(
|
||||||
self._emit_load_event(
|
queue_id=queue_id,
|
||||||
context=context,
|
queue_item_id=queue_item_id,
|
||||||
model_name=model_name,
|
queue_batch_id=queue_batch_id,
|
||||||
base_model=base_model,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
model_type=model_type,
|
model_name=model_name,
|
||||||
submodel=submodel,
|
base_model=base_model,
|
||||||
model_info=model_info,
|
model_type=model_type,
|
||||||
)
|
submodel=submodel,
|
||||||
|
model_info=model_info,
|
||||||
|
)
|
||||||
|
|
||||||
return model_info
|
return model_info
|
||||||
|
|
||||||
@@ -525,22 +273,25 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
|
|
||||||
def _emit_load_event(
|
def _emit_load_event(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
|
graph_execution_state_id: str,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
model_info: Optional[ModelInfo] = None,
|
model_info: Optional[ModelInfo] = None,
|
||||||
):
|
):
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
if self._invoker.services.queue.is_canceled(graph_execution_state_id):
|
||||||
raise CanceledException()
|
raise CanceledException()
|
||||||
|
|
||||||
if model_info:
|
if model_info:
|
||||||
context.services.events.emit_model_load_completed(
|
self._invoker.services.events.emit_model_load_completed(
|
||||||
queue_id=context.queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
queue_batch_id=context.queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@@ -548,11 +299,11 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context.services.events.emit_model_load_started(
|
self._invoker.services.events.emit_model_load_started(
|
||||||
queue_id=context.queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
queue_batch_id=context.queue_batch_id,
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@@ -589,7 +340,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def merge_models(
|
def merge_models(
|
||||||
self,
|
self,
|
||||||
model_names: List[str] = Field(
|
model_names: List[str] = Field(
|
||||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||||
),
|
),
|
||||||
base_model: Union[BaseModelType, str] = Field(
|
base_model: Union[BaseModelType, str] = Field(
|
||||||
default=None, description="Base model shared by all models to be merged"
|
default=None, description="Base model shared by all models to be merged"
|
||||||
0
invokeai/app/services/names/__init__.py
Normal file
0
invokeai/app/services/names/__init__.py
Normal file
11
invokeai/app/services/names/names_base.py
Normal file
11
invokeai/app/services/names/names_base.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class NameServiceBase(ABC):
|
||||||
|
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
||||||
|
|
||||||
|
# TODO: Add customizable naming schemes
|
||||||
|
@abstractmethod
|
||||||
|
def create_image_name(self) -> str:
|
||||||
|
"""Creates a name for an image."""
|
||||||
|
pass
|
||||||
8
invokeai/app/services/names/names_common.py
Normal file
8
invokeai/app/services/names/names_common.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from enum import Enum, EnumMeta
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||||
|
"""Enum for resource types."""
|
||||||
|
|
||||||
|
IMAGE = "image"
|
||||||
|
LATENT = "latent"
|
||||||
13
invokeai/app/services/names/names_default.py
Normal file
13
invokeai/app/services/names/names_default.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
from .names_base import NameServiceBase
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleNameService(NameServiceBase):
|
||||||
|
"""Creates image names from UUIDs."""
|
||||||
|
|
||||||
|
# TODO: Add customizable naming schemes
|
||||||
|
def create_image_name(self) -> str:
|
||||||
|
uuid_str = uuid_string()
|
||||||
|
filename = f"{uuid_str}.png"
|
||||||
|
return filename
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from enum import Enum, EnumMeta
|
|
||||||
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
|
|
||||||
|
|
||||||
class ResourceType(str, Enum, metaclass=EnumMeta):
|
|
||||||
"""Enum for resource types."""
|
|
||||||
|
|
||||||
IMAGE = "image"
|
|
||||||
LATENT = "latent"
|
|
||||||
|
|
||||||
|
|
||||||
class NameServiceBase(ABC):
|
|
||||||
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
|
||||||
|
|
||||||
# TODO: Add customizable naming schemes
|
|
||||||
@abstractmethod
|
|
||||||
def create_image_name(self) -> str:
|
|
||||||
"""Creates a name for an image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleNameService(NameServiceBase):
|
|
||||||
"""Creates image names from UUIDs."""
|
|
||||||
|
|
||||||
# TODO: Add customizable naming schemes
|
|
||||||
def create_image_name(self) -> str:
|
|
||||||
uuid_str = uuid_string()
|
|
||||||
filename = f"{uuid_str}.png"
|
|
||||||
return filename
|
|
||||||
@@ -7,7 +7,7 @@ from typing import Optional
|
|||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.typing import Event as FastAPIEvent
|
from fastapi_events.typing import Event as FastAPIEvent
|
||||||
|
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
|
|
||||||
from ..invoker import Invoker
|
from ..invoker import Invoker
|
||||||
@@ -97,7 +97,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
resume_event.set()
|
resume_event.set()
|
||||||
self.__threadLimit.acquire()
|
self.__threadLimit.acquire()
|
||||||
queue_item: Optional[SessionQueueItem] = None
|
queue_item: Optional[SessionQueueItem] = None
|
||||||
self.__invoker.services.logger
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
poll_now_event.clear()
|
poll_now_event.clear()
|
||||||
try:
|
try:
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user