Files
AutoGPT/autogpt_platform/backend/backend/data/execution.py
Nicholas Tindle 7668c17d9c feat(platform): add User Workspace for persistent CoPilot file storage (#11867)
Implements persistent User Workspace storage for CoPilot, enabling
blocks to save and retrieve files across sessions. Files are stored in
session-scoped virtual paths (`/sessions/{session_id}/`).

Fixes SECRT-1833

### Changes 🏗️

**Database & Storage:**
- Add `UserWorkspace` and `UserWorkspaceFile` Prisma models
- Implement `WorkspaceStorageBackend` abstraction (GCS for cloud, local
filesystem for self-hosted)
- Add `workspace_id` and `session_id` fields to `ExecutionContext`

**Backend API:**
- Add REST endpoints: `GET/POST /api/workspace/files`, `GET/DELETE
/api/workspace/files/{id}`, `GET /api/workspace/files/{id}/download`
- Add CoPilot tools: `list_workspace_files`, `read_workspace_file`,
`write_workspace_file`
- Integrate workspace storage into `store_media_file()` - returns
`workspace://file-id` references

**Block Updates:**
- Refactor all file-handling blocks to use unified `ExecutionContext`
parameter
- Update media-generating blocks to persist outputs to workspace
(AIImageGenerator, AIImageCustomizer, FluxKontext, TalkingHead, FAL
video, Bannerbear, etc.)

**Frontend:**
- Render `workspace://` image references in chat via proxy endpoint
- Add "AI cannot see this image" overlay indicator

**CoPilot Context Mapping:**
- Session = Agent (graph_id) = Run (graph_exec_id)
- Files scoped to `/sessions/{session_id}/`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Create CoPilot session, generate image with AIImageGeneratorBlock
  - [ ] Verify image returns `workspace://file-id` (not base64)
  - [ ] Verify image renders in chat with visibility indicator
  - [ ] Verify workspace files persist across sessions
  - [ ] Test list/read/write workspace files via CoPilot tools
  - [ ] Test local storage backend for self-hosted deployments

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

🤖 Generated with [Claude Code](https://claude.ai/code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Introduces a new persistent file-storage surface area (DB tables,
storage backends, download API, and chat tools) and rewires
`store_media_file()`/block execution context across many blocks, so
regressions could impact file handling, access control, or storage
costs.
> 
> **Overview**
> Adds a **persistent per-user Workspace** (new
`UserWorkspace`/`UserWorkspaceFile` models plus `WorkspaceManager` +
`WorkspaceStorageBackend` with GCS/local implementations) and wires it
into the API via a new `/api/workspace/files/{file_id}/download` route
(including header-sanitized `Content-Disposition`) and shutdown
lifecycle hooks.
> 
> Extends `ExecutionContext` to carry execution identity +
`workspace_id`/`session_id`, updates executor tooling to clone
node-specific contexts, and updates `run_block` (CoPilot) to create a
session-scoped workspace and synthetic graph/run/node IDs.
> 
> Refactors `store_media_file()` to require `execution_context` +
`return_format` and to support `workspace://` references; migrates many
media/file-handling blocks and related tests to the new API and to
persist generated media as `workspace://...` (or fall back to data URIs
outside CoPilot), and adds CoPilot chat tools for
listing/reading/writing/deleting workspace files with safeguards against
context bloat.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
6abc70f793. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-01-29 05:49:47 +00:00

1564 lines
52 KiB
Python

import logging
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from enum import Enum
from multiprocessing import Manager
from queue import Empty
from typing import (
TYPE_CHECKING,
Annotated,
Any,
AsyncGenerator,
Generator,
Generic,
Literal,
Mapping,
Optional,
TypeVar,
cast,
overload,
)
from prisma import Json
from prisma.enums import AgentExecutionStatus
from prisma.models import (
AgentGraphExecution,
AgentNodeExecution,
AgentNodeExecutionInputOutput,
AgentNodeExecutionKeyValueData,
)
from prisma.types import (
AgentGraphExecutionUpdateManyMutationInput,
AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput,
AgentNodeExecutionInputOutputCreateInput,
AgentNodeExecutionKeyValueDataCreateInput,
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
AgentNodeExecutionWhereUniqueInput,
)
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
from pydantic.fields import Field
from backend.util import type as type_utils
from backend.util.exceptions import DatabaseError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.retry import func_retry
from backend.util.settings import Config
from backend.util.truncate import truncate
from .block import (
BlockInput,
BlockType,
CompletedBlockOutput,
get_block,
get_io_block_ids,
get_webhook_block_ids,
)
from .db import BaseDbModel, query_raw_with_schema
from .event_bus import AsyncRedisEventBus, RedisEventBus
from .includes import (
EXECUTION_RESULT_INCLUDE,
EXECUTION_RESULT_ORDER,
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
graph_execution_include,
)
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
if TYPE_CHECKING:
pass
T = TypeVar("T")
logger = logging.getLogger(__name__)
config = Config()
class ExecutionContext(BaseModel):
"""
Unified context that carries execution-level data throughout the entire execution flow.
This includes information needed by blocks, sub-graphs, and execution management.
"""
model_config = {"extra": "ignore"}
# Execution identity
user_id: Optional[str] = None
graph_id: Optional[str] = None
graph_exec_id: Optional[str] = None
graph_version: Optional[int] = None
node_id: Optional[str] = None
node_exec_id: Optional[str] = None
# Safety settings
human_in_the_loop_safe_mode: bool = True
sensitive_action_safe_mode: bool = False
# User settings
user_timezone: str = "UTC"
# Execution hierarchy
root_execution_id: Optional[str] = None
parent_execution_id: Optional[str] = None
# Workspace
workspace_id: Optional[str] = None
session_id: Optional[str] = None
# -------------------------- Models -------------------------- #
class BlockErrorStats(BaseModel):
"""Typed data structure for block error statistics."""
block_id: str
total_executions: int
failed_executions: int
@property
def error_rate(self) -> float:
"""Calculate error rate as a percentage."""
if self.total_executions == 0:
return 0.0
return (self.failed_executions / self.total_executions) * 100
ExecutionStatus = AgentExecutionStatus
NodeInputMask = Mapping[str, JsonValue]
NodesInputMasks = Mapping[str, NodeInputMask]
# dest: source
VALID_STATUS_TRANSITIONS = {
ExecutionStatus.QUEUED: [
ExecutionStatus.INCOMPLETE,
ExecutionStatus.TERMINATED, # For resuming halted execution
ExecutionStatus.REVIEW, # For resuming after review
],
ExecutionStatus.RUNNING: [
ExecutionStatus.INCOMPLETE,
ExecutionStatus.QUEUED,
ExecutionStatus.TERMINATED, # For resuming halted execution
ExecutionStatus.REVIEW, # For resuming after review
],
ExecutionStatus.COMPLETED: [
ExecutionStatus.RUNNING,
],
ExecutionStatus.FAILED: [
ExecutionStatus.INCOMPLETE,
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
ExecutionStatus.REVIEW,
],
ExecutionStatus.TERMINATED: [
ExecutionStatus.INCOMPLETE,
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
ExecutionStatus.REVIEW,
],
ExecutionStatus.REVIEW: [
ExecutionStatus.RUNNING,
],
}
class GraphExecutionMeta(BaseDbModel):
id: str # type: ignore # Override base class to make this required
user_id: str
graph_id: str
graph_version: int
inputs: Optional[BlockInput] # no default -> required in the OpenAPI spec
credential_inputs: Optional[dict[str, CredentialsMetaInput]]
nodes_input_masks: Optional[dict[str, BlockInput]]
preset_id: Optional[str]
status: ExecutionStatus
started_at: Optional[datetime] = Field(
None,
description="When execution started running. Null if not yet started (QUEUED).",
)
ended_at: Optional[datetime] = Field(
None,
description="When execution finished. Null if not yet completed (QUEUED, RUNNING, INCOMPLETE, REVIEW).",
)
is_shared: bool = False
share_token: Optional[str] = None
class Stats(BaseModel):
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
)
cost: int = Field(
default=0,
description="Execution cost (cents)",
)
duration: float = Field(
default=0,
description="Seconds from start to end of run",
)
duration_cpu_only: float = Field(
default=0,
description="CPU sec of duration",
)
node_exec_time: float = Field(
default=0,
description="Seconds of total node runtime",
)
node_exec_time_cpu_only: float = Field(
default=0,
description="CPU sec of node_exec_time",
)
node_exec_count: int = Field(
default=0,
description="Number of node executions",
)
node_error_count: int = Field(
default=0,
description="Number of node errors",
)
error: str | None = Field(
default=None,
description="Error message if any",
)
activity_status: str | None = Field(
default=None,
description="AI-generated summary of what the agent did",
)
correctness_score: float | None = Field(
default=None,
description="AI-generated score (0.0-1.0) indicating how well the execution achieved its intended purpose",
)
def to_db(self) -> GraphExecutionStats:
return GraphExecutionStats(
cost=self.cost,
walltime=self.duration,
cputime=self.duration_cpu_only,
nodes_walltime=self.node_exec_time,
nodes_cputime=self.node_exec_time_cpu_only,
node_count=self.node_exec_count,
node_error_count=self.node_error_count,
error=self.error,
activity_status=self.activity_status,
correctness_score=self.correctness_score,
)
def without_activity_features(self) -> "GraphExecutionMeta.Stats":
"""Return a copy of stats with activity features (activity_status, correctness_score) set to None."""
return self.model_copy(
update={"activity_status": None, "correctness_score": None}
)
stats: Stats | None
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
start_time = _graph_exec.startedAt
end_time = _graph_exec.endedAt
try:
stats = GraphExecutionStats.model_validate(_graph_exec.stats)
except ValueError as e:
if _graph_exec.stats is not None:
logger.warning(
"Failed to parse invalid graph execution stats "
f"{_graph_exec.stats}: {e}"
)
stats = None
return GraphExecutionMeta(
id=_graph_exec.id,
user_id=_graph_exec.userId,
graph_id=_graph_exec.agentGraphId,
graph_version=_graph_exec.agentGraphVersion,
inputs=cast(BlockInput | None, _graph_exec.inputs),
credential_inputs=(
{
name: CredentialsMetaInput.model_validate(cmi)
for name, cmi in cast(dict, _graph_exec.credentialInputs).items()
}
if _graph_exec.credentialInputs
else None
),
nodes_input_masks=cast(
dict[str, BlockInput] | None, _graph_exec.nodesInputMasks
),
preset_id=_graph_exec.agentPresetId,
status=ExecutionStatus(_graph_exec.executionStatus),
started_at=start_time,
ended_at=end_time,
stats=(
GraphExecutionMeta.Stats(
cost=stats.cost,
duration=stats.walltime,
duration_cpu_only=stats.cputime,
node_exec_time=stats.nodes_walltime,
node_exec_time_cpu_only=stats.nodes_cputime,
node_exec_count=stats.node_count,
node_error_count=stats.node_error_count,
error=(
str(stats.error)
if isinstance(stats.error, Exception)
else stats.error
),
activity_status=stats.activity_status,
correctness_score=stats.correctness_score,
)
if stats
else None
),
is_shared=_graph_exec.isShared,
share_token=_graph_exec.shareToken,
)
class GraphExecution(GraphExecutionMeta):
inputs: BlockInput # type: ignore - incompatible override is intentional
outputs: CompletedBlockOutput
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
if _graph_exec.NodeExecutions is None:
raise ValueError("Node executions must be included in query")
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
complete_node_executions = sorted(
[
NodeExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.NodeExecutions
if ne.executionStatus != ExecutionStatus.INCOMPLETE
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
inputs = {
**(
graph_exec.inputs
or {
# fallback: extract inputs from Agent Input Blocks
exec.input_data["name"]: exec.input_data.get("value")
for exec in complete_node_executions
if (
(block := get_block(exec.block_id))
and block.block_type == BlockType.INPUT
)
}
),
**{
# input from webhook-triggered block
"payload": exec.input_data["payload"]
for exec in complete_node_executions
if (
(block := get_block(exec.block_id))
and block.block_type
in [BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL]
)
},
}
outputs: CompletedBlockOutput = defaultdict(list)
for exec in complete_node_executions:
if (
block := get_block(exec.block_id)
) and block.block_type == BlockType.OUTPUT:
outputs[exec.input_data["name"]].append(exec.input_data.get("value"))
return GraphExecution(
**{
field_name: getattr(graph_exec, field_name)
for field_name in GraphExecutionMeta.model_fields
if field_name != "inputs"
},
inputs=inputs,
outputs=outputs,
)
class GraphExecutionWithNodes(GraphExecution):
node_executions: list["NodeExecutionResult"]
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
if _graph_exec.NodeExecutions is None:
raise ValueError("Node executions must be included in query")
graph_exec_with_io = GraphExecution.from_db(_graph_exec)
node_executions = sorted(
[
NodeExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.NodeExecutions
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
return GraphExecutionWithNodes(
**{
field_name: getattr(graph_exec_with_io, field_name)
for field_name in GraphExecution.model_fields
},
node_executions=node_executions,
)
def to_graph_execution_entry(
self,
execution_context: ExecutionContext,
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
):
return GraphExecutionEntry(
user_id=self.user_id,
graph_id=self.graph_id,
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
nodes_input_masks=compiled_nodes_input_masks,
nodes_to_skip=nodes_to_skip or set(),
execution_context=execution_context,
)
class NodeExecutionResult(BaseModel):
user_id: str
graph_id: str
graph_version: int
graph_exec_id: str
node_exec_id: str
node_id: str
block_id: str
status: ExecutionStatus
input_data: BlockInput
output_data: CompletedBlockOutput
add_time: datetime
queue_time: datetime | None
start_time: datetime | None
end_time: datetime | None
@staticmethod
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
try:
stats = NodeExecutionStats.model_validate(_node_exec.stats or {})
except (ValueError, ValidationError):
stats = NodeExecutionStats()
if stats.cleared_inputs:
input_data: BlockInput = defaultdict()
for name, messages in stats.cleared_inputs.items():
input_data[name] = messages[-1] if messages else ""
elif _node_exec.executionData:
input_data = type_utils.convert(_node_exec.executionData, dict[str, Any])
else:
input_data: BlockInput = defaultdict()
for data in _node_exec.Input or []:
input_data[data.name] = type_utils.convert(data.data, JsonValue)
output_data: CompletedBlockOutput = defaultdict(list)
if stats.cleared_outputs:
for name, messages in stats.cleared_outputs.items():
output_data[name].extend(messages)
else:
for data in _node_exec.Output or []:
output_data[data.name].append(type_utils.convert(data.data, JsonValue))
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
if graph_execution:
user_id = graph_execution.userId
elif not user_id:
raise ValueError(
"AgentGraphExecution must be included or user_id passed in"
)
return NodeExecutionResult(
user_id=user_id,
graph_id=graph_execution.agentGraphId if graph_execution else "",
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
graph_exec_id=_node_exec.agentGraphExecutionId,
block_id=_node_exec.Node.agentBlockId if _node_exec.Node else "",
node_exec_id=_node_exec.id,
node_id=_node_exec.agentNodeId,
status=_node_exec.executionStatus,
input_data=input_data,
output_data=output_data,
add_time=_node_exec.addedTime,
queue_time=_node_exec.queuedTime,
start_time=_node_exec.startedTime,
end_time=_node_exec.endedTime,
)
def to_node_execution_entry(
self, execution_context: ExecutionContext
) -> "NodeExecutionEntry":
return NodeExecutionEntry(
user_id=self.user_id,
graph_exec_id=self.graph_exec_id,
graph_id=self.graph_id,
graph_version=self.graph_version,
node_exec_id=self.node_exec_id,
node_id=self.node_id,
block_id=self.block_id,
inputs=self.input_data,
execution_context=execution_context,
)
# --------------------- Model functions --------------------- #
async def get_graph_executions(
graph_exec_id: Optional[str] = None,
graph_id: Optional[str] = None,
graph_version: Optional[int] = None,
user_id: Optional[str] = None,
statuses: Optional[list[ExecutionStatus]] = None,
created_time_gte: Optional[datetime] = None,
created_time_lte: Optional[datetime] = None,
limit: Optional[int] = None,
) -> list[GraphExecutionMeta]:
"""⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints."""
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
}
if graph_exec_id:
where_filter["id"] = graph_exec_id
if user_id:
where_filter["userId"] = user_id
if graph_id:
where_filter["agentGraphId"] = graph_id
if graph_version is not None:
where_filter["agentGraphVersion"] = graph_version
if created_time_gte or created_time_lte:
where_filter["createdAt"] = {
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
}
if statuses:
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
executions = await AgentGraphExecution.prisma().find_many(
where=where_filter,
order={"createdAt": "desc"},
take=limit,
)
return [GraphExecutionMeta.from_db(execution) for execution in executions]
async def get_graph_executions_count(
user_id: Optional[str] = None,
graph_id: Optional[str] = None,
statuses: Optional[list[ExecutionStatus]] = None,
created_time_gte: Optional[datetime] = None,
created_time_lte: Optional[datetime] = None,
) -> int:
"""
Get count of graph executions with optional filters.
Args:
user_id: Optional user ID to filter by
graph_id: Optional graph ID to filter by
statuses: Optional list of execution statuses to filter by
created_time_gte: Optional minimum creation time
created_time_lte: Optional maximum creation time
Returns:
Count of matching graph executions
"""
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
}
if user_id:
where_filter["userId"] = user_id
if graph_id:
where_filter["agentGraphId"] = graph_id
if created_time_gte or created_time_lte:
where_filter["createdAt"] = {
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
}
if statuses:
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
count = await AgentGraphExecution.prisma().count(where=where_filter)
return count
class GraphExecutionsPaginated(BaseModel):
"""Response schema for paginated graph executions."""
executions: list[GraphExecutionMeta]
pagination: Pagination
async def get_graph_executions_paginated(
user_id: str,
graph_id: Optional[str] = None,
page: int = 1,
page_size: int = 25,
statuses: Optional[list[ExecutionStatus]] = None,
created_time_gte: Optional[datetime] = None,
created_time_lte: Optional[datetime] = None,
) -> GraphExecutionsPaginated:
"""Get paginated graph executions for a specific graph."""
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
"userId": user_id,
}
if graph_id:
where_filter["agentGraphId"] = graph_id
if created_time_gte or created_time_lte:
where_filter["createdAt"] = {
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
}
if statuses:
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
total_count = await AgentGraphExecution.prisma().count(where=where_filter)
total_pages = (total_count + page_size - 1) // page_size
offset = (page - 1) * page_size
executions = await AgentGraphExecution.prisma().find_many(
where=where_filter,
order={"createdAt": "desc"},
take=page_size,
skip=offset,
)
return GraphExecutionsPaginated(
executions=[GraphExecutionMeta.from_db(execution) for execution in executions],
pagination=Pagination(
total_items=total_count,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)
async def get_graph_execution_meta(
user_id: str, execution_id: str
) -> GraphExecutionMeta | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id}
)
return GraphExecutionMeta.from_db(execution) if execution else None
@overload
async def get_graph_execution(
user_id: str,
execution_id: str,
include_node_executions: Literal[True],
) -> GraphExecutionWithNodes | None: ...
@overload
async def get_graph_execution(
user_id: str,
execution_id: str,
include_node_executions: Literal[False] = False,
) -> GraphExecution | None: ...
@overload
async def get_graph_execution(
user_id: str,
execution_id: str,
include_node_executions: bool = False,
) -> GraphExecution | GraphExecutionWithNodes | None: ...
async def get_graph_execution(
user_id: str,
execution_id: str,
include_node_executions: bool = False,
) -> GraphExecution | GraphExecutionWithNodes | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id},
include=(
GRAPH_EXECUTION_INCLUDE_WITH_NODES
if include_node_executions
else graph_execution_include(
[*get_io_block_ids(), *get_webhook_block_ids()]
)
),
)
if not execution:
return None
return (
GraphExecutionWithNodes.from_db(execution)
if include_node_executions
else GraphExecution.from_db(execution)
)
async def get_child_graph_executions(
parent_exec_id: str,
) -> list[GraphExecutionMeta]:
"""
Get all child executions of a parent execution.
Args:
parent_exec_id: Parent graph execution ID
Returns:
List of child graph executions
"""
children = await AgentGraphExecution.prisma().find_many(
where={"parentGraphExecutionId": parent_exec_id, "isDeleted": False}
)
return [GraphExecutionMeta.from_db(child) for child in children]
async def create_graph_execution(
graph_id: str,
graph_version: int,
starting_nodes_input: list[tuple[str, BlockInput]], # list[(node_id, BlockInput)]
inputs: Mapping[str, JsonValue],
user_id: str,
preset_id: Optional[str] = None,
credential_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
parent_graph_exec_id: Optional[str] = None,
) -> GraphExecutionWithNodes:
"""
Create a new AgentGraphExecution record.
Returns:
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
"""
result = await AgentGraphExecution.prisma().create(
data={
"agentGraphId": graph_id,
"agentGraphVersion": graph_version,
"executionStatus": ExecutionStatus.INCOMPLETE,
"inputs": SafeJson(inputs),
"credentialInputs": (
SafeJson(credential_inputs) if credential_inputs else Json({})
),
"nodesInputMasks": (
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
),
"NodeExecutions": {
"create": [
AgentNodeExecutionCreateInput(
agentNodeId=node_id,
executionStatus=ExecutionStatus.QUEUED,
queuedTime=datetime.now(tz=timezone.utc),
Input={
"create": [
{"name": name, "data": SafeJson(data)}
for name, data in node_input.items()
]
},
)
for node_id, node_input in starting_nodes_input
]
},
"userId": user_id,
"agentPresetId": preset_id,
"parentGraphExecutionId": parent_graph_exec_id,
},
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
return GraphExecutionWithNodes.from_db(result)
async def upsert_execution_input(
node_id: str,
graph_exec_id: str,
input_name: str,
input_data: JsonValue,
node_exec_id: str | None = None,
) -> tuple[NodeExecutionResult, BlockInput]:
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Input.
If there is no AgentNodeExecution that has no `input_name` as input, create new one.
Args:
node_id: The id of the AgentNode.
graph_exec_id: The id of the AgentGraphExecution.
input_name: The name of the input data.
input_data: The input data to be inserted.
node_exec_id: [Optional] The id of the AgentNodeExecution that has no `input_name` as input. If not provided, it will find the eligible incomplete AgentNodeExecution or create a new one.
Returns:
str: The id of the created or existing AgentNodeExecution.
dict[str, Any]: Node input data; key is the input name, value is the input data.
"""
existing_exec_query_filter: AgentNodeExecutionWhereInput = {
"agentGraphExecutionId": graph_exec_id,
"agentNodeId": node_id,
"executionStatus": ExecutionStatus.INCOMPLETE,
"Input": {
"none": {
"name": input_name,
"time": {"gte": datetime.now(tz=timezone.utc) - timedelta(days=1)},
}
},
}
if node_exec_id:
existing_exec_query_filter["id"] = node_exec_id
existing_execution = await AgentNodeExecution.prisma().find_first(
where=existing_exec_query_filter,
order={"addedTime": "asc"},
include={"Input": True, "GraphExecution": True},
)
json_input_data = SafeJson(input_data)
if existing_execution:
await AgentNodeExecutionInputOutput.prisma().create(
data=AgentNodeExecutionInputOutputCreateInput(
name=input_name,
data=json_input_data,
referencedByInputExecId=existing_execution.id,
)
)
return NodeExecutionResult.from_db(existing_execution), {
**{
input_data.name: type_utils.convert(input_data.data, JsonValue)
for input_data in existing_execution.Input or []
},
input_name: input_data,
}
elif not node_exec_id:
result = await AgentNodeExecution.prisma().create(
data=AgentNodeExecutionCreateInput(
agentNodeId=node_id,
agentGraphExecutionId=graph_exec_id,
executionStatus=ExecutionStatus.INCOMPLETE,
Input={"create": {"name": input_name, "data": json_input_data}},
),
include={"GraphExecution": True},
)
return NodeExecutionResult.from_db(result), {input_name: input_data}
else:
raise ValueError(
f"NodeExecution {node_exec_id} not found or already has input {input_name}."
)
async def upsert_execution_output(
node_exec_id: str,
output_name: str,
output_data: Any | None,
) -> None:
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
"""
data: AgentNodeExecutionInputOutputCreateInput = {
"name": output_name,
"referencedByOutputExecId": node_exec_id,
}
if output_data is not None:
data["data"] = SafeJson(output_data)
await AgentNodeExecutionInputOutput.prisma().create(data=data)
async def get_execution_outputs_by_node_exec_id(
node_exec_id: str,
) -> dict[str, Any]:
"""
Get all execution outputs for a specific node execution ID.
Args:
node_exec_id: The node execution ID to get outputs for
Returns:
Dictionary mapping output names to their data values
"""
outputs = await AgentNodeExecutionInputOutput.prisma().find_many(
where={"referencedByOutputExecId": node_exec_id}
)
result = {}
for output in outputs:
if output.data is not None:
result[output.name] = type_utils.convert(output.data, JsonValue)
return result
async def update_graph_execution_start_time(
graph_exec_id: str,
) -> GraphExecution | None:
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={
"executionStatus": ExecutionStatus.RUNNING,
"startedAt": datetime.now(tz=timezone.utc),
},
include=graph_execution_include(
[*get_io_block_ids(), *get_webhook_block_ids()]
),
)
return GraphExecution.from_db(res) if res else None
async def update_graph_execution_stats(
graph_exec_id: str,
status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None,
) -> GraphExecution | None:
if not status and not stats:
raise ValueError(
f"Must provide either status or stats to update for execution {graph_exec_id}"
)
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
if stats:
stats_dict = stats.model_dump()
if isinstance(stats_dict.get("error"), Exception):
stats_dict["error"] = str(stats_dict["error"])
update_data["stats"] = SafeJson(stats_dict)
if status:
update_data["executionStatus"] = status
# Set endedAt when execution reaches a terminal status
terminal_statuses = [
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
ExecutionStatus.TERMINATED,
]
if status in terminal_statuses:
update_data["endedAt"] = datetime.now(tz=timezone.utc)
where_clause: AgentGraphExecutionWhereInput = {"id": graph_exec_id}
if status:
if allowed_from := VALID_STATUS_TRANSITIONS.get(status, []):
# Add OR clause to check if current status is one of the allowed source statuses
where_clause["AND"] = [
{"id": graph_exec_id},
{"OR": [{"executionStatus": s} for s in allowed_from]},
]
else:
raise ValueError(
f"Status {status} cannot be set via update for execution {graph_exec_id}. "
f"This status can only be set at creation or is not a valid target status."
)
await AgentGraphExecution.prisma().update_many(
where=where_clause,
data=update_data,
)
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
where={"id": graph_exec_id},
include=graph_execution_include(
[*get_io_block_ids(), *get_webhook_block_ids()]
),
)
return GraphExecution.from_db(graph_exec)
async def update_node_execution_status_batch(
node_exec_ids: list[str],
status: ExecutionStatus,
stats: dict[str, Any] | None = None,
) -> int:
# Validate status transitions - allowed_from should never be empty for valid statuses
allowed_from = VALID_STATUS_TRANSITIONS.get(status, [])
if not allowed_from:
raise ValueError(
f"Invalid status transition: {status} has no valid source statuses"
)
# For batch updates, we filter to only update nodes with valid current statuses
where_clause = cast(
AgentNodeExecutionWhereInput,
{
"id": {"in": node_exec_ids},
"executionStatus": {"in": [s.value for s in allowed_from]},
},
)
return await AgentNodeExecution.prisma().update_many(
where=where_clause,
data=_get_update_status_data(status, None, stats),
)
async def update_node_execution_status(
node_exec_id: str,
status: ExecutionStatus,
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> NodeExecutionResult:
if status == ExecutionStatus.QUEUED and execution_data is None:
raise ValueError("Execution data must be provided when queuing an execution.")
# Validate status transitions - allowed_from should never be empty for valid statuses
allowed_from = VALID_STATUS_TRANSITIONS.get(status, [])
if not allowed_from:
raise ValueError(
f"Invalid status transition: {status} has no valid source statuses"
)
if res := await AgentNodeExecution.prisma().update(
where=cast(
AgentNodeExecutionWhereUniqueInput,
{
"id": node_exec_id,
"executionStatus": {"in": [s.value for s in allowed_from]},
},
),
data=_get_update_status_data(status, execution_data, stats),
include=EXECUTION_RESULT_INCLUDE,
):
return NodeExecutionResult.from_db(res)
if res := await AgentNodeExecution.prisma().find_unique(
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
):
return NodeExecutionResult.from_db(res)
raise ValueError(f"Execution {node_exec_id} not found.")
def _get_update_status_data(
status: ExecutionStatus,
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> AgentNodeExecutionUpdateInput:
now = datetime.now(tz=timezone.utc)
update_data: AgentNodeExecutionUpdateInput = {"executionStatus": status}
if status == ExecutionStatus.QUEUED:
update_data["queuedTime"] = now
elif status == ExecutionStatus.RUNNING:
update_data["startedTime"] = now
elif status in (ExecutionStatus.FAILED, ExecutionStatus.COMPLETED):
update_data["endedTime"] = now
if execution_data:
update_data["executionData"] = SafeJson(execution_data)
if stats:
update_data["stats"] = SafeJson(stats)
return update_data
async def delete_graph_execution(
graph_exec_id: str, user_id: str, soft_delete: bool = True
) -> None:
if soft_delete:
deleted_count = await AgentGraphExecution.prisma().update_many(
where={"id": graph_exec_id, "userId": user_id}, data={"isDeleted": True}
)
else:
deleted_count = await AgentGraphExecution.prisma().delete_many(
where={"id": graph_exec_id, "userId": user_id}
)
if deleted_count < 1:
raise DatabaseError(
f"Could not delete graph execution #{graph_exec_id}: not found"
)
async def get_node_execution(node_exec_id: str) -> NodeExecutionResult | None:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
execution = await AgentNodeExecution.prisma().find_first(
where={"id": node_exec_id},
include=EXECUTION_RESULT_INCLUDE,
)
if not execution:
return None
return NodeExecutionResult.from_db(execution)
def _build_node_execution_where_clause(
graph_exec_id: str | None = None,
node_id: str | None = None,
block_ids: list[str] | None = None,
statuses: list[ExecutionStatus] | None = None,
created_time_gte: datetime | None = None,
created_time_lte: datetime | None = None,
) -> AgentNodeExecutionWhereInput:
"""
Build where clause for node execution queries.
"""
where_clause: AgentNodeExecutionWhereInput = {}
if graph_exec_id:
where_clause["agentGraphExecutionId"] = graph_exec_id
if node_id:
where_clause["agentNodeId"] = node_id
if block_ids:
where_clause["Node"] = {"is": {"agentBlockId": {"in": block_ids}}}
if statuses:
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
if created_time_gte or created_time_lte:
where_clause["addedTime"] = {
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
}
return where_clause
async def get_node_executions(
graph_exec_id: str | None = None,
node_id: str | None = None,
block_ids: list[str] | None = None,
statuses: list[ExecutionStatus] | None = None,
limit: int | None = None,
created_time_gte: datetime | None = None,
created_time_lte: datetime | None = None,
include_exec_data: bool = True,
) -> list[NodeExecutionResult]:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
where_clause = _build_node_execution_where_clause(
graph_exec_id=graph_exec_id,
node_id=node_id,
block_ids=block_ids,
statuses=statuses,
created_time_gte=created_time_gte,
created_time_lte=created_time_lte,
)
executions = await AgentNodeExecution.prisma().find_many(
where=where_clause,
include=(
EXECUTION_RESULT_INCLUDE
if include_exec_data
else {"Node": True, "GraphExecution": True}
),
order=EXECUTION_RESULT_ORDER,
take=limit,
)
res = [NodeExecutionResult.from_db(execution) for execution in executions]
return res
async def get_latest_node_execution(
node_id: str, graph_eid: str
) -> NodeExecutionResult | None:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
execution = await AgentNodeExecution.prisma().find_first(
where={
"agentGraphExecutionId": graph_eid,
"agentNodeId": node_id,
"OR": [
{"executionStatus": ExecutionStatus.QUEUED},
{"executionStatus": ExecutionStatus.RUNNING},
{"executionStatus": ExecutionStatus.COMPLETED},
{"executionStatus": ExecutionStatus.TERMINATED},
{"executionStatus": ExecutionStatus.FAILED},
],
},
include=EXECUTION_RESULT_INCLUDE,
order=EXECUTION_RESULT_ORDER,
)
if not execution:
return None
return NodeExecutionResult.from_db(execution)
# ----------------- Execution Infrastructure ----------------- #
class GraphExecutionEntry(BaseModel):
model_config = {"extra": "ignore"}
user_id: str
graph_exec_id: str
graph_id: str
graph_version: int
nodes_input_masks: Optional[NodesInputMasks] = None
nodes_to_skip: set[str] = Field(default_factory=set)
"""Node IDs that should be skipped due to optional credentials not being configured."""
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
class NodeExecutionEntry(BaseModel):
model_config = {"extra": "ignore"}
user_id: str
graph_exec_id: str
graph_id: str
graph_version: int
node_exec_id: str
node_id: str
block_id: str
inputs: BlockInput
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
class ExecutionQueue(Generic[T]):
"""
Queue for managing the execution of agents.
This will be shared between different processes
"""
def __init__(self):
self.queue = Manager().Queue()
def add(self, execution: T) -> T:
self.queue.put(execution)
return execution
def get(self) -> T:
return self.queue.get()
def empty(self) -> bool:
return self.queue.empty()
def get_or_none(self) -> T | None:
try:
return self.queue.get_nowait()
except Empty:
return None
# --------------------- Event Bus --------------------- #
class ExecutionEventType(str, Enum):
GRAPH_EXEC_UPDATE = "graph_execution_update"
NODE_EXEC_UPDATE = "node_execution_update"
ERROR_COMMS_UPDATE = "error_comms_update"
class GraphExecutionEvent(GraphExecution):
event_type: Literal[ExecutionEventType.GRAPH_EXEC_UPDATE] = (
ExecutionEventType.GRAPH_EXEC_UPDATE
)
class NodeExecutionEvent(NodeExecutionResult):
event_type: Literal[ExecutionEventType.NODE_EXEC_UPDATE] = (
ExecutionEventType.NODE_EXEC_UPDATE
)
class SharedExecutionResponse(BaseModel):
"""Public-safe response for shared executions"""
id: str
graph_name: str
graph_description: Optional[str]
status: ExecutionStatus
created_at: datetime
outputs: CompletedBlockOutput # Only the final outputs, no intermediate data
# Deliberately exclude: user_id, inputs, credentials, node details
ExecutionEvent = Annotated[
GraphExecutionEvent | NodeExecutionEvent, Field(discriminator="event_type")
]
class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
Model = ExecutionEvent # type: ignore
@property
def event_bus_name(self) -> str:
return config.execution_event_bus_name
def publish(self, res: GraphExecution | NodeExecutionResult):
if isinstance(res, GraphExecution):
self._publish_graph_exec_update(res)
else:
self._publish_node_exec_update(res)
def _publish_node_exec_update(self, res: NodeExecutionResult):
event = NodeExecutionEvent.model_validate(res.model_dump())
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
def _publish_graph_exec_update(self, res: GraphExecution):
event = GraphExecutionEvent.model_validate(res.model_dump())
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
def _publish(self, event: ExecutionEvent, channel: str):
"""
truncate inputs and outputs to avoid large payloads
"""
limit = config.max_message_size_limit // 2
if isinstance(event, GraphExecutionEvent):
event.inputs = truncate(event.inputs, limit)
event.outputs = truncate(event.outputs, limit)
elif isinstance(event, NodeExecutionEvent):
event.input_data = truncate(event.input_data, limit)
event.output_data = truncate(event.output_data, limit)
super().publish_event(event, channel)
def listen(
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
) -> Generator[ExecutionEvent, None, None]:
for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
yield event
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
Model = ExecutionEvent # type: ignore
@property
def event_bus_name(self) -> str:
return config.execution_event_bus_name
@func_retry
async def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
if isinstance(res, GraphExecutionMeta):
await self._publish_graph_exec_update(res)
else:
await self._publish_node_exec_update(res)
async def _publish_node_exec_update(self, res: NodeExecutionResult):
event = NodeExecutionEvent.model_validate(res.model_dump())
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
async def _publish_graph_exec_update(self, res: GraphExecutionMeta):
# GraphExecutionEvent requires inputs and outputs fields that GraphExecutionMeta doesn't have
# Add default empty values for compatibility
event_data = res.model_dump()
event_data.setdefault("inputs", {})
event_data.setdefault("outputs", {})
event = GraphExecutionEvent.model_validate(event_data)
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
async def _publish(self, event: ExecutionEvent, channel: str):
"""
truncate inputs and outputs to avoid large payloads
"""
limit = config.max_message_size_limit // 2
if isinstance(event, GraphExecutionEvent):
event.inputs = truncate(event.inputs, limit)
event.outputs = truncate(event.outputs, limit)
elif isinstance(event, NodeExecutionEvent):
event.input_data = truncate(event.input_data, limit)
event.output_data = truncate(event.output_data, limit)
await super().publish_event(event, channel)
async def listen(
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
) -> AsyncGenerator[ExecutionEvent, None]:
async for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
yield event
# --------------------- KV Data Functions --------------------- #
async def get_execution_kv_data(user_id: str, key: str) -> Any | None:
"""
Get key-value data for a user and key.
Args:
user_id: The id of the User.
key: The key to retrieve data for.
Returns:
The data associated with the key, or None if not found.
"""
kv_data = await AgentNodeExecutionKeyValueData.prisma().find_unique(
where={"userId_key": {"userId": user_id, "key": key}}
)
return (
type_utils.convert(kv_data.data, type[Any])
if kv_data and kv_data.data
else None
)
async def set_execution_kv_data(
user_id: str, node_exec_id: str, key: str, data: Any
) -> Any | None:
"""
Set key-value data for a user and key.
Args:
user_id: The id of the User.
node_exec_id: The id of the AgentNodeExecution.
key: The key to store data under.
data: The data to store.
"""
resp = await AgentNodeExecutionKeyValueData.prisma().upsert(
where={"userId_key": {"userId": user_id, "key": key}},
data={
"create": AgentNodeExecutionKeyValueDataCreateInput(
userId=user_id,
agentNodeExecutionId=node_exec_id,
key=key,
data=SafeJson(data) if data is not None else None,
),
"update": {
"agentNodeExecutionId": node_exec_id,
"data": SafeJson(data) if data is not None else None,
},
},
)
return type_utils.convert(resp.data, type[Any]) if resp and resp.data else None
async def get_block_error_stats(
start_time: datetime, end_time: datetime
) -> list[BlockErrorStats]:
"""Get block execution stats using efficient SQL aggregation."""
query_template = """
SELECT
n."agentBlockId" as block_id,
COUNT(*) as total_executions,
SUM(CASE WHEN ne."executionStatus" = 'FAILED' THEN 1 ELSE 0 END) as failed_executions
FROM {schema_prefix}"AgentNodeExecution" ne
JOIN {schema_prefix}"AgentNode" n ON ne."agentNodeId" = n.id
WHERE ne."addedTime" >= $1::timestamp AND ne."addedTime" <= $2::timestamp
GROUP BY n."agentBlockId"
HAVING COUNT(*) >= 10
"""
result = await query_raw_with_schema(query_template, start_time, end_time)
# Convert to typed data structures
return [
BlockErrorStats(
block_id=row["block_id"],
total_executions=int(row["total_executions"]),
failed_executions=int(row["failed_executions"]),
)
for row in result
]
async def update_graph_execution_share_status(
execution_id: str,
user_id: str,
is_shared: bool,
share_token: str | None,
shared_at: datetime | None,
) -> None:
"""Update the sharing status of a graph execution."""
await AgentGraphExecution.prisma().update(
where={"id": execution_id},
data={
"isShared": is_shared,
"shareToken": share_token,
"sharedAt": shared_at,
},
)
async def get_graph_execution_by_share_token(
share_token: str,
) -> SharedExecutionResponse | None:
"""Get a shared execution with limited public-safe data."""
execution = await AgentGraphExecution.prisma().find_first(
where={
"shareToken": share_token,
"isShared": True,
"isDeleted": False,
},
include={
"AgentGraph": True,
"NodeExecutions": {
"include": {
"Output": True,
"Node": {
"include": {
"AgentBlock": True,
}
},
},
},
},
)
if not execution:
return None
# Extract outputs from OUTPUT blocks only (consistent with GraphExecution.from_db)
outputs: CompletedBlockOutput = defaultdict(list)
if execution.NodeExecutions:
for node_exec in execution.NodeExecutions:
if node_exec.Node and node_exec.Node.agentBlockId:
# Get the block definition to check its type
block = get_block(node_exec.Node.agentBlockId)
if block and block.block_type == BlockType.OUTPUT:
# For OUTPUT blocks, the data is stored in executionData or Input
# The executionData contains the structured input with 'name' and 'value' fields
if hasattr(node_exec, "executionData") and node_exec.executionData:
exec_data = type_utils.convert(
node_exec.executionData, dict[str, Any]
)
if "name" in exec_data:
name = exec_data["name"]
value = exec_data.get("value")
outputs[name].append(value)
elif node_exec.Input:
# Build input_data from Input relation
input_data = {}
for data in node_exec.Input:
if data.name and data.data is not None:
input_data[data.name] = type_utils.convert(
data.data, JsonValue
)
if "name" in input_data:
name = input_data["name"]
value = input_data.get("value")
outputs[name].append(value)
return SharedExecutionResponse(
id=execution.id,
graph_name=(
execution.AgentGraph.name
if (execution.AgentGraph and execution.AgentGraph.name)
else "Untitled Agent"
),
graph_description=(
execution.AgentGraph.description if execution.AgentGraph else None
),
status=ExecutionStatus(execution.executionStatus),
created_at=execution.createdAt,
outputs=outputs,
)
async def get_frequently_executed_graphs(
days_back: int = 30,
min_executions: int = 10,
) -> list[dict]:
"""Get graphs that have been frequently executed for monitoring."""
query_template = """
SELECT DISTINCT
e."agentGraphId" as graph_id,
e."userId" as user_id,
COUNT(*) as execution_count
FROM {schema_prefix}"AgentGraphExecution" e
WHERE e."createdAt" >= $1::timestamp
AND e."isDeleted" = false
AND e."executionStatus" IN ('COMPLETED', 'FAILED', 'TERMINATED')
GROUP BY e."agentGraphId", e."userId"
HAVING COUNT(*) >= $2
ORDER BY execution_count DESC
"""
start_date = datetime.now(timezone.utc) - timedelta(days=days_back)
result = await query_raw_with_schema(query_template, start_date, min_executions)
return [
{
"graph_id": row["graph_id"],
"user_id": row["user_id"],
"execution_count": int(row["execution_count"]),
}
for row in result
]