From ea26b5b1471fca6925c1a3af97383b709903338f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 29 Jul 2025 19:46:50 +1000 Subject: [PATCH] feat(app): client state persistence endpoints accept stringified data --- invokeai/app/api/routers/app_info.py | 14 ++++---- .../client_state_persistence_base.py | 25 ++++++++----- .../client_state_persistence_sqlite.py | 36 +++++++++---------- 3 files changed, 41 insertions(+), 34 deletions(-) diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index fb949b74db..15cbe2700d 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -7,7 +7,7 @@ from typing import Optional import torch from fastapi import Body, HTTPException, Query from fastapi.routing import APIRouter -from pydantic import BaseModel, Field, JsonValue +from pydantic import BaseModel, Field from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.invocations.upscale import ESRGAN_MODELS @@ -178,11 +178,11 @@ async def get_invocation_cache_status() -> InvocationCacheStatus: @app_router.get( "/client_state", operation_id="get_client_state_by_key", - response_model=JsonValue | None, + response_model=str | None, ) async def get_client_state_by_key( key: str = Query(..., description="Key to get"), -) -> JsonValue | None: +) -> str | None: """Gets the client state""" try: return ApiDependencies.invoker.services.client_state_persistence.get_by_key(key) @@ -194,15 +194,15 @@ async def get_client_state_by_key( @app_router.post( "/client_state", operation_id="set_client_state", - response_model=None, + response_model=str, ) async def set_client_state( key: str = Query(..., description="Key to set"), - value: JsonValue = Body(..., description="Value of the key"), -) -> None: + value: str = Body(..., description="Stringified value to set"), +) -> str: """Sets the client state""" try: - ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value) + return ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value) except Exception as e: logging.error(f"Error setting client state: {e}") raise HTTPException(status_code=500, detail="Error setting client state") diff --git a/invokeai/app/services/client_state_persistence/client_state_persistence_base.py b/invokeai/app/services/client_state_persistence/client_state_persistence_base.py index 9f42210bd2..5b62470bc8 100644 --- a/invokeai/app/services/client_state_persistence/client_state_persistence_base.py +++ b/invokeai/app/services/client_state_persistence/client_state_persistence_base.py @@ -1,7 +1,5 @@ from abc import ABC, abstractmethod -from pydantic import JsonValue - class ClientStatePersistenceABC(ABC): """ @@ -10,26 +8,35 @@ class ClientStatePersistenceABC(ABC): """ @abstractmethod - def set_by_key(self, key: str, value: JsonValue) -> None: + def set_by_key(self, key: str, value: str) -> str: """ - Store the data for the client. + Set a key-value pair for the client. - :param data: The client data to be stored. + Args: + key (str): The key to set. + value (str): The value to set for the key. + + Returns: + str: The value that was set. """ pass @abstractmethod - def get_by_key(self, key: str) -> JsonValue | None: + def get_by_key(self, key: str) -> str | None: """ - Get the data for the client. + Get the value for a specific key of the client. - :return: The client data. + Args: + key (str): The key to retrieve the value for. + + Returns: + str | None: The value associated with the key, or None if the key does not exist. """ pass @abstractmethod def delete(self) -> None: """ - Delete the data for the client. + Delete all client state. """ pass diff --git a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py index 263fd02647..a5e0a0bbff 100644 --- a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py +++ b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py @@ -1,7 +1,5 @@ import json -from pydantic import JsonValue - from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase @@ -21,8 +19,21 @@ class ClientStatePersistenceSqlite(ClientStatePersistenceABC): def start(self, invoker: Invoker) -> None: self._invoker = invoker - def set_by_key(self, key: str, value: JsonValue) -> None: - state = self.get() or {} + def _get(self) -> dict[str, str] | None: + with self._db.transaction() as cursor: + cursor.execute( + f""" + SELECT data FROM client_state + WHERE id = {self._default_row_id} + """ + ) + row = cursor.fetchone() + if row is None: + return None + return json.loads(row[0]) + + def set_by_key(self, key: str, value: str) -> str: + state = self._get() or {} state.update({key: value}) with self._db.transaction() as cursor: @@ -36,21 +47,10 @@ class ClientStatePersistenceSqlite(ClientStatePersistenceABC): (json.dumps(state),), ) - def get(self) -> dict[str, JsonValue] | None: - with self._db.transaction() as cursor: - cursor.execute( - f""" - SELECT data FROM client_state - WHERE id = {self._default_row_id} - """ - ) - row = cursor.fetchone() - if row is None: - return None - return json.loads(row[0]) + return value - def get_by_key(self, key: str) -> JsonValue | None: - state = self.get() + def get_by_key(self, key: str) -> str | None: + state = self._get() if state is None: return None return state.get(key, None)