Files
InvokeAI/invokeai/app/api/dependencies.py
Alexander Eichhorn 0a428ffff4 Migrate session_queue to SQLModel (Phase 3)
Port SqliteSessionQueue to a SQLAlchemy Core / SQLModel hybrid that keeps the
existing public API and DB schema (migrations and triggers untouched). Hot
paths (enqueue bulk insert, dequeue, bulk cancel/delete, list with cursor
pagination, status aggregations) use Core to avoid ORM hydration overhead;
single-row reads stay ORM-style for clarity.

- Add SqlModelSessionQueue alongside the legacy SqliteSessionQueue
- Add the missing `workflow` column to SessionQueueTable (was added by
  migration_2 but never declared on the SQLModel)
- Wire dependencies.py to the new implementation
- Add 36 unit tests covering enqueue/dequeue, status mutations, bulk
  cancel/delete, prune-to-limit, retry, pagination and aggregations
- Avoid nested write sessions on the single StaticPool connection by reading
  the current item before opening the outer write session
2026-04-20 23:43:24 +02:00

238 lines
11 KiB
Python

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from logging import Logger
import torch
from invokeai.app.services.app_settings.app_settings_sqlmodel import AppSettingsServiceSqlModel
from invokeai.app.services.auth.token_service import set_jwt_secret
from invokeai.app.services.board_image_records.board_image_records_sqlmodel import SqlModelBoardImageRecordStorage
from invokeai.app.services.board_images.board_images_default import BoardImagesService
from invokeai.app.services.board_records.board_records_sqlmodel import SqlModelBoardRecordStorage
from invokeai.app.services.boards.boards_default import BoardService
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlmodel import (
ClientStatePersistenceSqlModel,
)
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_default import DownloadQueueService
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
from invokeai.app.services.external_generation.external_generation_default import ExternalGenerationService
from invokeai.app.services.external_generation.providers import GeminiProvider, OpenAIProvider
from invokeai.app.services.external_generation.startup import sync_configured_external_starter_models
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.image_records.image_records_sqlmodel import SqlModelImageRecordStorage
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
from invokeai.app.services.model_records.model_records_sqlmodel import ModelRecordServiceSqlModel
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlmodel import (
SqlModelModelRelationshipRecordStorage,
)
from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService
from invokeai.app.services.names.names_default import SimpleNameService
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.session_processor.session_processor_default import (
DefaultSessionProcessor,
DefaultSessionRunner,
)
from invokeai.app.services.session_queue.session_queue_sqlmodel import SqlModelSessionQueue
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlmodel import SqlModelStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.users.users_sqlmodel import UserServiceSqlModel
from invokeai.app.services.workflow_records.workflow_records_sqlmodel import SqlModelWorkflowRecordsStorage
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
AnimaConditioningInfo,
BasicConditioningInfo,
CogView4ConditioningInfo,
ConditioningFieldData,
FLUXConditioningInfo,
QwenImageConditioningInfo,
SD3ConditioningInfo,
SDXLConditioningInfo,
ZImageConditioningInfo,
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
# TODO: is there a better way to achieve this?
def check_internet() -> bool:
"""
Return true if the internet is reachable.
It does this by pinging huggingface.co.
"""
import urllib.request
host = "http://huggingface.co"
try:
urllib.request.urlopen(host, timeout=1)
return True
except Exception:
return False
logger = InvokeAILogger.get_logger()
class ApiDependencies:
"""Contains and initializes all dependencies for the API"""
invoker: Invoker
@staticmethod
def initialize(
config: InvokeAIAppConfig,
event_handler_id: int,
loop: asyncio.AbstractEventLoop,
logger: Logger = logger,
) -> None:
logger.info(f"InvokeAI version {__version__}")
logger.info(f"Root directory = {str(config.root_path)}")
output_folder = config.outputs_path
if output_folder is None:
raise ValueError("Output folder is not set")
image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
style_presets_folder = config.style_presets_path
workflow_thumbnails_folder = config.workflow_thumbnails_path
db = init_db(config=config, logger=logger, image_files=image_files)
# Initialize JWT secret from database
app_settings = AppSettingsServiceSqlModel(db=db)
jwt_secret = app_settings.get_jwt_secret()
set_jwt_secret(jwt_secret)
logger.info("JWT secret loaded from database")
configuration = config
logger = logger
board_image_records = SqlModelBoardImageRecordStorage(db=db)
board_images = BoardImagesService()
board_records = SqlModelBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id, loop=loop)
bulk_download = BulkDownloadService()
image_records = SqlModelImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
tensors = ObjectSerializerForwardCache(
ObjectSerializerDisk[torch.Tensor](
output_folder / "tensors",
safe_globals=[torch.Tensor],
ephemeral=True,
),
)
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](
output_folder / "conditioning",
safe_globals=[
ConditioningFieldData,
BasicConditioningInfo,
SDXLConditioningInfo,
FLUXConditioningInfo,
SD3ConditioningInfo,
CogView4ConditioningInfo,
ZImageConditioningInfo,
QwenImageConditioningInfo,
AnimaConditioningInfo,
],
ephemeral=True,
),
)
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
model_record_service = ModelRecordServiceSqlModel(db=db, logger=logger)
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=model_record_service,
download_queue=download_queue_service,
events=events,
)
external_generation = ExternalGenerationService(
providers={
GeminiProvider.provider_id: GeminiProvider(app_config=configuration, logger=logger),
OpenAIProvider.provider_id: OpenAIProvider(app_config=configuration, logger=logger),
},
logger=logger,
record_store=model_record_service,
)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_relationships = ModelRelationshipsService()
model_relationship_records = SqlModelModelRelationshipRecordStorage(db=db)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
session_queue = SqlModelSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqlModelWorkflowRecordsStorage(db=db)
style_preset_records = SqlModelStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
client_state_persistence = ClientStatePersistenceSqlModel(db=db)
users = UserServiceSqlModel(db=db)
services = InvocationServices(
board_image_records=board_image_records,
board_images=board_images,
board_records=board_records,
boards=boards,
bulk_download=bulk_download,
configuration=configuration,
events=events,
image_files=image_files,
image_records=image_records,
images=images,
invocation_cache=invocation_cache,
logger=logger,
model_images=model_images_service,
model_manager=model_manager,
model_relationships=model_relationships,
model_relationship_records=model_relationship_records,
download_queue=download_queue_service,
external_generation=external_generation,
names=names,
performance_statistics=performance_statistics,
session_processor=session_processor,
session_queue=session_queue,
urls=urls,
workflow_records=workflow_records,
tensors=tensors,
conditioning=conditioning,
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
workflow_thumbnails=workflow_thumbnails,
client_state_persistence=client_state_persistence,
users=users,
)
ApiDependencies.invoker = Invoker(services)
configured_external_providers = {
provider_id
for provider_id, status in external_generation.get_provider_statuses().items()
if status.configured
}
sync_configured_external_starter_models(
configured_provider_ids=configured_external_providers,
model_manager=model_manager,
logger=logger,
)
db.clean()
@staticmethod
def shutdown() -> None:
if ApiDependencies.invoker:
ApiDependencies.invoker.stop()