mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Port SqliteSessionQueue to a SQLAlchemy Core / SQLModel hybrid that keeps the existing public API and DB schema (migrations and triggers untouched). Hot paths (enqueue bulk insert, dequeue, bulk cancel/delete, list with cursor pagination, status aggregations) use Core to avoid ORM hydration overhead; single-row reads stay ORM-style for clarity. - Add SqlModelSessionQueue alongside the legacy SqliteSessionQueue - Add the missing `workflow` column to SessionQueueTable (was added by migration_2 but never declared on the SQLModel) - Wire dependencies.py to the new implementation - Add 36 unit tests covering enqueue/dequeue, status mutations, bulk cancel/delete, prune-to-limit, retry, pagination and aggregations - Avoid nested write sessions on the single StaticPool connection by reading the current item before opening the outer write session
238 lines
11 KiB
Python
238 lines
11 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
import asyncio
|
|
from logging import Logger
|
|
|
|
import torch
|
|
|
|
from invokeai.app.services.app_settings.app_settings_sqlmodel import AppSettingsServiceSqlModel
|
|
from invokeai.app.services.auth.token_service import set_jwt_secret
|
|
from invokeai.app.services.board_image_records.board_image_records_sqlmodel import SqlModelBoardImageRecordStorage
|
|
from invokeai.app.services.board_images.board_images_default import BoardImagesService
|
|
from invokeai.app.services.board_records.board_records_sqlmodel import SqlModelBoardRecordStorage
|
|
from invokeai.app.services.boards.boards_default import BoardService
|
|
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
|
|
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlmodel import (
|
|
ClientStatePersistenceSqlModel,
|
|
)
|
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
|
from invokeai.app.services.download.download_default import DownloadQueueService
|
|
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
|
|
from invokeai.app.services.external_generation.external_generation_default import ExternalGenerationService
|
|
from invokeai.app.services.external_generation.providers import GeminiProvider, OpenAIProvider
|
|
from invokeai.app.services.external_generation.startup import sync_configured_external_starter_models
|
|
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
|
from invokeai.app.services.image_records.image_records_sqlmodel import SqlModelImageRecordStorage
|
|
from invokeai.app.services.images.images_default import ImageService
|
|
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
|
from invokeai.app.services.invocation_services import InvocationServices
|
|
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
|
from invokeai.app.services.invoker import Invoker
|
|
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
|
|
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
|
|
from invokeai.app.services.model_records.model_records_sqlmodel import ModelRecordServiceSqlModel
|
|
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlmodel import (
|
|
SqlModelModelRelationshipRecordStorage,
|
|
)
|
|
from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService
|
|
from invokeai.app.services.names.names_default import SimpleNameService
|
|
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
|
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
|
from invokeai.app.services.session_processor.session_processor_default import (
|
|
DefaultSessionProcessor,
|
|
DefaultSessionRunner,
|
|
)
|
|
from invokeai.app.services.session_queue.session_queue_sqlmodel import SqlModelSessionQueue
|
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
|
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
|
|
from invokeai.app.services.style_preset_records.style_preset_records_sqlmodel import SqlModelStylePresetRecordsStorage
|
|
from invokeai.app.services.urls.urls_default import LocalUrlService
|
|
from invokeai.app.services.users.users_sqlmodel import UserServiceSqlModel
|
|
from invokeai.app.services.workflow_records.workflow_records_sqlmodel import SqlModelWorkflowRecordsStorage
|
|
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|
AnimaConditioningInfo,
|
|
BasicConditioningInfo,
|
|
CogView4ConditioningInfo,
|
|
ConditioningFieldData,
|
|
FLUXConditioningInfo,
|
|
QwenImageConditioningInfo,
|
|
SD3ConditioningInfo,
|
|
SDXLConditioningInfo,
|
|
ZImageConditioningInfo,
|
|
)
|
|
from invokeai.backend.util.logging import InvokeAILogger
|
|
from invokeai.version.invokeai_version import __version__
|
|
|
|
|
|
# TODO: is there a better way to achieve this?
|
|
def check_internet() -> bool:
|
|
"""
|
|
Return true if the internet is reachable.
|
|
It does this by pinging huggingface.co.
|
|
"""
|
|
import urllib.request
|
|
|
|
host = "http://huggingface.co"
|
|
try:
|
|
urllib.request.urlopen(host, timeout=1)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
logger = InvokeAILogger.get_logger()
|
|
|
|
|
|
class ApiDependencies:
|
|
"""Contains and initializes all dependencies for the API"""
|
|
|
|
invoker: Invoker
|
|
|
|
@staticmethod
|
|
def initialize(
|
|
config: InvokeAIAppConfig,
|
|
event_handler_id: int,
|
|
loop: asyncio.AbstractEventLoop,
|
|
logger: Logger = logger,
|
|
) -> None:
|
|
logger.info(f"InvokeAI version {__version__}")
|
|
logger.info(f"Root directory = {str(config.root_path)}")
|
|
|
|
output_folder = config.outputs_path
|
|
if output_folder is None:
|
|
raise ValueError("Output folder is not set")
|
|
|
|
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
|
|
|
model_images_folder = config.models_path
|
|
style_presets_folder = config.style_presets_path
|
|
workflow_thumbnails_folder = config.workflow_thumbnails_path
|
|
|
|
db = init_db(config=config, logger=logger, image_files=image_files)
|
|
|
|
# Initialize JWT secret from database
|
|
app_settings = AppSettingsServiceSqlModel(db=db)
|
|
jwt_secret = app_settings.get_jwt_secret()
|
|
set_jwt_secret(jwt_secret)
|
|
logger.info("JWT secret loaded from database")
|
|
|
|
configuration = config
|
|
logger = logger
|
|
|
|
board_image_records = SqlModelBoardImageRecordStorage(db=db)
|
|
board_images = BoardImagesService()
|
|
board_records = SqlModelBoardRecordStorage(db=db)
|
|
boards = BoardService()
|
|
events = FastAPIEventService(event_handler_id, loop=loop)
|
|
bulk_download = BulkDownloadService()
|
|
image_records = SqlModelImageRecordStorage(db=db)
|
|
images = ImageService()
|
|
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
|
tensors = ObjectSerializerForwardCache(
|
|
ObjectSerializerDisk[torch.Tensor](
|
|
output_folder / "tensors",
|
|
safe_globals=[torch.Tensor],
|
|
ephemeral=True,
|
|
),
|
|
)
|
|
conditioning = ObjectSerializerForwardCache(
|
|
ObjectSerializerDisk[ConditioningFieldData](
|
|
output_folder / "conditioning",
|
|
safe_globals=[
|
|
ConditioningFieldData,
|
|
BasicConditioningInfo,
|
|
SDXLConditioningInfo,
|
|
FLUXConditioningInfo,
|
|
SD3ConditioningInfo,
|
|
CogView4ConditioningInfo,
|
|
ZImageConditioningInfo,
|
|
QwenImageConditioningInfo,
|
|
AnimaConditioningInfo,
|
|
],
|
|
ephemeral=True,
|
|
),
|
|
)
|
|
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
|
model_record_service = ModelRecordServiceSqlModel(db=db, logger=logger)
|
|
model_manager = ModelManagerService.build_model_manager(
|
|
app_config=configuration,
|
|
model_record_service=model_record_service,
|
|
download_queue=download_queue_service,
|
|
events=events,
|
|
)
|
|
external_generation = ExternalGenerationService(
|
|
providers={
|
|
GeminiProvider.provider_id: GeminiProvider(app_config=configuration, logger=logger),
|
|
OpenAIProvider.provider_id: OpenAIProvider(app_config=configuration, logger=logger),
|
|
},
|
|
logger=logger,
|
|
record_store=model_record_service,
|
|
)
|
|
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
|
model_relationships = ModelRelationshipsService()
|
|
model_relationship_records = SqlModelModelRelationshipRecordStorage(db=db)
|
|
names = SimpleNameService()
|
|
performance_statistics = InvocationStatsService()
|
|
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
|
|
session_queue = SqlModelSessionQueue(db=db)
|
|
urls = LocalUrlService()
|
|
workflow_records = SqlModelWorkflowRecordsStorage(db=db)
|
|
style_preset_records = SqlModelStylePresetRecordsStorage(db=db)
|
|
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
|
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
|
|
client_state_persistence = ClientStatePersistenceSqlModel(db=db)
|
|
users = UserServiceSqlModel(db=db)
|
|
|
|
services = InvocationServices(
|
|
board_image_records=board_image_records,
|
|
board_images=board_images,
|
|
board_records=board_records,
|
|
boards=boards,
|
|
bulk_download=bulk_download,
|
|
configuration=configuration,
|
|
events=events,
|
|
image_files=image_files,
|
|
image_records=image_records,
|
|
images=images,
|
|
invocation_cache=invocation_cache,
|
|
logger=logger,
|
|
model_images=model_images_service,
|
|
model_manager=model_manager,
|
|
model_relationships=model_relationships,
|
|
model_relationship_records=model_relationship_records,
|
|
download_queue=download_queue_service,
|
|
external_generation=external_generation,
|
|
names=names,
|
|
performance_statistics=performance_statistics,
|
|
session_processor=session_processor,
|
|
session_queue=session_queue,
|
|
urls=urls,
|
|
workflow_records=workflow_records,
|
|
tensors=tensors,
|
|
conditioning=conditioning,
|
|
style_preset_records=style_preset_records,
|
|
style_preset_image_files=style_preset_image_files,
|
|
workflow_thumbnails=workflow_thumbnails,
|
|
client_state_persistence=client_state_persistence,
|
|
users=users,
|
|
)
|
|
|
|
ApiDependencies.invoker = Invoker(services)
|
|
configured_external_providers = {
|
|
provider_id
|
|
for provider_id, status in external_generation.get_provider_statuses().items()
|
|
if status.configured
|
|
}
|
|
sync_configured_external_starter_models(
|
|
configured_provider_ids=configured_external_providers,
|
|
model_manager=model_manager,
|
|
logger=logger,
|
|
)
|
|
db.clean()
|
|
|
|
@staticmethod
|
|
def shutdown() -> None:
|
|
if ApiDependencies.invoker:
|
|
ApiDependencies.invoker.stop()
|