feat: server-side client state persistence

This commit is contained in:
psychedelicious
2025-07-21 18:08:55 +10:00
parent 70ac58e64a
commit 28633c9983
35 changed files with 575 additions and 118 deletions

View File

@@ -10,6 +10,7 @@ from invokeai.app.services.board_images.board_images_default import BoardImagesS
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.boards.boards_default import BoardService
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_default import DownloadQueueService
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
@@ -151,6 +152,7 @@ class ApiDependencies:
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
client_state_persistence = ClientStatePersistenceSqlite(db=db)
services = InvocationServices(
board_image_records=board_image_records,
@@ -181,6 +183,7 @@ class ApiDependencies:
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
workflow_thumbnails=workflow_thumbnails,
client_state_persistence=client_state_persistence,
)
ApiDependencies.invoker = Invoker(services)

View File

@@ -5,9 +5,9 @@ from pathlib import Path
from typing import Optional
import torch
from fastapi import Body
from fastapi import Body, HTTPException, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, JsonValue
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.upscale import ESRGAN_MODELS
@@ -173,3 +173,50 @@ async def disable_invocation_cache() -> None:
async def get_invocation_cache_status() -> InvocationCacheStatus:
"""Clears the invocation cache"""
return ApiDependencies.invoker.services.invocation_cache.get_status()
@app_router.get(
"/client_state",
operation_id="get_client_state_by_key",
response_model=JsonValue | None,
)
async def get_client_state_by_key(
key: str = Query(..., description="Key to retrieve from client state persistence"),
) -> JsonValue | None:
"""Gets the client state"""
try:
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(key)
except Exception as e:
logging.error(f"Error getting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@app_router.post(
"/client_state",
operation_id="set_client_state",
response_model=None,
)
async def set_client_state(
key: str = Body(..., description="Key to set"),
value: JsonValue = Body(..., description="Value of the key"),
) -> None:
"""Sets the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
except Exception as e:
logging.error(f"Error setting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@app_router.delete(
"/client_state",
operation_id="delete_client_state",
responses={204: {"description": "Client state deleted"}},
)
async def delete_client_state() -> None:
"""Deletes the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.delete()
except Exception as e:
logging.error(f"Error deleting client state: {e}")
raise HTTPException(status_code=500, detail="Error deleting client state")

View File

@@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
from pydantic import JsonValue
class ClientStatePersistenceABC(ABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
"""
@abstractmethod
def set_by_key(self, key: str, value: JsonValue) -> None:
"""
Store the data for the client.
:param data: The client data to be stored.
"""
pass
@abstractmethod
def get_by_key(self, key: str) -> JsonValue | None:
"""
Get the data for the client.
:return: The client data.
"""
pass
@abstractmethod
def delete(self) -> None:
"""
Delete the data for the client.
"""
pass

View File

@@ -0,0 +1,65 @@
import json
from pydantic import JsonValue
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
"""
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._default_row_id = 1
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def set_by_key(self, key: str, value: JsonValue) -> None:
state = self.get() or {}
state.update({key: value})
with self._db.transaction() as cursor:
cursor.execute(
f"""
INSERT INTO client_state (id, data)
VALUES ({self._default_row_id}, ?)
ON CONFLICT(id) DO UPDATE
SET data = excluded.data;
""",
(json.dumps(state),),
)
def get(self) -> dict[str, JsonValue] | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])
def get_by_key(self, key: str) -> JsonValue | None:
state = self.get()
if state is None:
return None
return state.get(key, None)
def delete(self) -> None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
DELETE FROM client_state
WHERE id = {self._default_row_id}
"""
)

View File

@@ -17,6 +17,7 @@ if TYPE_CHECKING:
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
@@ -73,6 +74,7 @@ class InvocationServices:
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase",
workflow_thumbnails: "WorkflowThumbnailServiceBase",
client_state_persistence: "ClientStatePersistenceABC",
):
self.board_images = board_images
self.board_image_records = board_image_records
@@ -102,3 +104,4 @@ class InvocationServices:
self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files
self.workflow_thumbnails = workflow_thumbnails
self.client_state_persistence = client_state_persistence

View File

@@ -23,6 +23,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -63,6 +64,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_18())
migrator.register_migration(build_migration_19(app_config=config))
migrator.register_migration(build_migration_20())
migrator.register_migration(build_migration_21())
migrator.run_migrations()
return db

View File

@@ -0,0 +1,40 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration21Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""
CREATE TABLE client_state (
id INTEGER PRIMARY KEY CHECK(id = 1),
data TEXT NOT NULL, -- Frontend will handle the shape of this data
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP)
);
"""
)
cursor.execute(
"""
CREATE TRIGGER tg_client_state_updated_at
AFTER UPDATE ON client_state
FOR EACH ROW
BEGIN
UPDATE client_state
SET updated_at = CURRENT_TIMESTAMP
WHERE id = OLD.id;
END;
"""
)
def build_migration_21() -> Migration:
"""Builds the migration object for migrating from version 20 to version 21. This includes:
- Creating the `client_state` table.
- Adding a trigger to update the `updated_at` field on updates.
"""
return Migration(
from_version=20,
to_version=21,
callback=Migration21Callback(),
)

View File

@@ -1,9 +1,12 @@
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
import { $authToken } from 'app/store/nanostores/authToken';
import { $projectId } from 'app/store/nanostores/projectId';
import type { UseStore } from 'idb-keyval';
import { clear, createStore as createIDBKeyValStore, get, set } from 'idb-keyval';
import { atom } from 'nanostores';
import type { Driver } from 'redux-remember';
import { getBaseUrl } from 'services/api';
import { buildAppInfoUrl } from 'services/api/endpoints/appInfo';
// Create a custom idb-keyval store (just needed to customize the name)
const $idbKeyValStore = atom<UseStore>(createIDBKeyValStore('invoke', 'invoke-store'));
@@ -38,3 +41,73 @@ export const idbKeyValDriver: Driver = {
}
},
};
const getHeaders = (extra?: Record<string, string>) => {
const headers = new Headers();
const authToken = $authToken.get();
if (authToken) {
headers.set('Authorization', `Bearer ${authToken}`);
}
const projectId = $projectId.get();
if (projectId) {
headers.set('project-id', projectId);
}
for (const [key, value] of Object.entries(extra ?? {})) {
headers.set(key, value);
}
return headers;
};
export const serverBackedDriver: Driver = {
getItem: async (key) => {
try {
const baseUrl = getBaseUrl();
const path = buildAppInfoUrl('client_state', { key });
const url = `${baseUrl}/${path}`;
const headers = getHeaders();
const res = await fetch(url, { headers, method: 'GET' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
const json = await res.json();
return json;
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: async (key, value) => {
try {
const baseUrl = getBaseUrl();
const path = buildAppInfoUrl('client_state');
const url = `${baseUrl}/${path}`;
const headers = getHeaders({ 'content-type': 'application/json' });
const res = await fetch(url, { headers, method: 'POST', body: JSON.stringify({ key, value }) });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
return value;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
}
},
};
export const resetClientState = async () => {
const baseUrl = getBaseUrl();
const path = buildAppInfoUrl('client_state');
const url = `${baseUrl}/${path}`;
const headers = getHeaders();
const res = await fetch(url, { headers, method: 'DELETE' });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
};

View File

@@ -1,73 +0,0 @@
import type { TypedStartListening } from '@reduxjs/toolkit';
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addImageUploadedFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageUploaded';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
const startAppListening = listenerMiddleware.startListening as AppStartListening;
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
/**
* The RTK listener middleware is a lightweight alternative sagas/observables.
*
* Most side effect logic should live in a listener.
*/
// Image uploaded
addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// User Invoked
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Socket.IO
addSocketConnectedEventListener(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);
// app startup
addAppStartedListener(startAppListening);
addModelsLoadedListener(startAppListening);
addAppConfigReceivedListener(startAppListening);
// Ad-hoc upscale workflwo
addAdHocPostProcessingRequestedListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@@ -1,6 +1,6 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';

View File

@@ -1,5 +1,5 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
autoAddBoardIdChanged,

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';

View File

@@ -1,5 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { truncate } from 'es-toolkit/compat';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/state';

View File

@@ -1,5 +1,5 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { parseify } from 'common/util/serialize';
import { size } from 'es-toolkit/compat';
import { $templates } from 'features/nodes/store/nodesSlice';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');

View File

@@ -1,7 +1,6 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import type { AppStartListening, RootState } from 'app/store/store';
import { omit } from 'es-toolkit/compat';
import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';

View File

@@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import type { AppDispatch, AppStartListening, RootState } from 'app/store/store';
import { controlLayerModelChanged, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import {

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';

View File

@@ -1,8 +1,8 @@
import { objectEquals } from '@observ33r/object-equals';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import type { AppStartListening } from 'app/store/store';
import { atom } from 'nanostores';
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';

View File

@@ -1,8 +1,23 @@
import type { ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/toolkit';
import type { ThunkDispatch, TypedStartListening, UnknownAction } from '@reduxjs/toolkit';
import { addListener, combineReducers, configureStore, createAction, createListenerMiddleware } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { serverBackedDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { deepClone } from 'common/util/deepClone';
import { keys, mergeWith, omit, pick } from 'es-toolkit/compat';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
@@ -28,20 +43,24 @@ import { configSlice } from 'features/system/store/configSlice';
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice';
import { diff } from 'jsondiffpatch';
import { atom } from 'nanostores';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import { REMEMBER_PERSISTED, rememberEnhancer, rememberReducer } from 'redux-remember';
import undoable, { newHistory } from 'redux-undo';
import { serializeError } from 'serialize-error';
import { api } from 'services/api';
import { authToastMiddleware } from 'services/api/authToastMiddleware';
import type { JsonObject } from 'type-fest';
import { STORAGE_PREFIX } from './constants';
import { getDebugLoggerMiddleware } from './middleware/debugLoggerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { listenerMiddleware } from './middleware/listenerMiddleware';
import { addArchivedOrDeletedBoardListener } from './middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener';
import { addImageUploadedFulfilledListener } from './middleware/listenerMiddleware/listeners/imageUploaded';
export const listenerMiddleware = createListenerMiddleware();
const log = logger('system');
@@ -94,7 +113,7 @@ export type PersistConfig<T = any> = {
persistDenylist: (keyof T)[];
};
const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
export const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[galleryPersistConfig.name]: galleryPersistConfig,
[nodesPersistConfig.name]: nodesPersistConfig,
[systemPersistConfig.name]: systemPersistConfig,
@@ -113,6 +132,8 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[refImagesSlice.name]: refImagesPersistConfig,
};
export const $isPendingPersist = atom(false);
const unserialize: UnserializeFunction = (data, key) => {
const persistConfig = persistConfigs[key as keyof typeof persistConfigs];
if (!persistConfig) {
@@ -176,28 +197,32 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
serializableCheck: import.meta.env.MODE === 'development',
immutableCheck: import.meta.env.MODE === 'development',
serializableCheck: false,
immutableCheck: false,
// serializableCheck: import.meta.env.MODE === 'development',
// immutableCheck: import.meta.env.MODE === 'development',
})
.concat(api.middleware)
.concat(dynamicMiddlewares)
.concat(authToastMiddleware)
// .concat(getDebugLoggerMiddleware())
.concat(getDebugLoggerMiddleware())
.prepend(listenerMiddleware.middleware),
enhancers: (getDefaultEnhancers) => {
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());
const enhancers = getDefaultEnhancers();
if (persist) {
_enhancers.push(
rememberEnhancer(idbKeyValDriver, keys(persistConfigs), {
persistDebounce: 300,
const res = enhancers.prepend(
rememberEnhancer(serverBackedDriver, keys(persistConfigs), {
persistDebounce: 3000,
serialize,
unserialize,
prefix: uniqueStoreKey ? `${STORAGE_PREFIX}${uniqueStoreKey}-` : STORAGE_PREFIX,
prefix: '',
errorHandler,
})
);
return res;
} else {
return enhancers;
}
return _enhancers;
},
devTools: {
actionSanitizer,
@@ -218,3 +243,76 @@ export type RootState = ReturnType<AppStore['getState']>;
export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>;
export type AppDispatch = ReturnType<typeof createStore>['dispatch'];
export type AppGetState = ReturnType<typeof createStore>['getState'];
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
const startAppListening = listenerMiddleware.startListening as AppStartListening;
addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// User Invoked
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Socket.IO
addSocketConnectedEventListener(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);
// app startup
addAppStartedListener(startAppListening);
addModelsLoadedListener(startAppListening);
addAppConfigReceivedListener(startAppListening);
// Ad-hoc upscale workflwo
addAdHocPostProcessingRequestedListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);
const addPersistenceListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action, currentRootState, originalRootState) => {
for (const { name, persistDenylist } of Object.values(persistConfigs)) {
const originalState = originalRootState[name];
const currentState = currentRootState[name];
for (const [k, v] of Object.entries(currentState)) {
if (persistDenylist.includes(k)) {
continue;
}
if (v !== originalState[k as keyof typeof originalState]) {
return true;
}
}
}
return false;
},
effect: () => {
$isPendingPersist.set(true);
},
});
startAppListening({
matcher: createAction(REMEMBER_PERSISTED).match,
effect: () => {
$isPendingPersist.set(false);
},
});
};
addPersistenceListener(startAppListening);

View File

@@ -1,9 +1,10 @@
import { clearIdbKeyValStore } from 'app/store/enhancers/reduxRemember/driver';
import { resetClientState } from 'app/store/enhancers/reduxRemember/driver';
import { useCallback } from 'react';
export const useClearStorage = () => {
const clearStorage = useCallback(() => {
clearIdbKeyValStore();
// clearIdbKeyValStore();
resetClientState();
localStorage.clear();
}, []);

View File

@@ -1,7 +1,7 @@
import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Selector } from '@reduxjs/toolkit';
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
import type { AppStore, RootState } from 'app/store/store';
import { addAppListener } from 'app/store/store';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';

View File

@@ -77,7 +77,7 @@ export const dynamicPromptsPersistConfig: PersistConfig<DynamicPromptsState> = {
name: dynamicPromptsSlice.name,
initialState: initialDynamicPromptsState,
migrate: migrateDynamicPromptsState,
persistDenylist: ['prompts'],
persistDenylist: ['prompts', 'parsingError', 'isError', 'isLoading'],
};
export const selectDynamicPromptsSlice = (state: RootState) => state.dynamicPrompts;

View File

@@ -1,6 +1,7 @@
import { Flex, Spacer } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $customNavComponent } from 'app/store/nanostores/customNavComponent';
import { $isPendingPersist } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import InvokeAILogoComponent from 'features/system/components/InvokeAILogoComponent';
import SettingsMenu from 'features/system/components/SettingsModal/SettingsMenu';
@@ -37,6 +38,7 @@ export const VerticalNavBar = memo(() => {
const withWorkflowsTab = useAppSelector(selectWithWorkflowsTab);
const withModelsTab = useAppSelector(selectWithModelsTab);
const withQueueTab = useAppSelector(selectWithQueueTab);
const isPendingPersist = useStore($isPendingPersist);
return (
<Flex flexDir="column" alignItems="center" py={6} ps={4} pe={2} gap={4} minW={0} flexShrink={0}>
@@ -48,6 +50,7 @@ export const VerticalNavBar = memo(() => {
{withWorkflowsTab && <TabButton tab="workflows" icon={<PiFlowArrowBold />} label={t('ui.tabs.workflows')} />}
{withModelsTab && <TabButton tab="models" icon={<PiCubeBold />} label={t('ui.tabs.models')} />}
{withQueueTab && <TabButton tab="queue" icon={<PiQueueBold />} label={t('ui.tabs.queue')} />}
{isPendingPersist && <Flex w={4} h={4} bg="red" />}
</Flex>
<Spacer />
<StatusIndicator />

View File

@@ -1,5 +1,6 @@
import { $openAPISchemaUrl } from 'app/store/nanostores/openAPISchemaUrl';
import type { OpenAPIV3_1 } from 'openapi-types';
import type { stringify } from 'querystring';
import type { paths } from 'services/api/schema';
import type { AppConfig, AppVersion } from 'services/api/types';
@@ -11,7 +12,8 @@ import { api, buildV1Url } from '..';
* buildAppInfoUrl('some-path')
* // '/api/v1/app/some-path'
*/
const buildAppInfoUrl = (path: string = '') => buildV1Url(`app/${path}`);
export const buildAppInfoUrl = (path: string = '', query?: Parameters<typeof stringify>[0]) =>
buildV1Url(`app/${path}`, query);
export const appInfoApi = api.injectEndpoints({
endpoints: (build) => ({
@@ -87,6 +89,31 @@ export const appInfoApi = api.injectEndpoints({
},
providesTags: ['Schema'],
}),
getClientStateByKey: build.query<
paths['/api/v1/app/client_state']['get']['responses']['200']['content']['application/json'],
paths['/api/v1/app/client_state']['get']['parameters']['query']
>({
query: () => ({
url: buildAppInfoUrl('client_state'),
method: 'GET',
}),
}),
setClientStateByKey: build.mutation<
paths['/api/v1/app/client_state']['post']['responses']['200']['content']['application/json'],
paths['/api/v1/app/client_state']['post']['requestBody']['content']['application/json']
>({
query: (body) => ({
url: buildAppInfoUrl('client_state'),
method: 'POST',
body,
}),
}),
deleteClientState: build.mutation<void, void>({
query: () => ({
url: buildAppInfoUrl('client_state'),
method: 'DELETE',
}),
}),
}),
});

View File

@@ -57,13 +57,18 @@ const tagTypes = [
// This is invalidated on reconnect. It should be used for queries that have changing data,
// especially related to the queue and generation.
'FetchOnReconnect',
'ClientState',
] as const;
export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>;
export const LIST_TAG = 'LIST';
export const LIST_ALL_TAG = 'LIST_ALL';
const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryError> = (args, api, extraOptions) => {
export const getBaseUrl = (): string => {
const baseUrl = $baseUrl.get();
return baseUrl || window.location.href.replace(/\/$/, '');
};
const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryError> = (args, api, extraOptions) => {
const authToken = $authToken.get();
const projectId = $projectId.get();
const isOpenAPIRequest =
@@ -71,7 +76,7 @@ const dynamicBaseQuery: BaseQueryFn<string | FetchArgs, unknown, FetchBaseQueryE
(typeof args === 'string' && args.includes('openapi.json'));
const fetchBaseQueryArgs: FetchBaseQueryArgs = {
baseUrl: baseUrl || window.location.href.replace(/\/$/, ''),
baseUrl: getBaseUrl(),
};
// When fetching the openapi.json, we need to remove circular references from the JSON.

View File

@@ -1164,6 +1164,34 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/app/client_state": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Get Client State By Key
* @description Gets the client state
*/
get: operations["get_client_state_by_key"];
put?: never;
/**
* Set Client State
* @description Sets the client state
*/
post: operations["set_client_state"];
/**
* Delete Client State
* @description Deletes the client state
*/
delete: operations["delete_client_state"];
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/queue/{queue_id}/enqueue_batch": {
parameters: {
query?: never;
@@ -2726,6 +2754,16 @@ export type components = {
*/
image_names: string[];
};
/** Body_set_client_state */
Body_set_client_state: {
/**
* Key
* @description Key to set
*/
key: string;
/** @description Value of the key */
value: components["schemas"]["JsonValue"];
};
/** Body_set_workflow_thumbnail */
Body_set_workflow_thumbnail: {
/**
@@ -24697,6 +24735,98 @@ export interface operations {
};
};
};
get_client_state_by_key: {
parameters: {
query: {
/** @description Key to retrieve from client state persistence */
key: string;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["JsonValue"] | null;
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
set_client_state: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["Body_set_client_state"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": unknown;
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
delete_client_state: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": unknown;
};
};
/** @description Client state deleted */
204: {
headers: {
[name: string]: unknown;
};
content?: never;
};
};
};
enqueue_batch: {
parameters: {
query?: never;

View File

@@ -1,12 +1,12 @@
import { ExternalLink } from '@invoke-ai/ui-library';
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { listenerMiddleware } from 'app/store/middleware/listenerMiddleware';
import { socketConnected } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId';
import { $queueId } from 'app/store/nanostores/queueId';
import type { AppStore } from 'app/store/store';
import { listenerMiddleware } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { forEach, isNil, round } from 'es-toolkit/compat';
import {