feat(app): client state persistence endpoints accept stringified data

This commit is contained in:
psychedelicious
2025-07-29 19:46:50 +10:00
parent 4226b741b1
commit ea26b5b147
3 changed files with 41 additions and 34 deletions

View File

@@ -7,7 +7,7 @@ from typing import Optional
import torch
from fastapi import Body, HTTPException, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, JsonValue
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.upscale import ESRGAN_MODELS
@@ -178,11 +178,11 @@ async def get_invocation_cache_status() -> InvocationCacheStatus:
@app_router.get(
"/client_state",
operation_id="get_client_state_by_key",
response_model=JsonValue | None,
response_model=str | None,
)
async def get_client_state_by_key(
key: str = Query(..., description="Key to get"),
) -> JsonValue | None:
) -> str | None:
"""Gets the client state"""
try:
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(key)
@@ -194,15 +194,15 @@ async def get_client_state_by_key(
@app_router.post(
"/client_state",
operation_id="set_client_state",
response_model=None,
response_model=str,
)
async def set_client_state(
key: str = Query(..., description="Key to set"),
value: JsonValue = Body(..., description="Value of the key"),
) -> None:
value: str = Body(..., description="Stringified value to set"),
) -> str:
"""Sets the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
except Exception as e:
logging.error(f"Error setting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")

View File

@@ -1,7 +1,5 @@
from abc import ABC, abstractmethod
from pydantic import JsonValue
class ClientStatePersistenceABC(ABC):
"""
@@ -10,26 +8,35 @@ class ClientStatePersistenceABC(ABC):
"""
@abstractmethod
def set_by_key(self, key: str, value: JsonValue) -> None:
def set_by_key(self, key: str, value: str) -> str:
"""
Store the data for the client.
Set a key-value pair for the client.
:param data: The client data to be stored.
Args:
key (str): The key to set.
value (str): The value to set for the key.
Returns:
str: The value that was set.
"""
pass
@abstractmethod
def get_by_key(self, key: str) -> JsonValue | None:
def get_by_key(self, key: str) -> str | None:
"""
Get the data for the client.
Get the value for a specific key of the client.
:return: The client data.
Args:
key (str): The key to retrieve the value for.
Returns:
str | None: The value associated with the key, or None if the key does not exist.
"""
pass
@abstractmethod
def delete(self) -> None:
"""
Delete the data for the client.
Delete all client state.
"""
pass

View File

@@ -1,7 +1,5 @@
import json
from pydantic import JsonValue
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@@ -21,8 +19,21 @@ class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def set_by_key(self, key: str, value: JsonValue) -> None:
state = self.get() or {}
def _get(self) -> dict[str, str] | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])
def set_by_key(self, key: str, value: str) -> str:
state = self._get() or {}
state.update({key: value})
with self._db.transaction() as cursor:
@@ -36,21 +47,10 @@ class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
(json.dumps(state),),
)
def get(self) -> dict[str, JsonValue] | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])
return value
def get_by_key(self, key: str) -> JsonValue | None:
state = self.get()
def get_by_key(self, key: str) -> str | None:
state = self._get()
if state is None:
return None
return state.get(key, None)