mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-05 04:15:08 -05:00
Implements persistent User Workspace storage for CoPilot, enabling
blocks to save and retrieve files across sessions. Files are stored in
session-scoped virtual paths (`/sessions/{session_id}/`).
Fixes SECRT-1833
### Changes 🏗️
**Database & Storage:**
- Add `UserWorkspace` and `UserWorkspaceFile` Prisma models
- Implement `WorkspaceStorageBackend` abstraction (GCS for cloud, local
filesystem for self-hosted)
- Add `workspace_id` and `session_id` fields to `ExecutionContext`
**Backend API:**
- Add REST endpoints: `GET/POST /api/workspace/files`, `GET/DELETE
/api/workspace/files/{id}`, `GET /api/workspace/files/{id}/download`
- Add CoPilot tools: `list_workspace_files`, `read_workspace_file`,
`write_workspace_file`
- Integrate workspace storage into `store_media_file()` - returns
`workspace://file-id` references
**Block Updates:**
- Refactor all file-handling blocks to use unified `ExecutionContext`
parameter
- Update media-generating blocks to persist outputs to workspace
(AIImageGenerator, AIImageCustomizer, FluxKontext, TalkingHead, FAL
video, Bannerbear, etc.)
**Frontend:**
- Render `workspace://` image references in chat via proxy endpoint
- Add "AI cannot see this image" overlay indicator
**CoPilot Context Mapping:**
- Session = Agent (graph_id) = Run (graph_exec_id)
- Files scoped to `/sessions/{session_id}/`
### Checklist 📋
#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Create CoPilot session, generate image with AIImageGeneratorBlock
- [ ] Verify image returns `workspace://file-id` (not base64)
- [ ] Verify image renders in chat with visibility indicator
- [ ] Verify workspace files persist across sessions
- [ ] Test list/read/write workspace files via CoPilot tools
- [ ] Test local storage backend for self-hosted deployments
#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)
🤖 Generated with [Claude Code](https://claude.ai/code)
<!-- CURSOR_SUMMARY -->
---
> [!NOTE]
> **Medium Risk**
> Introduces a new persistent file-storage surface area (DB tables,
storage backends, download API, and chat tools) and rewires
`store_media_file()`/block execution context across many blocks, so
regressions could impact file handling, access control, or storage
costs.
>
> **Overview**
> Adds a **persistent per-user Workspace** (new
`UserWorkspace`/`UserWorkspaceFile` models plus `WorkspaceManager` +
`WorkspaceStorageBackend` with GCS/local implementations) and wires it
into the API via a new `/api/workspace/files/{file_id}/download` route
(including header-sanitized `Content-Disposition`) and shutdown
lifecycle hooks.
>
> Extends `ExecutionContext` to carry execution identity +
`workspace_id`/`session_id`, updates executor tooling to clone
node-specific contexts, and updates `run_block` (CoPilot) to create a
session-scoped workspace and synthetic graph/run/node IDs.
>
> Refactors `store_media_file()` to require `execution_context` +
`return_format` and to support `workspace://` references; migrates many
media/file-handling blocks and related tests to the new API and to
persist generated media as `workspace://...` (or fall back to data URIs
outside CoPilot), and adds CoPilot chat tools for
listing/reading/writing/deleting workspace files with safeguards against
context bloat.
>
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
6abc70f793. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
---------
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
1564 lines
52 KiB
Python
1564 lines
52 KiB
Python
import logging
|
|
from collections import defaultdict
|
|
from datetime import datetime, timedelta, timezone
|
|
from enum import Enum
|
|
from multiprocessing import Manager
|
|
from queue import Empty
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Annotated,
|
|
Any,
|
|
AsyncGenerator,
|
|
Generator,
|
|
Generic,
|
|
Literal,
|
|
Mapping,
|
|
Optional,
|
|
TypeVar,
|
|
cast,
|
|
overload,
|
|
)
|
|
|
|
from prisma import Json
|
|
from prisma.enums import AgentExecutionStatus
|
|
from prisma.models import (
|
|
AgentGraphExecution,
|
|
AgentNodeExecution,
|
|
AgentNodeExecutionInputOutput,
|
|
AgentNodeExecutionKeyValueData,
|
|
)
|
|
from prisma.types import (
|
|
AgentGraphExecutionUpdateManyMutationInput,
|
|
AgentGraphExecutionWhereInput,
|
|
AgentNodeExecutionCreateInput,
|
|
AgentNodeExecutionInputOutputCreateInput,
|
|
AgentNodeExecutionKeyValueDataCreateInput,
|
|
AgentNodeExecutionUpdateInput,
|
|
AgentNodeExecutionWhereInput,
|
|
AgentNodeExecutionWhereUniqueInput,
|
|
)
|
|
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
|
from pydantic.fields import Field
|
|
|
|
from backend.util import type as type_utils
|
|
from backend.util.exceptions import DatabaseError
|
|
from backend.util.json import SafeJson
|
|
from backend.util.models import Pagination
|
|
from backend.util.retry import func_retry
|
|
from backend.util.settings import Config
|
|
from backend.util.truncate import truncate
|
|
|
|
from .block import (
|
|
BlockInput,
|
|
BlockType,
|
|
CompletedBlockOutput,
|
|
get_block,
|
|
get_io_block_ids,
|
|
get_webhook_block_ids,
|
|
)
|
|
from .db import BaseDbModel, query_raw_with_schema
|
|
from .event_bus import AsyncRedisEventBus, RedisEventBus
|
|
from .includes import (
|
|
EXECUTION_RESULT_INCLUDE,
|
|
EXECUTION_RESULT_ORDER,
|
|
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
|
graph_execution_include,
|
|
)
|
|
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
|
|
|
if TYPE_CHECKING:
|
|
pass
|
|
|
|
T = TypeVar("T")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
config = Config()
|
|
|
|
|
|
class ExecutionContext(BaseModel):
|
|
"""
|
|
Unified context that carries execution-level data throughout the entire execution flow.
|
|
This includes information needed by blocks, sub-graphs, and execution management.
|
|
"""
|
|
|
|
model_config = {"extra": "ignore"}
|
|
|
|
# Execution identity
|
|
user_id: Optional[str] = None
|
|
graph_id: Optional[str] = None
|
|
graph_exec_id: Optional[str] = None
|
|
graph_version: Optional[int] = None
|
|
node_id: Optional[str] = None
|
|
node_exec_id: Optional[str] = None
|
|
|
|
# Safety settings
|
|
human_in_the_loop_safe_mode: bool = True
|
|
sensitive_action_safe_mode: bool = False
|
|
|
|
# User settings
|
|
user_timezone: str = "UTC"
|
|
|
|
# Execution hierarchy
|
|
root_execution_id: Optional[str] = None
|
|
parent_execution_id: Optional[str] = None
|
|
|
|
# Workspace
|
|
workspace_id: Optional[str] = None
|
|
session_id: Optional[str] = None
|
|
|
|
|
|
# -------------------------- Models -------------------------- #
|
|
|
|
|
|
class BlockErrorStats(BaseModel):
|
|
"""Typed data structure for block error statistics."""
|
|
|
|
block_id: str
|
|
total_executions: int
|
|
failed_executions: int
|
|
|
|
@property
|
|
def error_rate(self) -> float:
|
|
"""Calculate error rate as a percentage."""
|
|
if self.total_executions == 0:
|
|
return 0.0
|
|
return (self.failed_executions / self.total_executions) * 100
|
|
|
|
|
|
ExecutionStatus = AgentExecutionStatus
|
|
NodeInputMask = Mapping[str, JsonValue]
|
|
NodesInputMasks = Mapping[str, NodeInputMask]
|
|
|
|
# dest: source
|
|
VALID_STATUS_TRANSITIONS = {
|
|
ExecutionStatus.QUEUED: [
|
|
ExecutionStatus.INCOMPLETE,
|
|
ExecutionStatus.TERMINATED, # For resuming halted execution
|
|
ExecutionStatus.REVIEW, # For resuming after review
|
|
],
|
|
ExecutionStatus.RUNNING: [
|
|
ExecutionStatus.INCOMPLETE,
|
|
ExecutionStatus.QUEUED,
|
|
ExecutionStatus.TERMINATED, # For resuming halted execution
|
|
ExecutionStatus.REVIEW, # For resuming after review
|
|
],
|
|
ExecutionStatus.COMPLETED: [
|
|
ExecutionStatus.RUNNING,
|
|
],
|
|
ExecutionStatus.FAILED: [
|
|
ExecutionStatus.INCOMPLETE,
|
|
ExecutionStatus.QUEUED,
|
|
ExecutionStatus.RUNNING,
|
|
ExecutionStatus.REVIEW,
|
|
],
|
|
ExecutionStatus.TERMINATED: [
|
|
ExecutionStatus.INCOMPLETE,
|
|
ExecutionStatus.QUEUED,
|
|
ExecutionStatus.RUNNING,
|
|
ExecutionStatus.REVIEW,
|
|
],
|
|
ExecutionStatus.REVIEW: [
|
|
ExecutionStatus.RUNNING,
|
|
],
|
|
}
|
|
|
|
|
|
class GraphExecutionMeta(BaseDbModel):
|
|
id: str # type: ignore # Override base class to make this required
|
|
user_id: str
|
|
graph_id: str
|
|
graph_version: int
|
|
inputs: Optional[BlockInput] # no default -> required in the OpenAPI spec
|
|
credential_inputs: Optional[dict[str, CredentialsMetaInput]]
|
|
nodes_input_masks: Optional[dict[str, BlockInput]]
|
|
preset_id: Optional[str]
|
|
status: ExecutionStatus
|
|
started_at: Optional[datetime] = Field(
|
|
None,
|
|
description="When execution started running. Null if not yet started (QUEUED).",
|
|
)
|
|
ended_at: Optional[datetime] = Field(
|
|
None,
|
|
description="When execution finished. Null if not yet completed (QUEUED, RUNNING, INCOMPLETE, REVIEW).",
|
|
)
|
|
is_shared: bool = False
|
|
share_token: Optional[str] = None
|
|
|
|
class Stats(BaseModel):
|
|
model_config = ConfigDict(
|
|
extra="allow",
|
|
arbitrary_types_allowed=True,
|
|
)
|
|
|
|
cost: int = Field(
|
|
default=0,
|
|
description="Execution cost (cents)",
|
|
)
|
|
duration: float = Field(
|
|
default=0,
|
|
description="Seconds from start to end of run",
|
|
)
|
|
duration_cpu_only: float = Field(
|
|
default=0,
|
|
description="CPU sec of duration",
|
|
)
|
|
node_exec_time: float = Field(
|
|
default=0,
|
|
description="Seconds of total node runtime",
|
|
)
|
|
node_exec_time_cpu_only: float = Field(
|
|
default=0,
|
|
description="CPU sec of node_exec_time",
|
|
)
|
|
node_exec_count: int = Field(
|
|
default=0,
|
|
description="Number of node executions",
|
|
)
|
|
node_error_count: int = Field(
|
|
default=0,
|
|
description="Number of node errors",
|
|
)
|
|
error: str | None = Field(
|
|
default=None,
|
|
description="Error message if any",
|
|
)
|
|
activity_status: str | None = Field(
|
|
default=None,
|
|
description="AI-generated summary of what the agent did",
|
|
)
|
|
correctness_score: float | None = Field(
|
|
default=None,
|
|
description="AI-generated score (0.0-1.0) indicating how well the execution achieved its intended purpose",
|
|
)
|
|
|
|
def to_db(self) -> GraphExecutionStats:
|
|
return GraphExecutionStats(
|
|
cost=self.cost,
|
|
walltime=self.duration,
|
|
cputime=self.duration_cpu_only,
|
|
nodes_walltime=self.node_exec_time,
|
|
nodes_cputime=self.node_exec_time_cpu_only,
|
|
node_count=self.node_exec_count,
|
|
node_error_count=self.node_error_count,
|
|
error=self.error,
|
|
activity_status=self.activity_status,
|
|
correctness_score=self.correctness_score,
|
|
)
|
|
|
|
def without_activity_features(self) -> "GraphExecutionMeta.Stats":
|
|
"""Return a copy of stats with activity features (activity_status, correctness_score) set to None."""
|
|
return self.model_copy(
|
|
update={"activity_status": None, "correctness_score": None}
|
|
)
|
|
|
|
stats: Stats | None
|
|
|
|
@staticmethod
|
|
def from_db(_graph_exec: AgentGraphExecution):
|
|
start_time = _graph_exec.startedAt
|
|
end_time = _graph_exec.endedAt
|
|
|
|
try:
|
|
stats = GraphExecutionStats.model_validate(_graph_exec.stats)
|
|
except ValueError as e:
|
|
if _graph_exec.stats is not None:
|
|
logger.warning(
|
|
"Failed to parse invalid graph execution stats "
|
|
f"{_graph_exec.stats}: {e}"
|
|
)
|
|
stats = None
|
|
|
|
return GraphExecutionMeta(
|
|
id=_graph_exec.id,
|
|
user_id=_graph_exec.userId,
|
|
graph_id=_graph_exec.agentGraphId,
|
|
graph_version=_graph_exec.agentGraphVersion,
|
|
inputs=cast(BlockInput | None, _graph_exec.inputs),
|
|
credential_inputs=(
|
|
{
|
|
name: CredentialsMetaInput.model_validate(cmi)
|
|
for name, cmi in cast(dict, _graph_exec.credentialInputs).items()
|
|
}
|
|
if _graph_exec.credentialInputs
|
|
else None
|
|
),
|
|
nodes_input_masks=cast(
|
|
dict[str, BlockInput] | None, _graph_exec.nodesInputMasks
|
|
),
|
|
preset_id=_graph_exec.agentPresetId,
|
|
status=ExecutionStatus(_graph_exec.executionStatus),
|
|
started_at=start_time,
|
|
ended_at=end_time,
|
|
stats=(
|
|
GraphExecutionMeta.Stats(
|
|
cost=stats.cost,
|
|
duration=stats.walltime,
|
|
duration_cpu_only=stats.cputime,
|
|
node_exec_time=stats.nodes_walltime,
|
|
node_exec_time_cpu_only=stats.nodes_cputime,
|
|
node_exec_count=stats.node_count,
|
|
node_error_count=stats.node_error_count,
|
|
error=(
|
|
str(stats.error)
|
|
if isinstance(stats.error, Exception)
|
|
else stats.error
|
|
),
|
|
activity_status=stats.activity_status,
|
|
correctness_score=stats.correctness_score,
|
|
)
|
|
if stats
|
|
else None
|
|
),
|
|
is_shared=_graph_exec.isShared,
|
|
share_token=_graph_exec.shareToken,
|
|
)
|
|
|
|
|
|
class GraphExecution(GraphExecutionMeta):
|
|
inputs: BlockInput # type: ignore - incompatible override is intentional
|
|
outputs: CompletedBlockOutput
|
|
|
|
@staticmethod
|
|
def from_db(_graph_exec: AgentGraphExecution):
|
|
if _graph_exec.NodeExecutions is None:
|
|
raise ValueError("Node executions must be included in query")
|
|
|
|
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
|
|
|
|
complete_node_executions = sorted(
|
|
[
|
|
NodeExecutionResult.from_db(ne, _graph_exec.userId)
|
|
for ne in _graph_exec.NodeExecutions
|
|
if ne.executionStatus != ExecutionStatus.INCOMPLETE
|
|
],
|
|
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
|
)
|
|
|
|
inputs = {
|
|
**(
|
|
graph_exec.inputs
|
|
or {
|
|
# fallback: extract inputs from Agent Input Blocks
|
|
exec.input_data["name"]: exec.input_data.get("value")
|
|
for exec in complete_node_executions
|
|
if (
|
|
(block := get_block(exec.block_id))
|
|
and block.block_type == BlockType.INPUT
|
|
)
|
|
}
|
|
),
|
|
**{
|
|
# input from webhook-triggered block
|
|
"payload": exec.input_data["payload"]
|
|
for exec in complete_node_executions
|
|
if (
|
|
(block := get_block(exec.block_id))
|
|
and block.block_type
|
|
in [BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL]
|
|
)
|
|
},
|
|
}
|
|
|
|
outputs: CompletedBlockOutput = defaultdict(list)
|
|
for exec in complete_node_executions:
|
|
if (
|
|
block := get_block(exec.block_id)
|
|
) and block.block_type == BlockType.OUTPUT:
|
|
outputs[exec.input_data["name"]].append(exec.input_data.get("value"))
|
|
|
|
return GraphExecution(
|
|
**{
|
|
field_name: getattr(graph_exec, field_name)
|
|
for field_name in GraphExecutionMeta.model_fields
|
|
if field_name != "inputs"
|
|
},
|
|
inputs=inputs,
|
|
outputs=outputs,
|
|
)
|
|
|
|
|
|
class GraphExecutionWithNodes(GraphExecution):
|
|
node_executions: list["NodeExecutionResult"]
|
|
|
|
@staticmethod
|
|
def from_db(_graph_exec: AgentGraphExecution):
|
|
if _graph_exec.NodeExecutions is None:
|
|
raise ValueError("Node executions must be included in query")
|
|
|
|
graph_exec_with_io = GraphExecution.from_db(_graph_exec)
|
|
|
|
node_executions = sorted(
|
|
[
|
|
NodeExecutionResult.from_db(ne, _graph_exec.userId)
|
|
for ne in _graph_exec.NodeExecutions
|
|
],
|
|
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
|
)
|
|
|
|
return GraphExecutionWithNodes(
|
|
**{
|
|
field_name: getattr(graph_exec_with_io, field_name)
|
|
for field_name in GraphExecution.model_fields
|
|
},
|
|
node_executions=node_executions,
|
|
)
|
|
|
|
def to_graph_execution_entry(
|
|
self,
|
|
execution_context: ExecutionContext,
|
|
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
|
nodes_to_skip: Optional[set[str]] = None,
|
|
):
|
|
return GraphExecutionEntry(
|
|
user_id=self.user_id,
|
|
graph_id=self.graph_id,
|
|
graph_version=self.graph_version or 0,
|
|
graph_exec_id=self.id,
|
|
nodes_input_masks=compiled_nodes_input_masks,
|
|
nodes_to_skip=nodes_to_skip or set(),
|
|
execution_context=execution_context,
|
|
)
|
|
|
|
|
|
class NodeExecutionResult(BaseModel):
|
|
user_id: str
|
|
graph_id: str
|
|
graph_version: int
|
|
graph_exec_id: str
|
|
node_exec_id: str
|
|
node_id: str
|
|
block_id: str
|
|
status: ExecutionStatus
|
|
input_data: BlockInput
|
|
output_data: CompletedBlockOutput
|
|
add_time: datetime
|
|
queue_time: datetime | None
|
|
start_time: datetime | None
|
|
end_time: datetime | None
|
|
|
|
@staticmethod
|
|
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
|
|
try:
|
|
stats = NodeExecutionStats.model_validate(_node_exec.stats or {})
|
|
except (ValueError, ValidationError):
|
|
stats = NodeExecutionStats()
|
|
|
|
if stats.cleared_inputs:
|
|
input_data: BlockInput = defaultdict()
|
|
for name, messages in stats.cleared_inputs.items():
|
|
input_data[name] = messages[-1] if messages else ""
|
|
elif _node_exec.executionData:
|
|
input_data = type_utils.convert(_node_exec.executionData, dict[str, Any])
|
|
else:
|
|
input_data: BlockInput = defaultdict()
|
|
for data in _node_exec.Input or []:
|
|
input_data[data.name] = type_utils.convert(data.data, JsonValue)
|
|
|
|
output_data: CompletedBlockOutput = defaultdict(list)
|
|
|
|
if stats.cleared_outputs:
|
|
for name, messages in stats.cleared_outputs.items():
|
|
output_data[name].extend(messages)
|
|
else:
|
|
for data in _node_exec.Output or []:
|
|
output_data[data.name].append(type_utils.convert(data.data, JsonValue))
|
|
|
|
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
|
|
if graph_execution:
|
|
user_id = graph_execution.userId
|
|
elif not user_id:
|
|
raise ValueError(
|
|
"AgentGraphExecution must be included or user_id passed in"
|
|
)
|
|
|
|
return NodeExecutionResult(
|
|
user_id=user_id,
|
|
graph_id=graph_execution.agentGraphId if graph_execution else "",
|
|
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
|
|
graph_exec_id=_node_exec.agentGraphExecutionId,
|
|
block_id=_node_exec.Node.agentBlockId if _node_exec.Node else "",
|
|
node_exec_id=_node_exec.id,
|
|
node_id=_node_exec.agentNodeId,
|
|
status=_node_exec.executionStatus,
|
|
input_data=input_data,
|
|
output_data=output_data,
|
|
add_time=_node_exec.addedTime,
|
|
queue_time=_node_exec.queuedTime,
|
|
start_time=_node_exec.startedTime,
|
|
end_time=_node_exec.endedTime,
|
|
)
|
|
|
|
def to_node_execution_entry(
|
|
self, execution_context: ExecutionContext
|
|
) -> "NodeExecutionEntry":
|
|
return NodeExecutionEntry(
|
|
user_id=self.user_id,
|
|
graph_exec_id=self.graph_exec_id,
|
|
graph_id=self.graph_id,
|
|
graph_version=self.graph_version,
|
|
node_exec_id=self.node_exec_id,
|
|
node_id=self.node_id,
|
|
block_id=self.block_id,
|
|
inputs=self.input_data,
|
|
execution_context=execution_context,
|
|
)
|
|
|
|
|
|
# --------------------- Model functions --------------------- #
|
|
|
|
|
|
async def get_graph_executions(
|
|
graph_exec_id: Optional[str] = None,
|
|
graph_id: Optional[str] = None,
|
|
graph_version: Optional[int] = None,
|
|
user_id: Optional[str] = None,
|
|
statuses: Optional[list[ExecutionStatus]] = None,
|
|
created_time_gte: Optional[datetime] = None,
|
|
created_time_lte: Optional[datetime] = None,
|
|
limit: Optional[int] = None,
|
|
) -> list[GraphExecutionMeta]:
|
|
"""⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints."""
|
|
where_filter: AgentGraphExecutionWhereInput = {
|
|
"isDeleted": False,
|
|
}
|
|
if graph_exec_id:
|
|
where_filter["id"] = graph_exec_id
|
|
if user_id:
|
|
where_filter["userId"] = user_id
|
|
if graph_id:
|
|
where_filter["agentGraphId"] = graph_id
|
|
if graph_version is not None:
|
|
where_filter["agentGraphVersion"] = graph_version
|
|
if created_time_gte or created_time_lte:
|
|
where_filter["createdAt"] = {
|
|
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
|
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
|
}
|
|
if statuses:
|
|
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
|
|
|
|
executions = await AgentGraphExecution.prisma().find_many(
|
|
where=where_filter,
|
|
order={"createdAt": "desc"},
|
|
take=limit,
|
|
)
|
|
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
|
|
|
|
|
async def get_graph_executions_count(
|
|
user_id: Optional[str] = None,
|
|
graph_id: Optional[str] = None,
|
|
statuses: Optional[list[ExecutionStatus]] = None,
|
|
created_time_gte: Optional[datetime] = None,
|
|
created_time_lte: Optional[datetime] = None,
|
|
) -> int:
|
|
"""
|
|
Get count of graph executions with optional filters.
|
|
|
|
Args:
|
|
user_id: Optional user ID to filter by
|
|
graph_id: Optional graph ID to filter by
|
|
statuses: Optional list of execution statuses to filter by
|
|
created_time_gte: Optional minimum creation time
|
|
created_time_lte: Optional maximum creation time
|
|
|
|
Returns:
|
|
Count of matching graph executions
|
|
"""
|
|
where_filter: AgentGraphExecutionWhereInput = {
|
|
"isDeleted": False,
|
|
}
|
|
|
|
if user_id:
|
|
where_filter["userId"] = user_id
|
|
|
|
if graph_id:
|
|
where_filter["agentGraphId"] = graph_id
|
|
|
|
if created_time_gte or created_time_lte:
|
|
where_filter["createdAt"] = {
|
|
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
|
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
|
}
|
|
if statuses:
|
|
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
|
|
|
|
count = await AgentGraphExecution.prisma().count(where=where_filter)
|
|
return count
|
|
|
|
|
|
class GraphExecutionsPaginated(BaseModel):
|
|
"""Response schema for paginated graph executions."""
|
|
|
|
executions: list[GraphExecutionMeta]
|
|
pagination: Pagination
|
|
|
|
|
|
async def get_graph_executions_paginated(
|
|
user_id: str,
|
|
graph_id: Optional[str] = None,
|
|
page: int = 1,
|
|
page_size: int = 25,
|
|
statuses: Optional[list[ExecutionStatus]] = None,
|
|
created_time_gte: Optional[datetime] = None,
|
|
created_time_lte: Optional[datetime] = None,
|
|
) -> GraphExecutionsPaginated:
|
|
"""Get paginated graph executions for a specific graph."""
|
|
where_filter: AgentGraphExecutionWhereInput = {
|
|
"isDeleted": False,
|
|
"userId": user_id,
|
|
}
|
|
|
|
if graph_id:
|
|
where_filter["agentGraphId"] = graph_id
|
|
if created_time_gte or created_time_lte:
|
|
where_filter["createdAt"] = {
|
|
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
|
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
|
}
|
|
if statuses:
|
|
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
|
|
|
|
total_count = await AgentGraphExecution.prisma().count(where=where_filter)
|
|
total_pages = (total_count + page_size - 1) // page_size
|
|
|
|
offset = (page - 1) * page_size
|
|
executions = await AgentGraphExecution.prisma().find_many(
|
|
where=where_filter,
|
|
order={"createdAt": "desc"},
|
|
take=page_size,
|
|
skip=offset,
|
|
)
|
|
|
|
return GraphExecutionsPaginated(
|
|
executions=[GraphExecutionMeta.from_db(execution) for execution in executions],
|
|
pagination=Pagination(
|
|
total_items=total_count,
|
|
total_pages=total_pages,
|
|
current_page=page,
|
|
page_size=page_size,
|
|
),
|
|
)
|
|
|
|
|
|
async def get_graph_execution_meta(
|
|
user_id: str, execution_id: str
|
|
) -> GraphExecutionMeta | None:
|
|
execution = await AgentGraphExecution.prisma().find_first(
|
|
where={"id": execution_id, "isDeleted": False, "userId": user_id}
|
|
)
|
|
return GraphExecutionMeta.from_db(execution) if execution else None
|
|
|
|
|
|
@overload
|
|
async def get_graph_execution(
|
|
user_id: str,
|
|
execution_id: str,
|
|
include_node_executions: Literal[True],
|
|
) -> GraphExecutionWithNodes | None: ...
|
|
|
|
|
|
@overload
|
|
async def get_graph_execution(
|
|
user_id: str,
|
|
execution_id: str,
|
|
include_node_executions: Literal[False] = False,
|
|
) -> GraphExecution | None: ...
|
|
|
|
|
|
@overload
|
|
async def get_graph_execution(
|
|
user_id: str,
|
|
execution_id: str,
|
|
include_node_executions: bool = False,
|
|
) -> GraphExecution | GraphExecutionWithNodes | None: ...
|
|
|
|
|
|
async def get_graph_execution(
|
|
user_id: str,
|
|
execution_id: str,
|
|
include_node_executions: bool = False,
|
|
) -> GraphExecution | GraphExecutionWithNodes | None:
|
|
execution = await AgentGraphExecution.prisma().find_first(
|
|
where={"id": execution_id, "isDeleted": False, "userId": user_id},
|
|
include=(
|
|
GRAPH_EXECUTION_INCLUDE_WITH_NODES
|
|
if include_node_executions
|
|
else graph_execution_include(
|
|
[*get_io_block_ids(), *get_webhook_block_ids()]
|
|
)
|
|
),
|
|
)
|
|
if not execution:
|
|
return None
|
|
|
|
return (
|
|
GraphExecutionWithNodes.from_db(execution)
|
|
if include_node_executions
|
|
else GraphExecution.from_db(execution)
|
|
)
|
|
|
|
|
|
async def get_child_graph_executions(
|
|
parent_exec_id: str,
|
|
) -> list[GraphExecutionMeta]:
|
|
"""
|
|
Get all child executions of a parent execution.
|
|
|
|
Args:
|
|
parent_exec_id: Parent graph execution ID
|
|
|
|
Returns:
|
|
List of child graph executions
|
|
"""
|
|
children = await AgentGraphExecution.prisma().find_many(
|
|
where={"parentGraphExecutionId": parent_exec_id, "isDeleted": False}
|
|
)
|
|
|
|
return [GraphExecutionMeta.from_db(child) for child in children]
|
|
|
|
|
|
async def create_graph_execution(
|
|
graph_id: str,
|
|
graph_version: int,
|
|
starting_nodes_input: list[tuple[str, BlockInput]], # list[(node_id, BlockInput)]
|
|
inputs: Mapping[str, JsonValue],
|
|
user_id: str,
|
|
preset_id: Optional[str] = None,
|
|
credential_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
|
parent_graph_exec_id: Optional[str] = None,
|
|
) -> GraphExecutionWithNodes:
|
|
"""
|
|
Create a new AgentGraphExecution record.
|
|
Returns:
|
|
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
|
"""
|
|
result = await AgentGraphExecution.prisma().create(
|
|
data={
|
|
"agentGraphId": graph_id,
|
|
"agentGraphVersion": graph_version,
|
|
"executionStatus": ExecutionStatus.INCOMPLETE,
|
|
"inputs": SafeJson(inputs),
|
|
"credentialInputs": (
|
|
SafeJson(credential_inputs) if credential_inputs else Json({})
|
|
),
|
|
"nodesInputMasks": (
|
|
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
|
),
|
|
"NodeExecutions": {
|
|
"create": [
|
|
AgentNodeExecutionCreateInput(
|
|
agentNodeId=node_id,
|
|
executionStatus=ExecutionStatus.QUEUED,
|
|
queuedTime=datetime.now(tz=timezone.utc),
|
|
Input={
|
|
"create": [
|
|
{"name": name, "data": SafeJson(data)}
|
|
for name, data in node_input.items()
|
|
]
|
|
},
|
|
)
|
|
for node_id, node_input in starting_nodes_input
|
|
]
|
|
},
|
|
"userId": user_id,
|
|
"agentPresetId": preset_id,
|
|
"parentGraphExecutionId": parent_graph_exec_id,
|
|
},
|
|
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
|
)
|
|
|
|
return GraphExecutionWithNodes.from_db(result)
|
|
|
|
|
|
async def upsert_execution_input(
|
|
node_id: str,
|
|
graph_exec_id: str,
|
|
input_name: str,
|
|
input_data: JsonValue,
|
|
node_exec_id: str | None = None,
|
|
) -> tuple[NodeExecutionResult, BlockInput]:
|
|
"""
|
|
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Input.
|
|
If there is no AgentNodeExecution that has no `input_name` as input, create new one.
|
|
|
|
Args:
|
|
node_id: The id of the AgentNode.
|
|
graph_exec_id: The id of the AgentGraphExecution.
|
|
input_name: The name of the input data.
|
|
input_data: The input data to be inserted.
|
|
node_exec_id: [Optional] The id of the AgentNodeExecution that has no `input_name` as input. If not provided, it will find the eligible incomplete AgentNodeExecution or create a new one.
|
|
|
|
Returns:
|
|
str: The id of the created or existing AgentNodeExecution.
|
|
dict[str, Any]: Node input data; key is the input name, value is the input data.
|
|
"""
|
|
existing_exec_query_filter: AgentNodeExecutionWhereInput = {
|
|
"agentGraphExecutionId": graph_exec_id,
|
|
"agentNodeId": node_id,
|
|
"executionStatus": ExecutionStatus.INCOMPLETE,
|
|
"Input": {
|
|
"none": {
|
|
"name": input_name,
|
|
"time": {"gte": datetime.now(tz=timezone.utc) - timedelta(days=1)},
|
|
}
|
|
},
|
|
}
|
|
if node_exec_id:
|
|
existing_exec_query_filter["id"] = node_exec_id
|
|
|
|
existing_execution = await AgentNodeExecution.prisma().find_first(
|
|
where=existing_exec_query_filter,
|
|
order={"addedTime": "asc"},
|
|
include={"Input": True, "GraphExecution": True},
|
|
)
|
|
json_input_data = SafeJson(input_data)
|
|
|
|
if existing_execution:
|
|
await AgentNodeExecutionInputOutput.prisma().create(
|
|
data=AgentNodeExecutionInputOutputCreateInput(
|
|
name=input_name,
|
|
data=json_input_data,
|
|
referencedByInputExecId=existing_execution.id,
|
|
)
|
|
)
|
|
return NodeExecutionResult.from_db(existing_execution), {
|
|
**{
|
|
input_data.name: type_utils.convert(input_data.data, JsonValue)
|
|
for input_data in existing_execution.Input or []
|
|
},
|
|
input_name: input_data,
|
|
}
|
|
|
|
elif not node_exec_id:
|
|
result = await AgentNodeExecution.prisma().create(
|
|
data=AgentNodeExecutionCreateInput(
|
|
agentNodeId=node_id,
|
|
agentGraphExecutionId=graph_exec_id,
|
|
executionStatus=ExecutionStatus.INCOMPLETE,
|
|
Input={"create": {"name": input_name, "data": json_input_data}},
|
|
),
|
|
include={"GraphExecution": True},
|
|
)
|
|
return NodeExecutionResult.from_db(result), {input_name: input_data}
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"NodeExecution {node_exec_id} not found or already has input {input_name}."
|
|
)
|
|
|
|
|
|
async def upsert_execution_output(
|
|
node_exec_id: str,
|
|
output_name: str,
|
|
output_data: Any | None,
|
|
) -> None:
|
|
"""
|
|
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
|
"""
|
|
data: AgentNodeExecutionInputOutputCreateInput = {
|
|
"name": output_name,
|
|
"referencedByOutputExecId": node_exec_id,
|
|
}
|
|
if output_data is not None:
|
|
data["data"] = SafeJson(output_data)
|
|
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
|
|
|
|
|
async def get_execution_outputs_by_node_exec_id(
|
|
node_exec_id: str,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Get all execution outputs for a specific node execution ID.
|
|
|
|
Args:
|
|
node_exec_id: The node execution ID to get outputs for
|
|
|
|
Returns:
|
|
Dictionary mapping output names to their data values
|
|
"""
|
|
outputs = await AgentNodeExecutionInputOutput.prisma().find_many(
|
|
where={"referencedByOutputExecId": node_exec_id}
|
|
)
|
|
|
|
result = {}
|
|
for output in outputs:
|
|
if output.data is not None:
|
|
result[output.name] = type_utils.convert(output.data, JsonValue)
|
|
|
|
return result
|
|
|
|
|
|
async def update_graph_execution_start_time(
|
|
graph_exec_id: str,
|
|
) -> GraphExecution | None:
|
|
res = await AgentGraphExecution.prisma().update(
|
|
where={"id": graph_exec_id},
|
|
data={
|
|
"executionStatus": ExecutionStatus.RUNNING,
|
|
"startedAt": datetime.now(tz=timezone.utc),
|
|
},
|
|
include=graph_execution_include(
|
|
[*get_io_block_ids(), *get_webhook_block_ids()]
|
|
),
|
|
)
|
|
return GraphExecution.from_db(res) if res else None
|
|
|
|
|
|
async def update_graph_execution_stats(
|
|
graph_exec_id: str,
|
|
status: ExecutionStatus | None = None,
|
|
stats: GraphExecutionStats | None = None,
|
|
) -> GraphExecution | None:
|
|
if not status and not stats:
|
|
raise ValueError(
|
|
f"Must provide either status or stats to update for execution {graph_exec_id}"
|
|
)
|
|
|
|
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
|
|
|
|
if stats:
|
|
stats_dict = stats.model_dump()
|
|
if isinstance(stats_dict.get("error"), Exception):
|
|
stats_dict["error"] = str(stats_dict["error"])
|
|
update_data["stats"] = SafeJson(stats_dict)
|
|
|
|
if status:
|
|
update_data["executionStatus"] = status
|
|
# Set endedAt when execution reaches a terminal status
|
|
terminal_statuses = [
|
|
ExecutionStatus.COMPLETED,
|
|
ExecutionStatus.FAILED,
|
|
ExecutionStatus.TERMINATED,
|
|
]
|
|
if status in terminal_statuses:
|
|
update_data["endedAt"] = datetime.now(tz=timezone.utc)
|
|
|
|
where_clause: AgentGraphExecutionWhereInput = {"id": graph_exec_id}
|
|
|
|
if status:
|
|
if allowed_from := VALID_STATUS_TRANSITIONS.get(status, []):
|
|
# Add OR clause to check if current status is one of the allowed source statuses
|
|
where_clause["AND"] = [
|
|
{"id": graph_exec_id},
|
|
{"OR": [{"executionStatus": s} for s in allowed_from]},
|
|
]
|
|
else:
|
|
raise ValueError(
|
|
f"Status {status} cannot be set via update for execution {graph_exec_id}. "
|
|
f"This status can only be set at creation or is not a valid target status."
|
|
)
|
|
|
|
await AgentGraphExecution.prisma().update_many(
|
|
where=where_clause,
|
|
data=update_data,
|
|
)
|
|
|
|
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
|
|
where={"id": graph_exec_id},
|
|
include=graph_execution_include(
|
|
[*get_io_block_ids(), *get_webhook_block_ids()]
|
|
),
|
|
)
|
|
|
|
return GraphExecution.from_db(graph_exec)
|
|
|
|
|
|
async def update_node_execution_status_batch(
|
|
node_exec_ids: list[str],
|
|
status: ExecutionStatus,
|
|
stats: dict[str, Any] | None = None,
|
|
) -> int:
|
|
# Validate status transitions - allowed_from should never be empty for valid statuses
|
|
allowed_from = VALID_STATUS_TRANSITIONS.get(status, [])
|
|
if not allowed_from:
|
|
raise ValueError(
|
|
f"Invalid status transition: {status} has no valid source statuses"
|
|
)
|
|
|
|
# For batch updates, we filter to only update nodes with valid current statuses
|
|
where_clause = cast(
|
|
AgentNodeExecutionWhereInput,
|
|
{
|
|
"id": {"in": node_exec_ids},
|
|
"executionStatus": {"in": [s.value for s in allowed_from]},
|
|
},
|
|
)
|
|
|
|
return await AgentNodeExecution.prisma().update_many(
|
|
where=where_clause,
|
|
data=_get_update_status_data(status, None, stats),
|
|
)
|
|
|
|
|
|
async def update_node_execution_status(
|
|
node_exec_id: str,
|
|
status: ExecutionStatus,
|
|
execution_data: BlockInput | None = None,
|
|
stats: dict[str, Any] | None = None,
|
|
) -> NodeExecutionResult:
|
|
if status == ExecutionStatus.QUEUED and execution_data is None:
|
|
raise ValueError("Execution data must be provided when queuing an execution.")
|
|
|
|
# Validate status transitions - allowed_from should never be empty for valid statuses
|
|
allowed_from = VALID_STATUS_TRANSITIONS.get(status, [])
|
|
if not allowed_from:
|
|
raise ValueError(
|
|
f"Invalid status transition: {status} has no valid source statuses"
|
|
)
|
|
|
|
if res := await AgentNodeExecution.prisma().update(
|
|
where=cast(
|
|
AgentNodeExecutionWhereUniqueInput,
|
|
{
|
|
"id": node_exec_id,
|
|
"executionStatus": {"in": [s.value for s in allowed_from]},
|
|
},
|
|
),
|
|
data=_get_update_status_data(status, execution_data, stats),
|
|
include=EXECUTION_RESULT_INCLUDE,
|
|
):
|
|
return NodeExecutionResult.from_db(res)
|
|
|
|
if res := await AgentNodeExecution.prisma().find_unique(
|
|
where={"id": node_exec_id}, include=EXECUTION_RESULT_INCLUDE
|
|
):
|
|
return NodeExecutionResult.from_db(res)
|
|
|
|
raise ValueError(f"Execution {node_exec_id} not found.")
|
|
|
|
|
|
def _get_update_status_data(
|
|
status: ExecutionStatus,
|
|
execution_data: BlockInput | None = None,
|
|
stats: dict[str, Any] | None = None,
|
|
) -> AgentNodeExecutionUpdateInput:
|
|
now = datetime.now(tz=timezone.utc)
|
|
update_data: AgentNodeExecutionUpdateInput = {"executionStatus": status}
|
|
|
|
if status == ExecutionStatus.QUEUED:
|
|
update_data["queuedTime"] = now
|
|
elif status == ExecutionStatus.RUNNING:
|
|
update_data["startedTime"] = now
|
|
elif status in (ExecutionStatus.FAILED, ExecutionStatus.COMPLETED):
|
|
update_data["endedTime"] = now
|
|
|
|
if execution_data:
|
|
update_data["executionData"] = SafeJson(execution_data)
|
|
if stats:
|
|
update_data["stats"] = SafeJson(stats)
|
|
|
|
return update_data
|
|
|
|
|
|
async def delete_graph_execution(
|
|
graph_exec_id: str, user_id: str, soft_delete: bool = True
|
|
) -> None:
|
|
if soft_delete:
|
|
deleted_count = await AgentGraphExecution.prisma().update_many(
|
|
where={"id": graph_exec_id, "userId": user_id}, data={"isDeleted": True}
|
|
)
|
|
else:
|
|
deleted_count = await AgentGraphExecution.prisma().delete_many(
|
|
where={"id": graph_exec_id, "userId": user_id}
|
|
)
|
|
if deleted_count < 1:
|
|
raise DatabaseError(
|
|
f"Could not delete graph execution #{graph_exec_id}: not found"
|
|
)
|
|
|
|
|
|
async def get_node_execution(node_exec_id: str) -> NodeExecutionResult | None:
|
|
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
|
execution = await AgentNodeExecution.prisma().find_first(
|
|
where={"id": node_exec_id},
|
|
include=EXECUTION_RESULT_INCLUDE,
|
|
)
|
|
if not execution:
|
|
return None
|
|
return NodeExecutionResult.from_db(execution)
|
|
|
|
|
|
def _build_node_execution_where_clause(
|
|
graph_exec_id: str | None = None,
|
|
node_id: str | None = None,
|
|
block_ids: list[str] | None = None,
|
|
statuses: list[ExecutionStatus] | None = None,
|
|
created_time_gte: datetime | None = None,
|
|
created_time_lte: datetime | None = None,
|
|
) -> AgentNodeExecutionWhereInput:
|
|
"""
|
|
Build where clause for node execution queries.
|
|
"""
|
|
where_clause: AgentNodeExecutionWhereInput = {}
|
|
if graph_exec_id:
|
|
where_clause["agentGraphExecutionId"] = graph_exec_id
|
|
if node_id:
|
|
where_clause["agentNodeId"] = node_id
|
|
if block_ids:
|
|
where_clause["Node"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
|
if statuses:
|
|
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
|
|
|
|
if created_time_gte or created_time_lte:
|
|
where_clause["addedTime"] = {
|
|
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
|
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
|
}
|
|
|
|
return where_clause
|
|
|
|
|
|
async def get_node_executions(
|
|
graph_exec_id: str | None = None,
|
|
node_id: str | None = None,
|
|
block_ids: list[str] | None = None,
|
|
statuses: list[ExecutionStatus] | None = None,
|
|
limit: int | None = None,
|
|
created_time_gte: datetime | None = None,
|
|
created_time_lte: datetime | None = None,
|
|
include_exec_data: bool = True,
|
|
) -> list[NodeExecutionResult]:
|
|
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
|
where_clause = _build_node_execution_where_clause(
|
|
graph_exec_id=graph_exec_id,
|
|
node_id=node_id,
|
|
block_ids=block_ids,
|
|
statuses=statuses,
|
|
created_time_gte=created_time_gte,
|
|
created_time_lte=created_time_lte,
|
|
)
|
|
|
|
executions = await AgentNodeExecution.prisma().find_many(
|
|
where=where_clause,
|
|
include=(
|
|
EXECUTION_RESULT_INCLUDE
|
|
if include_exec_data
|
|
else {"Node": True, "GraphExecution": True}
|
|
),
|
|
order=EXECUTION_RESULT_ORDER,
|
|
take=limit,
|
|
)
|
|
res = [NodeExecutionResult.from_db(execution) for execution in executions]
|
|
return res
|
|
|
|
|
|
async def get_latest_node_execution(
|
|
node_id: str, graph_eid: str
|
|
) -> NodeExecutionResult | None:
|
|
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
|
execution = await AgentNodeExecution.prisma().find_first(
|
|
where={
|
|
"agentGraphExecutionId": graph_eid,
|
|
"agentNodeId": node_id,
|
|
"OR": [
|
|
{"executionStatus": ExecutionStatus.QUEUED},
|
|
{"executionStatus": ExecutionStatus.RUNNING},
|
|
{"executionStatus": ExecutionStatus.COMPLETED},
|
|
{"executionStatus": ExecutionStatus.TERMINATED},
|
|
{"executionStatus": ExecutionStatus.FAILED},
|
|
],
|
|
},
|
|
include=EXECUTION_RESULT_INCLUDE,
|
|
order=EXECUTION_RESULT_ORDER,
|
|
)
|
|
if not execution:
|
|
return None
|
|
return NodeExecutionResult.from_db(execution)
|
|
|
|
|
|
# ----------------- Execution Infrastructure ----------------- #
|
|
|
|
|
|
class GraphExecutionEntry(BaseModel):
|
|
model_config = {"extra": "ignore"}
|
|
|
|
user_id: str
|
|
graph_exec_id: str
|
|
graph_id: str
|
|
graph_version: int
|
|
nodes_input_masks: Optional[NodesInputMasks] = None
|
|
nodes_to_skip: set[str] = Field(default_factory=set)
|
|
"""Node IDs that should be skipped due to optional credentials not being configured."""
|
|
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
|
|
|
|
|
|
class NodeExecutionEntry(BaseModel):
|
|
model_config = {"extra": "ignore"}
|
|
|
|
user_id: str
|
|
graph_exec_id: str
|
|
graph_id: str
|
|
graph_version: int
|
|
node_exec_id: str
|
|
node_id: str
|
|
block_id: str
|
|
inputs: BlockInput
|
|
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
|
|
|
|
|
|
class ExecutionQueue(Generic[T]):
|
|
"""
|
|
Queue for managing the execution of agents.
|
|
This will be shared between different processes
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.queue = Manager().Queue()
|
|
|
|
def add(self, execution: T) -> T:
|
|
self.queue.put(execution)
|
|
return execution
|
|
|
|
def get(self) -> T:
|
|
return self.queue.get()
|
|
|
|
def empty(self) -> bool:
|
|
return self.queue.empty()
|
|
|
|
def get_or_none(self) -> T | None:
|
|
try:
|
|
return self.queue.get_nowait()
|
|
except Empty:
|
|
return None
|
|
|
|
|
|
# --------------------- Event Bus --------------------- #
|
|
|
|
|
|
class ExecutionEventType(str, Enum):
|
|
GRAPH_EXEC_UPDATE = "graph_execution_update"
|
|
NODE_EXEC_UPDATE = "node_execution_update"
|
|
ERROR_COMMS_UPDATE = "error_comms_update"
|
|
|
|
|
|
class GraphExecutionEvent(GraphExecution):
|
|
event_type: Literal[ExecutionEventType.GRAPH_EXEC_UPDATE] = (
|
|
ExecutionEventType.GRAPH_EXEC_UPDATE
|
|
)
|
|
|
|
|
|
class NodeExecutionEvent(NodeExecutionResult):
|
|
event_type: Literal[ExecutionEventType.NODE_EXEC_UPDATE] = (
|
|
ExecutionEventType.NODE_EXEC_UPDATE
|
|
)
|
|
|
|
|
|
class SharedExecutionResponse(BaseModel):
|
|
"""Public-safe response for shared executions"""
|
|
|
|
id: str
|
|
graph_name: str
|
|
graph_description: Optional[str]
|
|
status: ExecutionStatus
|
|
created_at: datetime
|
|
outputs: CompletedBlockOutput # Only the final outputs, no intermediate data
|
|
# Deliberately exclude: user_id, inputs, credentials, node details
|
|
|
|
|
|
ExecutionEvent = Annotated[
|
|
GraphExecutionEvent | NodeExecutionEvent, Field(discriminator="event_type")
|
|
]
|
|
|
|
|
|
class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
|
|
Model = ExecutionEvent # type: ignore
|
|
|
|
@property
|
|
def event_bus_name(self) -> str:
|
|
return config.execution_event_bus_name
|
|
|
|
def publish(self, res: GraphExecution | NodeExecutionResult):
|
|
if isinstance(res, GraphExecution):
|
|
self._publish_graph_exec_update(res)
|
|
else:
|
|
self._publish_node_exec_update(res)
|
|
|
|
def _publish_node_exec_update(self, res: NodeExecutionResult):
|
|
event = NodeExecutionEvent.model_validate(res.model_dump())
|
|
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
|
|
|
def _publish_graph_exec_update(self, res: GraphExecution):
|
|
event = GraphExecutionEvent.model_validate(res.model_dump())
|
|
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
|
|
|
def _publish(self, event: ExecutionEvent, channel: str):
|
|
"""
|
|
truncate inputs and outputs to avoid large payloads
|
|
"""
|
|
limit = config.max_message_size_limit // 2
|
|
if isinstance(event, GraphExecutionEvent):
|
|
event.inputs = truncate(event.inputs, limit)
|
|
event.outputs = truncate(event.outputs, limit)
|
|
elif isinstance(event, NodeExecutionEvent):
|
|
event.input_data = truncate(event.input_data, limit)
|
|
event.output_data = truncate(event.output_data, limit)
|
|
|
|
super().publish_event(event, channel)
|
|
|
|
def listen(
|
|
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
|
|
) -> Generator[ExecutionEvent, None, None]:
|
|
for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
|
|
yield event
|
|
|
|
|
|
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
|
Model = ExecutionEvent # type: ignore
|
|
|
|
@property
|
|
def event_bus_name(self) -> str:
|
|
return config.execution_event_bus_name
|
|
|
|
@func_retry
|
|
async def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
|
|
if isinstance(res, GraphExecutionMeta):
|
|
await self._publish_graph_exec_update(res)
|
|
else:
|
|
await self._publish_node_exec_update(res)
|
|
|
|
async def _publish_node_exec_update(self, res: NodeExecutionResult):
|
|
event = NodeExecutionEvent.model_validate(res.model_dump())
|
|
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
|
|
|
async def _publish_graph_exec_update(self, res: GraphExecutionMeta):
|
|
# GraphExecutionEvent requires inputs and outputs fields that GraphExecutionMeta doesn't have
|
|
# Add default empty values for compatibility
|
|
event_data = res.model_dump()
|
|
event_data.setdefault("inputs", {})
|
|
event_data.setdefault("outputs", {})
|
|
event = GraphExecutionEvent.model_validate(event_data)
|
|
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
|
|
|
async def _publish(self, event: ExecutionEvent, channel: str):
|
|
"""
|
|
truncate inputs and outputs to avoid large payloads
|
|
"""
|
|
limit = config.max_message_size_limit // 2
|
|
if isinstance(event, GraphExecutionEvent):
|
|
event.inputs = truncate(event.inputs, limit)
|
|
event.outputs = truncate(event.outputs, limit)
|
|
elif isinstance(event, NodeExecutionEvent):
|
|
event.input_data = truncate(event.input_data, limit)
|
|
event.output_data = truncate(event.output_data, limit)
|
|
|
|
await super().publish_event(event, channel)
|
|
|
|
async def listen(
|
|
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
|
|
) -> AsyncGenerator[ExecutionEvent, None]:
|
|
async for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
|
|
yield event
|
|
|
|
|
|
# --------------------- KV Data Functions --------------------- #
|
|
|
|
|
|
async def get_execution_kv_data(user_id: str, key: str) -> Any | None:
|
|
"""
|
|
Get key-value data for a user and key.
|
|
|
|
Args:
|
|
user_id: The id of the User.
|
|
key: The key to retrieve data for.
|
|
|
|
Returns:
|
|
The data associated with the key, or None if not found.
|
|
"""
|
|
kv_data = await AgentNodeExecutionKeyValueData.prisma().find_unique(
|
|
where={"userId_key": {"userId": user_id, "key": key}}
|
|
)
|
|
return (
|
|
type_utils.convert(kv_data.data, type[Any])
|
|
if kv_data and kv_data.data
|
|
else None
|
|
)
|
|
|
|
|
|
async def set_execution_kv_data(
|
|
user_id: str, node_exec_id: str, key: str, data: Any
|
|
) -> Any | None:
|
|
"""
|
|
Set key-value data for a user and key.
|
|
|
|
Args:
|
|
user_id: The id of the User.
|
|
node_exec_id: The id of the AgentNodeExecution.
|
|
key: The key to store data under.
|
|
data: The data to store.
|
|
"""
|
|
resp = await AgentNodeExecutionKeyValueData.prisma().upsert(
|
|
where={"userId_key": {"userId": user_id, "key": key}},
|
|
data={
|
|
"create": AgentNodeExecutionKeyValueDataCreateInput(
|
|
userId=user_id,
|
|
agentNodeExecutionId=node_exec_id,
|
|
key=key,
|
|
data=SafeJson(data) if data is not None else None,
|
|
),
|
|
"update": {
|
|
"agentNodeExecutionId": node_exec_id,
|
|
"data": SafeJson(data) if data is not None else None,
|
|
},
|
|
},
|
|
)
|
|
return type_utils.convert(resp.data, type[Any]) if resp and resp.data else None
|
|
|
|
|
|
async def get_block_error_stats(
|
|
start_time: datetime, end_time: datetime
|
|
) -> list[BlockErrorStats]:
|
|
"""Get block execution stats using efficient SQL aggregation."""
|
|
|
|
query_template = """
|
|
SELECT
|
|
n."agentBlockId" as block_id,
|
|
COUNT(*) as total_executions,
|
|
SUM(CASE WHEN ne."executionStatus" = 'FAILED' THEN 1 ELSE 0 END) as failed_executions
|
|
FROM {schema_prefix}"AgentNodeExecution" ne
|
|
JOIN {schema_prefix}"AgentNode" n ON ne."agentNodeId" = n.id
|
|
WHERE ne."addedTime" >= $1::timestamp AND ne."addedTime" <= $2::timestamp
|
|
GROUP BY n."agentBlockId"
|
|
HAVING COUNT(*) >= 10
|
|
"""
|
|
|
|
result = await query_raw_with_schema(query_template, start_time, end_time)
|
|
|
|
# Convert to typed data structures
|
|
return [
|
|
BlockErrorStats(
|
|
block_id=row["block_id"],
|
|
total_executions=int(row["total_executions"]),
|
|
failed_executions=int(row["failed_executions"]),
|
|
)
|
|
for row in result
|
|
]
|
|
|
|
|
|
async def update_graph_execution_share_status(
|
|
execution_id: str,
|
|
user_id: str,
|
|
is_shared: bool,
|
|
share_token: str | None,
|
|
shared_at: datetime | None,
|
|
) -> None:
|
|
"""Update the sharing status of a graph execution."""
|
|
await AgentGraphExecution.prisma().update(
|
|
where={"id": execution_id},
|
|
data={
|
|
"isShared": is_shared,
|
|
"shareToken": share_token,
|
|
"sharedAt": shared_at,
|
|
},
|
|
)
|
|
|
|
|
|
async def get_graph_execution_by_share_token(
|
|
share_token: str,
|
|
) -> SharedExecutionResponse | None:
|
|
"""Get a shared execution with limited public-safe data."""
|
|
execution = await AgentGraphExecution.prisma().find_first(
|
|
where={
|
|
"shareToken": share_token,
|
|
"isShared": True,
|
|
"isDeleted": False,
|
|
},
|
|
include={
|
|
"AgentGraph": True,
|
|
"NodeExecutions": {
|
|
"include": {
|
|
"Output": True,
|
|
"Node": {
|
|
"include": {
|
|
"AgentBlock": True,
|
|
}
|
|
},
|
|
},
|
|
},
|
|
},
|
|
)
|
|
|
|
if not execution:
|
|
return None
|
|
|
|
# Extract outputs from OUTPUT blocks only (consistent with GraphExecution.from_db)
|
|
outputs: CompletedBlockOutput = defaultdict(list)
|
|
if execution.NodeExecutions:
|
|
for node_exec in execution.NodeExecutions:
|
|
if node_exec.Node and node_exec.Node.agentBlockId:
|
|
# Get the block definition to check its type
|
|
block = get_block(node_exec.Node.agentBlockId)
|
|
|
|
if block and block.block_type == BlockType.OUTPUT:
|
|
# For OUTPUT blocks, the data is stored in executionData or Input
|
|
# The executionData contains the structured input with 'name' and 'value' fields
|
|
if hasattr(node_exec, "executionData") and node_exec.executionData:
|
|
exec_data = type_utils.convert(
|
|
node_exec.executionData, dict[str, Any]
|
|
)
|
|
if "name" in exec_data:
|
|
name = exec_data["name"]
|
|
value = exec_data.get("value")
|
|
outputs[name].append(value)
|
|
elif node_exec.Input:
|
|
# Build input_data from Input relation
|
|
input_data = {}
|
|
for data in node_exec.Input:
|
|
if data.name and data.data is not None:
|
|
input_data[data.name] = type_utils.convert(
|
|
data.data, JsonValue
|
|
)
|
|
|
|
if "name" in input_data:
|
|
name = input_data["name"]
|
|
value = input_data.get("value")
|
|
outputs[name].append(value)
|
|
|
|
return SharedExecutionResponse(
|
|
id=execution.id,
|
|
graph_name=(
|
|
execution.AgentGraph.name
|
|
if (execution.AgentGraph and execution.AgentGraph.name)
|
|
else "Untitled Agent"
|
|
),
|
|
graph_description=(
|
|
execution.AgentGraph.description if execution.AgentGraph else None
|
|
),
|
|
status=ExecutionStatus(execution.executionStatus),
|
|
created_at=execution.createdAt,
|
|
outputs=outputs,
|
|
)
|
|
|
|
|
|
async def get_frequently_executed_graphs(
|
|
days_back: int = 30,
|
|
min_executions: int = 10,
|
|
) -> list[dict]:
|
|
"""Get graphs that have been frequently executed for monitoring."""
|
|
query_template = """
|
|
SELECT DISTINCT
|
|
e."agentGraphId" as graph_id,
|
|
e."userId" as user_id,
|
|
COUNT(*) as execution_count
|
|
FROM {schema_prefix}"AgentGraphExecution" e
|
|
WHERE e."createdAt" >= $1::timestamp
|
|
AND e."isDeleted" = false
|
|
AND e."executionStatus" IN ('COMPLETED', 'FAILED', 'TERMINATED')
|
|
GROUP BY e."agentGraphId", e."userId"
|
|
HAVING COUNT(*) >= $2
|
|
ORDER BY execution_count DESC
|
|
"""
|
|
|
|
start_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
|
result = await query_raw_with_schema(query_template, start_date, min_executions)
|
|
|
|
return [
|
|
{
|
|
"graph_id": row["graph_id"],
|
|
"user_id": row["user_id"],
|
|
"execution_count": int(row["execution_count"]),
|
|
}
|
|
for row in result
|
|
]
|