mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 05:05:21 -05:00
feat(app): client state persistence endpoints accept stringified data
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import Optional
|
||||
import torch
|
||||
from fastapi import Body, HTTPException, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
@@ -178,11 +178,11 @@ async def get_invocation_cache_status() -> InvocationCacheStatus:
|
||||
@app_router.get(
|
||||
"/client_state",
|
||||
operation_id="get_client_state_by_key",
|
||||
response_model=JsonValue | None,
|
||||
response_model=str | None,
|
||||
)
|
||||
async def get_client_state_by_key(
|
||||
key: str = Query(..., description="Key to get"),
|
||||
) -> JsonValue | None:
|
||||
) -> str | None:
|
||||
"""Gets the client state"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(key)
|
||||
@@ -194,15 +194,15 @@ async def get_client_state_by_key(
|
||||
@app_router.post(
|
||||
"/client_state",
|
||||
operation_id="set_client_state",
|
||||
response_model=None,
|
||||
response_model=str,
|
||||
)
|
||||
async def set_client_state(
|
||||
key: str = Query(..., description="Key to set"),
|
||||
value: JsonValue = Body(..., description="Value of the key"),
|
||||
) -> None:
|
||||
value: str = Body(..., description="Stringified value to set"),
|
||||
) -> str:
|
||||
"""Sets the client state"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
|
||||
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
|
||||
except Exception as e:
|
||||
logging.error(f"Error setting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error setting client state")
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
|
||||
class ClientStatePersistenceABC(ABC):
|
||||
"""
|
||||
@@ -10,26 +8,35 @@ class ClientStatePersistenceABC(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_by_key(self, key: str, value: JsonValue) -> None:
|
||||
def set_by_key(self, key: str, value: str) -> str:
|
||||
"""
|
||||
Store the data for the client.
|
||||
Set a key-value pair for the client.
|
||||
|
||||
:param data: The client data to be stored.
|
||||
Args:
|
||||
key (str): The key to set.
|
||||
value (str): The value to set for the key.
|
||||
|
||||
Returns:
|
||||
str: The value that was set.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_by_key(self, key: str) -> JsonValue | None:
|
||||
def get_by_key(self, key: str) -> str | None:
|
||||
"""
|
||||
Get the data for the client.
|
||||
Get the value for a specific key of the client.
|
||||
|
||||
:return: The client data.
|
||||
Args:
|
||||
key (str): The key to retrieve the value for.
|
||||
|
||||
Returns:
|
||||
str | None: The value associated with the key, or None if the key does not exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete the data for the client.
|
||||
Delete all client state.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
@@ -21,8 +19,21 @@ class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def set_by_key(self, key: str, value: JsonValue) -> None:
|
||||
state = self.get() or {}
|
||||
def _get(self) -> dict[str, str] | None:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT data FROM client_state
|
||||
WHERE id = {self._default_row_id}
|
||||
"""
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return json.loads(row[0])
|
||||
|
||||
def set_by_key(self, key: str, value: str) -> str:
|
||||
state = self._get() or {}
|
||||
state.update({key: value})
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
@@ -36,21 +47,10 @@ class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
|
||||
(json.dumps(state),),
|
||||
)
|
||||
|
||||
def get(self) -> dict[str, JsonValue] | None:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT data FROM client_state
|
||||
WHERE id = {self._default_row_id}
|
||||
"""
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return json.loads(row[0])
|
||||
return value
|
||||
|
||||
def get_by_key(self, key: str) -> JsonValue | None:
|
||||
state = self.get()
|
||||
def get_by_key(self, key: str) -> str | None:
|
||||
state = self._get()
|
||||
if state is None:
|
||||
return None
|
||||
return state.get(key, None)
|
||||
|
||||
Reference in New Issue
Block a user