feat(backend): implement KV data storage blocks (#10294)

This PR introduces key-value storage blocks.

### Changes 🏗️

- **Database Schema**: Add AgentNodeExecutionKeyValueData table with
composite primary key (userId, key)
- **Persistence Blocks**: Create PersistInformationBlock and
RetrieveInformationBlock in persistence.py
- **Scope-based Storage**: Support for within_agent (per agent) vs
across_agents (global user) persistence
- **Key Structure**: Use formal # delimiter for storage keys:
`agent#{graph_id}#{key}` and `global#{key}`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run all 244 block tests - all passing 
  - [x] Test PersistInformationBlock with mock data storage
  - [x] Test RetrieveInformationBlock with mock data retrieval
- [x] Verify scope-based key generation (within_agent vs across_agents)
  - [x] Verify database function integration through all manager classes
  - [x] Run lint and type checking - all passing 
  - [x] Verify database migration is included and valid

#### For configuration changes:
- [x] `.env.example` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

Note: This change adds database schema and new blocks but doesn't
require environment or docker-compose changes as it uses existing
database infrastructure.

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Zamil Majdy
2025-07-03 07:24:51 -07:00
committed by GitHub
parent 90fb223114
commit 095199bfa6
7 changed files with 254 additions and 14 deletions

View File

@@ -0,0 +1,153 @@
import logging
from typing import Any, Literal
from autogpt_libs.utils.cache import thread_cached
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
def get_storage_key(key: str, scope: str, graph_id: str) -> str:
"""Generate the storage key based on scope"""
if scope == "within_agent":
return f"agent#{graph_id}#{key}"
elif scope == "across_agents":
return f"global#{key}"
else:
return f"agent#{graph_id}#{key}"
class PersistInformationBlock(Block):
"""Block for persisting key-value data for the current user with configurable scope"""
class Input(BlockSchema):
key: str = SchemaField(description="Key to store the information under")
value: Any = SchemaField(description="Value to store")
scope: Literal["within_agent", "across_agents"] = SchemaField(
description="Scope of persistence: within_agent (shared across all runs of this agent) or across_agents (shared across all agents for this user)",
default="within_agent",
)
class Output(BlockSchema):
value: Any = SchemaField(description="Value that was stored")
def __init__(self):
super().__init__(
id="1d055e55-a2b9-4547-8311-907d05b0304d",
description="Persist key-value information for the current user",
categories={BlockCategory.DATA},
input_schema=PersistInformationBlock.Input,
output_schema=PersistInformationBlock.Output,
test_input={
"key": "user_preference",
"value": {"theme": "dark", "language": "en"},
"scope": "within_agent",
},
test_output=[
("value", {"theme": "dark", "language": "en"}),
],
test_mock={
"_store_data": lambda *args, **kwargs: {
"theme": "dark",
"language": "en",
}
},
)
async def run(
self,
input_data: Input,
*,
user_id: str,
graph_id: str,
node_exec_id: str,
**kwargs,
) -> BlockOutput:
# Determine the storage key based on scope
storage_key = get_storage_key(input_data.key, input_data.scope, graph_id)
# Store the data
yield "value", await self._store_data(
user_id=user_id,
node_exec_id=node_exec_id,
key=storage_key,
data=input_data.value,
)
async def _store_data(
self, user_id: str, node_exec_id: str, key: str, data: Any
) -> Any | None:
return await get_database_manager_client().set_execution_kv_data(
user_id=user_id,
node_exec_id=node_exec_id,
key=key,
data=data,
)
class RetrieveInformationBlock(Block):
"""Block for retrieving key-value data for the current user with configurable scope"""
class Input(BlockSchema):
key: str = SchemaField(description="Key to retrieve the information for")
scope: Literal["within_agent", "across_agents"] = SchemaField(
description="Scope of persistence: within_agent (shared across all runs of this agent) or across_agents (shared across all agents for this user)",
default="within_agent",
)
default_value: Any = SchemaField(
description="Default value to return if key is not found", default=None
)
class Output(BlockSchema):
value: Any = SchemaField(description="Retrieved value or default value")
def __init__(self):
super().__init__(
id="d8710fc9-6e29-481e-a7d5-165eb16f8471",
description="Retrieve key-value information for the current user",
categories={BlockCategory.DATA},
input_schema=RetrieveInformationBlock.Input,
output_schema=RetrieveInformationBlock.Output,
test_input={
"key": "user_preference",
"scope": "within_agent",
"default_value": {"theme": "light", "language": "en"},
},
test_output=[
("value", {"theme": "light", "language": "en"}),
],
test_mock={"_retrieve_data": lambda *args, **kwargs: None},
)
async def run(
self, input_data: Input, *, user_id: str, graph_id: str, **kwargs
) -> BlockOutput:
# Determine the storage key based on scope
storage_key = get_storage_key(input_data.key, input_data.scope, graph_id)
# Retrieve the data
stored_value = await self._retrieve_data(
user_id=user_id,
key=storage_key,
)
if stored_value is not None:
yield "value", stored_value
else:
yield "value", input_data.default_value
async def _retrieve_data(self, user_id: str, key: str) -> Any | None:
return await get_database_manager_client().get_execution_kv_data(
user_id=user_id,
key=key,
)

View File

@@ -26,10 +26,10 @@ logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManagerClient
from backend.executor import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerClient)
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
@@ -273,7 +273,7 @@ class SmartDecisionMakerBlock(Block):
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
@staticmethod
def _create_block_function_signature(
async def _create_block_function_signature(
sink_node: "Node", links: list["Link"]
) -> dict[str, Any]:
"""
@@ -312,7 +312,7 @@ class SmartDecisionMakerBlock(Block):
return {"type": "function", "function": tool_function}
@staticmethod
def _create_agent_function_signature(
async def _create_agent_function_signature(
sink_node: "Node", links: list["Link"]
) -> dict[str, Any]:
"""
@@ -334,7 +334,7 @@ class SmartDecisionMakerBlock(Block):
raise ValueError("Graph ID or Graph Version not found in sink node.")
db_client = get_database_manager_client()
sink_graph_meta = db_client.get_graph_metadata(graph_id, graph_version)
sink_graph_meta = await db_client.get_graph_metadata(graph_id, graph_version)
if not sink_graph_meta:
raise ValueError(
f"Sink graph metadata not found: {graph_id} {graph_version}"
@@ -374,7 +374,7 @@ class SmartDecisionMakerBlock(Block):
return {"type": "function", "function": tool_function}
@staticmethod
def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
async def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
"""
Creates function signatures for tools linked to a specified node within a graph.
@@ -396,13 +396,13 @@ class SmartDecisionMakerBlock(Block):
db_client = get_database_manager_client()
tools = [
(link, node)
for link, node in db_client.get_connected_output_nodes(node_id)
for link, node in await db_client.get_connected_output_nodes(node_id)
if link.source_name.startswith("tools_^_") and link.source_id == node_id
]
if not tools:
raise ValueError("There is no next node to execute.")
return_tool_functions = []
return_tool_functions: list[dict[str, Any]] = []
grouped_tool_links: dict[str, tuple["Node", list["Link"]]] = {}
for link, node in tools:
@@ -417,13 +417,13 @@ class SmartDecisionMakerBlock(Block):
if sink_node.block_id == AgentExecutorBlock().id:
return_tool_functions.append(
SmartDecisionMakerBlock._create_agent_function_signature(
await SmartDecisionMakerBlock._create_agent_function_signature(
sink_node, links
)
)
else:
return_tool_functions.append(
SmartDecisionMakerBlock._create_block_function_signature(
await SmartDecisionMakerBlock._create_block_function_signature(
sink_node, links
)
)
@@ -442,7 +442,7 @@ class SmartDecisionMakerBlock(Block):
user_id: str,
**kwargs,
) -> BlockOutput:
tool_functions = self._create_function_signature(node_id)
tool_functions = await self._create_function_signature(node_id)
yield "tool_functions", json.dumps(tool_functions)
input_data.conversation_history = input_data.conversation_history or []

View File

@@ -161,7 +161,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
)
test_graph = await create_graph(server, test_graph, test_user)
tool_functions = SmartDecisionMakerBlock._create_function_signature(
tool_functions = await SmartDecisionMakerBlock._create_function_signature(
test_graph.nodes[0].id
)
assert tool_functions is not None, "Tool functions should not be None"

View File

@@ -22,6 +22,7 @@ from prisma.models import (
AgentGraphExecution,
AgentNodeExecution,
AgentNodeExecutionInputOutput,
AgentNodeExecutionKeyValueData,
)
from prisma.types import (
AgentGraphExecutionCreateInput,
@@ -29,6 +30,7 @@ from prisma.types import (
AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput,
AgentNodeExecutionInputOutputCreateInput,
AgentNodeExecutionKeyValueDataCreateInput,
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
)
@@ -907,3 +909,57 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
) -> AsyncGenerator[ExecutionEvent, None]:
async for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
yield event
# --------------------- KV Data Functions --------------------- #
async def get_execution_kv_data(user_id: str, key: str) -> Any | None:
"""
Get key-value data for a user and key.
Args:
user_id: The id of the User.
key: The key to retrieve data for.
Returns:
The data associated with the key, or None if not found.
"""
kv_data = await AgentNodeExecutionKeyValueData.prisma().find_unique(
where={"userId_key": {"userId": user_id, "key": key}}
)
return (
type_utils.convert(kv_data.data, type[Any])
if kv_data and kv_data.data
else None
)
async def set_execution_kv_data(
user_id: str, node_exec_id: str, key: str, data: Any
) -> Any | None:
"""
Set key-value data for a user and key.
Args:
user_id: The id of the User.
node_exec_id: The id of the AgentNodeExecution.
key: The key to store data under.
data: The data to store.
"""
resp = await AgentNodeExecutionKeyValueData.prisma().upsert(
where={"userId_key": {"userId": user_id, "key": key}},
data={
"create": AgentNodeExecutionKeyValueDataCreateInput(
userId=user_id,
agentNodeExecutionId=node_exec_id,
key=key,
data=Json(data) if data is not None else None,
),
"update": {
"agentNodeExecutionId": node_exec_id,
"data": Json(data) if data is not None else None,
},
},
)
return type_utils.convert(resp.data, type[Any]) if resp and resp.data else None

View File

@@ -5,12 +5,14 @@ from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
create_graph_execution,
get_execution_kv_data,
get_graph_execution,
get_graph_execution_meta,
get_graph_executions,
get_latest_node_execution,
get_node_execution,
get_node_executions,
set_execution_kv_data,
update_graph_execution_start_time,
update_graph_execution_stats,
update_node_execution_stats,
@@ -101,6 +103,8 @@ class DatabaseManager(AppService):
update_node_execution_stats = _(update_node_execution_stats)
upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output)
get_execution_kv_data = _(get_execution_kv_data)
set_execution_kv_data = _(set_execution_kv_data)
# Graphs
get_node = _(get_node)
@@ -159,6 +163,8 @@ class DatabaseManagerClient(AppServiceClient):
update_node_execution_stats = _(d.update_node_execution_stats)
upsert_execution_input = _(d.upsert_execution_input)
upsert_execution_output = _(d.upsert_execution_output)
get_execution_kv_data = _(d.get_execution_kv_data)
set_execution_kv_data = _(d.set_execution_kv_data)
# Graphs
get_node = _(d.get_node)
@@ -202,8 +208,10 @@ class DatabaseManagerAsyncClient(AppServiceClient):
return DatabaseManager
create_graph_execution = d.create_graph_execution
get_connected_output_nodes = d.get_connected_output_nodes
get_latest_node_execution = d.get_latest_node_execution
get_graph = d.get_graph
get_graph_metadata = d.get_graph_metadata
get_graph_execution_meta = d.get_graph_execution_meta
get_node = d.get_node
get_node_execution = d.get_node_execution
@@ -216,3 +224,5 @@ class DatabaseManagerAsyncClient(AppServiceClient):
update_node_execution_status = d.update_node_execution_status
update_node_execution_status_batch = d.update_node_execution_status_batch
update_user_integrations = d.update_user_integrations
get_execution_kv_data = d.get_execution_kv_data
set_execution_kv_data = d.set_execution_kv_data

View File

@@ -0,0 +1,11 @@
-- CreateTable
CREATE TABLE "AgentNodeExecutionKeyValueData" (
"userId" TEXT NOT NULL,
"key" TEXT NOT NULL,
"agentNodeExecutionId" TEXT NOT NULL,
"data" JSONB,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3),
CONSTRAINT "AgentNodeExecutionKeyValueData_pkey" PRIMARY KEY ("userId","key")
);

View File

@@ -413,6 +413,16 @@ model AgentNodeExecutionInputOutput {
@@index([name, time])
}
model AgentNodeExecutionKeyValueData {
userId String
key String
agentNodeExecutionId String
data Json?
createdAt DateTime @default(now())
updatedAt DateTime? @updatedAt
@@id([userId, key])
}
// Webhook that is registered with a provider and propagates to one or more nodes
model IntegrationWebhook {
id String @id @default(uuid())
@@ -432,8 +442,8 @@ model IntegrationWebhook {
providerWebhookId String // Webhook ID assigned by the provider
AgentNodes AgentNode[]
AgentPresets AgentPreset[]
AgentNodes AgentNode[]
AgentPresets AgentPreset[]
@@index([userId])
}