Merge branch 'dev' into add-iffy-moderation

This commit is contained in:
Bently
2025-03-21 16:28:18 +00:00
committed by GitHub
9 changed files with 169 additions and 117 deletions

View File

@@ -10,6 +10,7 @@ from prisma.models import (
AgentNodeExecution,
AgentNodeExecutionInputOutput,
)
from prisma.types import AgentNodeExecutionUpdateInput, AgentNodeExecutionWhereInput
from pydantic import BaseModel
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
@@ -166,7 +167,8 @@ async def create_graph_execution(
"create": [ # type: ignore
{
"agentNodeId": node_id,
"executionStatus": ExecutionStatus.INCOMPLETE,
"executionStatus": ExecutionStatus.QUEUED,
"queuedTime": datetime.now(tz=timezone.utc),
"Input": {
"create": [
{"name": name, "data": Json(data)}
@@ -291,13 +293,19 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResu
async def update_graph_execution_stats(
graph_exec_id: str,
status: ExecutionStatus,
stats: GraphExecutionStats,
stats: GraphExecutionStats | None = None,
) -> ExecutionResult:
data = stats.model_dump()
if isinstance(data["error"], Exception):
data = stats.model_dump() if stats else {}
if isinstance(data.get("error"), Exception):
data["error"] = str(data["error"])
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
where={
"id": graph_exec_id,
"OR": [
{"executionStatus": ExecutionStatus.RUNNING},
{"executionStatus": ExecutionStatus.QUEUED},
],
},
data={
"executionStatus": status,
"stats": Json(data),
@@ -319,6 +327,17 @@ async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionSta
)
async def update_execution_status_batch(
node_exec_ids: list[str],
status: ExecutionStatus,
stats: dict[str, Any] | None = None,
):
await AgentNodeExecution.prisma().update_many(
where={"id": {"in": node_exec_ids}},
data=_get_update_status_data(status, None, stats),
)
async def update_execution_status(
node_exec_id: str,
status: ExecutionStatus,
@@ -328,20 +347,9 @@ async def update_execution_status(
if status == ExecutionStatus.QUEUED and execution_data is None:
raise ValueError("Execution data must be provided when queuing an execution.")
now = datetime.now(tz=timezone.utc)
data = {
**({"executionStatus": status}),
**({"queuedTime": now} if status == ExecutionStatus.QUEUED else {}),
**({"startedTime": now} if status == ExecutionStatus.RUNNING else {}),
**({"endedTime": now} if status == ExecutionStatus.FAILED else {}),
**({"endedTime": now} if status == ExecutionStatus.COMPLETED else {}),
**({"executionData": Json(execution_data)} if execution_data else {}),
**({"stats": Json(stats)} if stats else {}),
}
res = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data=data, # type: ignore
data=_get_update_status_data(status, execution_data, stats),
include=EXECUTION_RESULT_INCLUDE,
)
if not res:
@@ -350,6 +358,29 @@ async def update_execution_status(
return ExecutionResult.from_db(res)
def _get_update_status_data(
status: ExecutionStatus,
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> AgentNodeExecutionUpdateInput:
now = datetime.now(tz=timezone.utc)
update_data: AgentNodeExecutionUpdateInput = {"executionStatus": status}
if status == ExecutionStatus.QUEUED:
update_data["queuedTime"] = now
elif status == ExecutionStatus.RUNNING:
update_data["startedTime"] = now
elif status in (ExecutionStatus.FAILED, ExecutionStatus.COMPLETED):
update_data["endedTime"] = now
if execution_data:
update_data["executionData"] = Json(execution_data)
if stats:
update_data["stats"] = Json(stats)
return update_data
async def delete_execution(
graph_exec_id: str, user_id: str, soft_delete: bool = True
) -> None:
@@ -367,41 +398,29 @@ async def delete_execution(
)
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
async def get_execution_results(
graph_exec_id: str,
block_ids: list[str] | None = None,
statuses: list[ExecutionStatus] | None = None,
limit: int | None = None,
) -> list[ExecutionResult]:
where_clause: AgentNodeExecutionWhereInput = {
"agentGraphExecutionId": graph_exec_id,
}
if block_ids:
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
if statuses:
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
executions = await AgentNodeExecution.prisma().find_many(
where={"agentGraphExecutionId": graph_exec_id},
where=where_clause,
include=EXECUTION_RESULT_INCLUDE,
order=[
{"queuedTime": "asc"},
{"addedTime": "asc"}, # Fallback: Incomplete execs has no queuedTime.
],
take=limit,
)
res = [ExecutionResult.from_db(execution) for execution in executions]
return res
async def get_executions_in_timerange(
user_id: str, start_time: str, end_time: str
) -> list[ExecutionResult]:
try:
executions = await AgentGraphExecution.prisma().find_many(
where={
"startedAt": {
"gte": datetime.fromisoformat(start_time),
"lte": datetime.fromisoformat(end_time),
},
"userId": user_id,
"isDeleted": False,
},
include=GRAPH_EXECUTION_INCLUDE,
)
return [ExecutionResult.from_graph(execution) for execution in executions]
except Exception as e:
raise DatabaseError(
f"Failed to get executions in timerange {start_time} to {end_time} for user {user_id}: {e}"
) from e
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
@@ -544,7 +563,10 @@ async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
},
order={"queuedTime": "desc"},
order=[
{"queuedTime": "desc"},
{"addedTime": "desc"},
],
include=EXECUTION_RESULT_INCLUDE,
)
if not execution:

View File

@@ -23,7 +23,7 @@ from backend.util import type as type_utils
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
from .db import BaseDbModel, transaction
from .execution import ExecutionResult, ExecutionStatus
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, GRAPH_EXECUTION_INCLUDE
from .integrations import Webhook
logger = logging.getLogger(__name__)
@@ -216,10 +216,13 @@ class GraphExecution(GraphExecutionMeta):
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
node_executions = [
ExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
]
node_executions = sorted(
[
ExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
inputs = {
**{
@@ -658,20 +661,13 @@ async def get_execution_meta(
return GraphExecutionMeta.from_db(execution) if execution else None
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
async def get_execution(
user_id: str,
execution_id: str,
) -> GraphExecution | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id},
include={
"AgentNodeExecutions": {
"include": {"AgentNode": True, "Input": True, "Output": True},
"order_by": [
{"queuedTime": "asc"},
{ # Fallback: Incomplete execs has no queuedTime.
"addedTime": "asc"
},
],
},
},
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecution.from_db(execution) if execution else None

View File

@@ -18,6 +18,8 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
"AgentGraphExecution": True,
}
MAX_NODE_EXECUTIONS_FETCH = 1000
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
"include": {
@@ -25,10 +27,17 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
}
},
"order_by": [
{"queuedTime": "desc"},
# Fallback: Incomplete execs has no queuedTime.
{"addedTime": "desc"},
],
"take": MAX_NODE_EXECUTIONS_FETCH, # Avoid loading excessive node executions.
}
}
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
}

View File

@@ -8,6 +8,7 @@ from backend.data.execution import (
get_incomplete_executions,
get_latest_execution,
update_execution_status,
update_execution_status_batch,
update_graph_execution_start_time,
update_graph_execution_stats,
update_node_execution_stats,
@@ -69,6 +70,7 @@ class DatabaseManager(AppService):
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
get_latest_execution = exposed_run_and_wait(get_latest_execution)
update_execution_status = exposed_run_and_wait(update_execution_status)
update_execution_status_batch = exposed_run_and_wait(update_execution_status_batch)
update_graph_execution_start_time = exposed_run_and_wait(
update_graph_execution_start_time
)

View File

@@ -681,10 +681,6 @@ class Executor:
try:
queue = ExecutionQueue[NodeExecutionEntry]()
for node_exec in graph_exec.start_node_execs:
exec_update = cls.db_client.update_execution_status(
node_exec.node_exec_id, ExecutionStatus.QUEUED, node_exec.data
)
cls.db_client.send_execution_update(exec_update)
queue.add(node_exec)
running_executions: dict[str, AsyncResult] = {}
@@ -789,7 +785,10 @@ class Executor:
metadata = cls.db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id)
outputs = cls.db_client.get_execution_results(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
named_outputs = [
{
@@ -797,7 +796,6 @@ class Executor:
for key, value in output.output_data.items()
}
for output in outputs
if output.block_id == AgentOutputBlock().id
]
event = NotificationEventDTO(
@@ -1055,29 +1053,36 @@ class ExecutionManager(AppService):
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
if graph_exec_id not in self.active_graph_runs:
raise Exception(
logger.warning(
f"Graph execution #{graph_exec_id} not active/running: "
"possibly already completed/cancelled."
)
else:
future, cancel_event = self.active_graph_runs[graph_exec_id]
if not cancel_event.is_set():
cancel_event.set()
future.result()
future, cancel_event = self.active_graph_runs[graph_exec_id]
if cancel_event.is_set():
return
cancel_event.set()
future.result()
# Update the status of the unfinished node executions
node_execs = self.db_client.get_execution_results(graph_exec_id)
# Update the status of the graph & node executions
self.db_client.update_graph_execution_stats(
graph_exec_id,
ExecutionStatus.TERMINATED,
)
node_execs = self.db_client.get_execution_results(
graph_exec_id=graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
ExecutionStatus.INCOMPLETE,
],
)
self.db_client.update_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
ExecutionStatus.TERMINATED,
)
for node_exec in node_execs:
if node_exec.status not in (
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
):
exec_update = self.db_client.update_execution_status(
node_exec.node_exec_id, ExecutionStatus.TERMINATED
)
self.db_client.send_execution_update(exec_update)
node_exec.status = ExecutionStatus.TERMINATED
self.db_client.send_execution_update(node_exec)
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
"""Checks all credentials for all nodes of the graph"""

View File

@@ -640,12 +640,8 @@ async def get_graph_execution(
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> graph_db.GraphExecution:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
result = await graph_db.get_execution(execution_id=graph_exec_id, user_id=user_id)
if not result:
if not result or result.graph_id != graph_id:
raise HTTPException(
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
)

View File

@@ -44,7 +44,7 @@ from Pyro5 import config as pyro_config
from backend.data import db, rabbitmq, redis
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import to_dict
from backend.util.process import AppProcess
from backend.util.process import AppProcess, get_service_name
from backend.util.retry import conn_retry
from backend.util.settings import Config, Secrets
@@ -190,7 +190,17 @@ class BaseAppService(AppProcess, ABC):
@classmethod
def get_host(cls) -> str:
return os.environ.get(f"{cls.service_name.upper()}_HOST", api_host)
source_host = os.environ.get(f"{get_service_name().upper()}_HOST", api_host)
target_host = os.environ.get(f"{cls.service_name.upper()}_HOST", api_host)
if source_host == target_host and source_host != api_host:
logger.warning(
f"Service {cls.service_name} is the same host as the source service."
f"Use the localhost of {api_host} instead."
)
return api_host
return target_host
@property
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:

View File

@@ -114,10 +114,19 @@ export default function useAgentGraph(
});
if (flowID && flowVersion) {
api.subscribeToExecution(flowID, flowVersion);
console.debug(
`Subscribed to execution events for ${flowID} v.${flowVersion}`,
);
api
.subscribeToExecution(flowID, flowVersion)
.then(() =>
console.debug(
`Subscribed to execution events for ${flowID} v.${flowVersion}`,
),
)
.catch((error) =>
console.error(
`Failed to subscribe to execution events for ${flowID} v.${flowVersion}:`,
error,
),
);
}
}, [api, flowID, flowVersion, flowExecutionID]);

View File

@@ -827,7 +827,7 @@ export default class BackendAPI {
this.webSocket.onclose = (event) => {
console.warn("WebSocket connection closed", event);
this.stopHeartbeat(); // Stop heartbeat when connection closes
this.webSocket = null;
this.wsConnecting = null;
// Attempt to reconnect after a delay
setTimeout(() => this.connectWebSocket(), 1000);
};
@@ -835,6 +835,7 @@ export default class BackendAPI {
this.webSocket.onerror = (error) => {
console.error("WebSocket error:", error);
this.stopHeartbeat(); // Stop heartbeat on error
this.wsConnecting = null;
reject(error);
};
@@ -868,26 +869,28 @@ export default class BackendAPI {
this.webSocket.close();
}
}
sendWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
async sendWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
method: M,
data: WebsocketMessageTypeMap[M],
callCount = 0,
) {
callCountLimit = 4,
): Promise<void> {
if (this.webSocket && this.webSocket.readyState === WebSocket.OPEN) {
this.webSocket.send(JSON.stringify({ method, data }));
} else {
this.connectWebSocket().then(() => {
callCount == 0
? this.sendWebSocketMessage(method, data, callCount + 1)
: setTimeout(
() => {
this.sendWebSocketMessage(method, data, callCount + 1);
},
2 ** (callCount - 1) * 1000,
);
});
const result = this.webSocket.send(JSON.stringify({ method, data }));
return;
}
if (callCount >= callCountLimit) {
throw new Error(
`WebSocket connection not open after ${callCountLimit} attempts`,
);
}
await this.connectWebSocket();
if (callCount === 0) {
return this.sendWebSocketMessage(method, data, callCount + 1);
}
const delayMs = 2 ** (callCount - 1) * 1000;
await new Promise((res) => setTimeout(res, delayMs));
return this.sendWebSocketMessage(method, data, callCount + 1);
}
onWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
@@ -901,8 +904,8 @@ export default class BackendAPI {
return () => this.wsMessageHandlers[method].delete(handler);
}
subscribeToExecution(graphId: string, graphVersion: number) {
this.sendWebSocketMessage("subscribe", {
async subscribeToExecution(graphId: string, graphVersion: number) {
await this.sendWebSocketMessage("subscribe", {
graph_id: graphId,
graph_version: graphVersion,
});