mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 07:08:09 -05:00
feat(backend): implement KV data storage blocks (#10294)
This PR introduces key-value storage blocks. ### Changes 🏗️ - **Database Schema**: Add AgentNodeExecutionKeyValueData table with composite primary key (userId, key) - **Persistence Blocks**: Create PersistInformationBlock and RetrieveInformationBlock in persistence.py - **Scope-based Storage**: Support for within_agent (per agent) vs across_agents (global user) persistence - **Key Structure**: Use formal # delimiter for storage keys: `agent#{graph_id}#{key}` and `global#{key}` ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Run all 244 block tests - all passing ✅ - [x] Test PersistInformationBlock with mock data storage - [x] Test RetrieveInformationBlock with mock data retrieval - [x] Verify scope-based key generation (within_agent vs across_agents) - [x] Verify database function integration through all manager classes - [x] Run lint and type checking - all passing ✅ - [x] Verify database migration is included and valid #### For configuration changes: - [x] `.env.example` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes - [x] I have included a list of my configuration changes in the PR description (under **Changes**) Note: This change adds database schema and new blocks but doesn't require environment or docker-compose changes as it uses existing database infrastructure. --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
153
autogpt_platform/backend/backend/blocks/persistence.py
Normal file
153
autogpt_platform/backend/backend/blocks/persistence.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
|
||||
|
||||
def get_storage_key(key: str, scope: str, graph_id: str) -> str:
|
||||
"""Generate the storage key based on scope"""
|
||||
if scope == "within_agent":
|
||||
return f"agent#{graph_id}#{key}"
|
||||
elif scope == "across_agents":
|
||||
return f"global#{key}"
|
||||
else:
|
||||
return f"agent#{graph_id}#{key}"
|
||||
|
||||
|
||||
class PersistInformationBlock(Block):
|
||||
"""Block for persisting key-value data for the current user with configurable scope"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
key: str = SchemaField(description="Key to store the information under")
|
||||
value: Any = SchemaField(description="Value to store")
|
||||
scope: Literal["within_agent", "across_agents"] = SchemaField(
|
||||
description="Scope of persistence: within_agent (shared across all runs of this agent) or across_agents (shared across all agents for this user)",
|
||||
default="within_agent",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
value: Any = SchemaField(description="Value that was stored")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1d055e55-a2b9-4547-8311-907d05b0304d",
|
||||
description="Persist key-value information for the current user",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=PersistInformationBlock.Input,
|
||||
output_schema=PersistInformationBlock.Output,
|
||||
test_input={
|
||||
"key": "user_preference",
|
||||
"value": {"theme": "dark", "language": "en"},
|
||||
"scope": "within_agent",
|
||||
},
|
||||
test_output=[
|
||||
("value", {"theme": "dark", "language": "en"}),
|
||||
],
|
||||
test_mock={
|
||||
"_store_data": lambda *args, **kwargs: {
|
||||
"theme": "dark",
|
||||
"language": "en",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Determine the storage key based on scope
|
||||
storage_key = get_storage_key(input_data.key, input_data.scope, graph_id)
|
||||
|
||||
# Store the data
|
||||
yield "value", await self._store_data(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
key=storage_key,
|
||||
data=input_data.value,
|
||||
)
|
||||
|
||||
async def _store_data(
|
||||
self, user_id: str, node_exec_id: str, key: str, data: Any
|
||||
) -> Any | None:
|
||||
return await get_database_manager_client().set_execution_kv_data(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
key=key,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
class RetrieveInformationBlock(Block):
|
||||
"""Block for retrieving key-value data for the current user with configurable scope"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
key: str = SchemaField(description="Key to retrieve the information for")
|
||||
scope: Literal["within_agent", "across_agents"] = SchemaField(
|
||||
description="Scope of persistence: within_agent (shared across all runs of this agent) or across_agents (shared across all agents for this user)",
|
||||
default="within_agent",
|
||||
)
|
||||
default_value: Any = SchemaField(
|
||||
description="Default value to return if key is not found", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
value: Any = SchemaField(description="Retrieved value or default value")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d8710fc9-6e29-481e-a7d5-165eb16f8471",
|
||||
description="Retrieve key-value information for the current user",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=RetrieveInformationBlock.Input,
|
||||
output_schema=RetrieveInformationBlock.Output,
|
||||
test_input={
|
||||
"key": "user_preference",
|
||||
"scope": "within_agent",
|
||||
"default_value": {"theme": "light", "language": "en"},
|
||||
},
|
||||
test_output=[
|
||||
("value", {"theme": "light", "language": "en"}),
|
||||
],
|
||||
test_mock={"_retrieve_data": lambda *args, **kwargs: None},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, user_id: str, graph_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Determine the storage key based on scope
|
||||
storage_key = get_storage_key(input_data.key, input_data.scope, graph_id)
|
||||
|
||||
# Retrieve the data
|
||||
stored_value = await self._retrieve_data(
|
||||
user_id=user_id,
|
||||
key=storage_key,
|
||||
)
|
||||
|
||||
if stored_value is not None:
|
||||
yield "value", stored_value
|
||||
else:
|
||||
yield "value", input_data.default_value
|
||||
|
||||
async def _retrieve_data(self, user_id: str, key: str) -> Any | None:
|
||||
return await get_database_manager_client().get_execution_kv_data(
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
)
|
||||
@@ -26,10 +26,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
@@ -273,7 +273,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
@staticmethod
|
||||
def _create_block_function_signature(
|
||||
async def _create_block_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -312,7 +312,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
def _create_agent_function_signature(
|
||||
async def _create_agent_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -334,7 +334,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
raise ValueError("Graph ID or Graph Version not found in sink node.")
|
||||
|
||||
db_client = get_database_manager_client()
|
||||
sink_graph_meta = db_client.get_graph_metadata(graph_id, graph_version)
|
||||
sink_graph_meta = await db_client.get_graph_metadata(graph_id, graph_version)
|
||||
if not sink_graph_meta:
|
||||
raise ValueError(
|
||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||
@@ -374,7 +374,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
async def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Creates function signatures for tools linked to a specified node within a graph.
|
||||
|
||||
@@ -396,13 +396,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
db_client = get_database_manager_client()
|
||||
tools = [
|
||||
(link, node)
|
||||
for link, node in db_client.get_connected_output_nodes(node_id)
|
||||
for link, node in await db_client.get_connected_output_nodes(node_id)
|
||||
if link.source_name.startswith("tools_^_") and link.source_id == node_id
|
||||
]
|
||||
if not tools:
|
||||
raise ValueError("There is no next node to execute.")
|
||||
|
||||
return_tool_functions = []
|
||||
return_tool_functions: list[dict[str, Any]] = []
|
||||
|
||||
grouped_tool_links: dict[str, tuple["Node", list["Link"]]] = {}
|
||||
for link, node in tools:
|
||||
@@ -417,13 +417,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
return_tool_functions.append(
|
||||
SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
else:
|
||||
return_tool_functions.append(
|
||||
SmartDecisionMakerBlock._create_block_function_signature(
|
||||
await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
@@ -442,7 +442,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
tool_functions = self._create_function_signature(node_id)
|
||||
tool_functions = await self._create_function_signature(node_id)
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
input_data.conversation_history = input_data.conversation_history or []
|
||||
|
||||
@@ -161,7 +161,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
tool_functions = SmartDecisionMakerBlock._create_function_signature(
|
||||
tool_functions = await SmartDecisionMakerBlock._create_function_signature(
|
||||
test_graph.nodes[0].id
|
||||
)
|
||||
assert tool_functions is not None, "Tool functions should not be None"
|
||||
|
||||
@@ -22,6 +22,7 @@ from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
AgentNodeExecutionKeyValueData,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
@@ -29,6 +30,7 @@ from prisma.types import (
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
AgentNodeExecutionInputOutputCreateInput,
|
||||
AgentNodeExecutionKeyValueDataCreateInput,
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
)
|
||||
@@ -907,3 +909,57 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
||||
) -> AsyncGenerator[ExecutionEvent, None]:
|
||||
async for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
|
||||
yield event
|
||||
|
||||
|
||||
# --------------------- KV Data Functions --------------------- #
|
||||
|
||||
|
||||
async def get_execution_kv_data(user_id: str, key: str) -> Any | None:
|
||||
"""
|
||||
Get key-value data for a user and key.
|
||||
|
||||
Args:
|
||||
user_id: The id of the User.
|
||||
key: The key to retrieve data for.
|
||||
|
||||
Returns:
|
||||
The data associated with the key, or None if not found.
|
||||
"""
|
||||
kv_data = await AgentNodeExecutionKeyValueData.prisma().find_unique(
|
||||
where={"userId_key": {"userId": user_id, "key": key}}
|
||||
)
|
||||
return (
|
||||
type_utils.convert(kv_data.data, type[Any])
|
||||
if kv_data and kv_data.data
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
async def set_execution_kv_data(
|
||||
user_id: str, node_exec_id: str, key: str, data: Any
|
||||
) -> Any | None:
|
||||
"""
|
||||
Set key-value data for a user and key.
|
||||
|
||||
Args:
|
||||
user_id: The id of the User.
|
||||
node_exec_id: The id of the AgentNodeExecution.
|
||||
key: The key to store data under.
|
||||
data: The data to store.
|
||||
"""
|
||||
resp = await AgentNodeExecutionKeyValueData.prisma().upsert(
|
||||
where={"userId_key": {"userId": user_id, "key": key}},
|
||||
data={
|
||||
"create": AgentNodeExecutionKeyValueDataCreateInput(
|
||||
userId=user_id,
|
||||
agentNodeExecutionId=node_exec_id,
|
||||
key=key,
|
||||
data=Json(data) if data is not None else None,
|
||||
),
|
||||
"update": {
|
||||
"agentNodeExecutionId": node_exec_id,
|
||||
"data": Json(data) if data is not None else None,
|
||||
},
|
||||
},
|
||||
)
|
||||
return type_utils.convert(resp.data, type[Any]) if resp and resp.data else None
|
||||
|
||||
@@ -5,12 +5,14 @@ from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
create_graph_execution,
|
||||
get_execution_kv_data,
|
||||
get_graph_execution,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_latest_node_execution,
|
||||
get_node_execution,
|
||||
get_node_executions,
|
||||
set_execution_kv_data,
|
||||
update_graph_execution_start_time,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
@@ -101,6 +103,8 @@ class DatabaseManager(AppService):
|
||||
update_node_execution_stats = _(update_node_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
get_execution_kv_data = _(get_execution_kv_data)
|
||||
set_execution_kv_data = _(set_execution_kv_data)
|
||||
|
||||
# Graphs
|
||||
get_node = _(get_node)
|
||||
@@ -159,6 +163,8 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
update_node_execution_stats = _(d.update_node_execution_stats)
|
||||
upsert_execution_input = _(d.upsert_execution_input)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
get_execution_kv_data = _(d.get_execution_kv_data)
|
||||
set_execution_kv_data = _(d.set_execution_kv_data)
|
||||
|
||||
# Graphs
|
||||
get_node = _(d.get_node)
|
||||
@@ -202,8 +208,10 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
return DatabaseManager
|
||||
|
||||
create_graph_execution = d.create_graph_execution
|
||||
get_connected_output_nodes = d.get_connected_output_nodes
|
||||
get_latest_node_execution = d.get_latest_node_execution
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
@@ -216,3 +224,5 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
update_user_integrations = d.update_user_integrations
|
||||
get_execution_kv_data = d.get_execution_kv_data
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "AgentNodeExecutionKeyValueData" (
|
||||
"userId" TEXT NOT NULL,
|
||||
"key" TEXT NOT NULL,
|
||||
"agentNodeExecutionId" TEXT NOT NULL,
|
||||
"data" JSONB,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "AgentNodeExecutionKeyValueData_pkey" PRIMARY KEY ("userId","key")
|
||||
);
|
||||
@@ -413,6 +413,16 @@ model AgentNodeExecutionInputOutput {
|
||||
@@index([name, time])
|
||||
}
|
||||
|
||||
model AgentNodeExecutionKeyValueData {
|
||||
userId String
|
||||
key String
|
||||
agentNodeExecutionId String
|
||||
data Json?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime? @updatedAt
|
||||
@@id([userId, key])
|
||||
}
|
||||
|
||||
// Webhook that is registered with a provider and propagates to one or more nodes
|
||||
model IntegrationWebhook {
|
||||
id String @id @default(uuid())
|
||||
@@ -432,8 +442,8 @@ model IntegrationWebhook {
|
||||
|
||||
providerWebhookId String // Webhook ID assigned by the provider
|
||||
|
||||
AgentNodes AgentNode[]
|
||||
AgentPresets AgentPreset[]
|
||||
AgentNodes AgentNode[]
|
||||
AgentPresets AgentPreset[]
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user