mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' into add-iffy-moderation
This commit is contained in:
@@ -10,6 +10,7 @@ from prisma.models import (
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import AgentNodeExecutionUpdateInput, AgentNodeExecutionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
|
||||
@@ -166,7 +167,8 @@ async def create_graph_execution(
|
||||
"create": [ # type: ignore
|
||||
{
|
||||
"agentNodeId": node_id,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"queuedTime": datetime.now(tz=timezone.utc),
|
||||
"Input": {
|
||||
"create": [
|
||||
{"name": name, "data": Json(data)}
|
||||
@@ -291,13 +293,19 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResu
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
stats: GraphExecutionStats,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> ExecutionResult:
|
||||
data = stats.model_dump()
|
||||
if isinstance(data["error"], Exception):
|
||||
data = stats.model_dump() if stats else {}
|
||||
if isinstance(data.get("error"), Exception):
|
||||
data["error"] = str(data["error"])
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
{"executionStatus": ExecutionStatus.RUNNING},
|
||||
{"executionStatus": ExecutionStatus.QUEUED},
|
||||
],
|
||||
},
|
||||
data={
|
||||
"executionStatus": status,
|
||||
"stats": Json(data),
|
||||
@@ -319,6 +327,17 @@ async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionSta
|
||||
)
|
||||
|
||||
|
||||
async def update_execution_status_batch(
|
||||
node_exec_ids: list[str],
|
||||
status: ExecutionStatus,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
await AgentNodeExecution.prisma().update_many(
|
||||
where={"id": {"in": node_exec_ids}},
|
||||
data=_get_update_status_data(status, None, stats),
|
||||
)
|
||||
|
||||
|
||||
async def update_execution_status(
|
||||
node_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
@@ -328,20 +347,9 @@ async def update_execution_status(
|
||||
if status == ExecutionStatus.QUEUED and execution_data is None:
|
||||
raise ValueError("Execution data must be provided when queuing an execution.")
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
data = {
|
||||
**({"executionStatus": status}),
|
||||
**({"queuedTime": now} if status == ExecutionStatus.QUEUED else {}),
|
||||
**({"startedTime": now} if status == ExecutionStatus.RUNNING else {}),
|
||||
**({"endedTime": now} if status == ExecutionStatus.FAILED else {}),
|
||||
**({"endedTime": now} if status == ExecutionStatus.COMPLETED else {}),
|
||||
**({"executionData": Json(execution_data)} if execution_data else {}),
|
||||
**({"stats": Json(stats)} if stats else {}),
|
||||
}
|
||||
|
||||
res = await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
data=data, # type: ignore
|
||||
data=_get_update_status_data(status, execution_data, stats),
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
@@ -350,6 +358,29 @@ async def update_execution_status(
|
||||
return ExecutionResult.from_db(res)
|
||||
|
||||
|
||||
def _get_update_status_data(
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> AgentNodeExecutionUpdateInput:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
update_data: AgentNodeExecutionUpdateInput = {"executionStatus": status}
|
||||
|
||||
if status == ExecutionStatus.QUEUED:
|
||||
update_data["queuedTime"] = now
|
||||
elif status == ExecutionStatus.RUNNING:
|
||||
update_data["startedTime"] = now
|
||||
elif status in (ExecutionStatus.FAILED, ExecutionStatus.COMPLETED):
|
||||
update_data["endedTime"] = now
|
||||
|
||||
if execution_data:
|
||||
update_data["executionData"] = Json(execution_data)
|
||||
if stats:
|
||||
update_data["stats"] = Json(stats)
|
||||
|
||||
return update_data
|
||||
|
||||
|
||||
async def delete_execution(
|
||||
graph_exec_id: str, user_id: str, soft_delete: bool = True
|
||||
) -> None:
|
||||
@@ -367,41 +398,29 @@ async def delete_execution(
|
||||
)
|
||||
|
||||
|
||||
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
|
||||
async def get_execution_results(
|
||||
graph_exec_id: str,
|
||||
block_ids: list[str] | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[ExecutionResult]:
|
||||
where_clause: AgentNodeExecutionWhereInput = {
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
if block_ids:
|
||||
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
||||
if statuses:
|
||||
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={"agentGraphExecutionId": graph_exec_id},
|
||||
where=where_clause,
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
order=[
|
||||
{"queuedTime": "asc"},
|
||||
{"addedTime": "asc"}, # Fallback: Incomplete execs has no queuedTime.
|
||||
],
|
||||
take=limit,
|
||||
)
|
||||
res = [ExecutionResult.from_db(execution) for execution in executions]
|
||||
return res
|
||||
|
||||
|
||||
async def get_executions_in_timerange(
|
||||
user_id: str, start_time: str, end_time: str
|
||||
) -> list[ExecutionResult]:
|
||||
try:
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where={
|
||||
"startedAt": {
|
||||
"gte": datetime.fromisoformat(start_time),
|
||||
"lte": datetime.fromisoformat(end_time),
|
||||
},
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return [ExecutionResult.from_graph(execution) for execution in executions]
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get executions in timerange {start_time} to {end_time} for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
@@ -544,7 +563,10 @@ async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
|
||||
},
|
||||
order={"queuedTime": "desc"},
|
||||
order=[
|
||||
{"queuedTime": "desc"},
|
||||
{"addedTime": "desc"},
|
||||
],
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not execution:
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.util import type as type_utils
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, transaction
|
||||
from .execution import ExecutionResult, ExecutionStatus
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, GRAPH_EXECUTION_INCLUDE
|
||||
from .integrations import Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -216,10 +216,13 @@ class GraphExecution(GraphExecutionMeta):
|
||||
|
||||
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
|
||||
|
||||
node_executions = [
|
||||
ExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
]
|
||||
node_executions = sorted(
|
||||
[
|
||||
ExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
**{
|
||||
@@ -658,20 +661,13 @@ async def get_execution_meta(
|
||||
return GraphExecutionMeta.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
|
||||
async def get_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
) -> GraphExecution | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "isDeleted": False, "userId": user_id},
|
||||
include={
|
||||
"AgentNodeExecutions": {
|
||||
"include": {"AgentNode": True, "Input": True, "Output": True},
|
||||
"order_by": [
|
||||
{"queuedTime": "asc"},
|
||||
{ # Fallback: Incomplete execs has no queuedTime.
|
||||
"addedTime": "asc"
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return GraphExecution.from_db(execution) if execution else None
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
"include": {
|
||||
@@ -25,10 +27,17 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
},
|
||||
"order_by": [
|
||||
{"queuedTime": "desc"},
|
||||
# Fallback: Incomplete execs has no queuedTime.
|
||||
{"addedTime": "desc"},
|
||||
],
|
||||
"take": MAX_NODE_EXECUTIONS_FETCH, # Avoid loading excessive node executions.
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ from backend.data.execution import (
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
update_execution_status,
|
||||
update_execution_status_batch,
|
||||
update_graph_execution_start_time,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
@@ -69,6 +70,7 @@ class DatabaseManager(AppService):
|
||||
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
|
||||
get_latest_execution = exposed_run_and_wait(get_latest_execution)
|
||||
update_execution_status = exposed_run_and_wait(update_execution_status)
|
||||
update_execution_status_batch = exposed_run_and_wait(update_execution_status_batch)
|
||||
update_graph_execution_start_time = exposed_run_and_wait(
|
||||
update_graph_execution_start_time
|
||||
)
|
||||
|
||||
@@ -681,10 +681,6 @@ class Executor:
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
for node_exec in graph_exec.start_node_execs:
|
||||
exec_update = cls.db_client.update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.QUEUED, node_exec.data
|
||||
)
|
||||
cls.db_client.send_execution_update(exec_update)
|
||||
queue.add(node_exec)
|
||||
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
@@ -789,7 +785,10 @@ class Executor:
|
||||
metadata = cls.db_client.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id)
|
||||
outputs = cls.db_client.get_execution_results(
|
||||
graph_exec.graph_exec_id,
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
|
||||
named_outputs = [
|
||||
{
|
||||
@@ -797,7 +796,6 @@ class Executor:
|
||||
for key, value in output.output_data.items()
|
||||
}
|
||||
for output in outputs
|
||||
if output.block_id == AgentOutputBlock().id
|
||||
]
|
||||
|
||||
event = NotificationEventDTO(
|
||||
@@ -1055,29 +1053,36 @@ class ExecutionManager(AppService):
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
raise Exception(
|
||||
logger.warning(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
else:
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
# Update the status of the unfinished node executions
|
||||
node_execs = self.db_client.get_execution_results(graph_exec_id)
|
||||
# Update the status of the graph & node executions
|
||||
self.db_client.update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = self.db_client.get_execution_results(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
self.db_client.update_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
for node_exec in node_execs:
|
||||
if node_exec.status not in (
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
):
|
||||
exec_update = self.db_client.update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.TERMINATED
|
||||
)
|
||||
self.db_client.send_execution_update(exec_update)
|
||||
node_exec.status = ExecutionStatus.TERMINATED
|
||||
self.db_client.send_execution_update(node_exec)
|
||||
|
||||
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
@@ -640,12 +640,8 @@ async def get_graph_execution(
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> graph_db.GraphExecution:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
result = await graph_db.get_execution(execution_id=graph_exec_id, user_id=user_id)
|
||||
if not result:
|
||||
if not result or result.graph_id != graph_id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
@@ -44,7 +44,7 @@ from Pyro5 import config as pyro_config
|
||||
from backend.data import db, rabbitmq, redis
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Config, Secrets
|
||||
|
||||
@@ -190,7 +190,17 @@ class BaseAppService(AppProcess, ABC):
|
||||
|
||||
@classmethod
|
||||
def get_host(cls) -> str:
|
||||
return os.environ.get(f"{cls.service_name.upper()}_HOST", api_host)
|
||||
source_host = os.environ.get(f"{get_service_name().upper()}_HOST", api_host)
|
||||
target_host = os.environ.get(f"{cls.service_name.upper()}_HOST", api_host)
|
||||
|
||||
if source_host == target_host and source_host != api_host:
|
||||
logger.warning(
|
||||
f"Service {cls.service_name} is the same host as the source service."
|
||||
f"Use the localhost of {api_host} instead."
|
||||
)
|
||||
return api_host
|
||||
|
||||
return target_host
|
||||
|
||||
@property
|
||||
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
|
||||
|
||||
@@ -114,10 +114,19 @@ export default function useAgentGraph(
|
||||
});
|
||||
|
||||
if (flowID && flowVersion) {
|
||||
api.subscribeToExecution(flowID, flowVersion);
|
||||
console.debug(
|
||||
`Subscribed to execution events for ${flowID} v.${flowVersion}`,
|
||||
);
|
||||
api
|
||||
.subscribeToExecution(flowID, flowVersion)
|
||||
.then(() =>
|
||||
console.debug(
|
||||
`Subscribed to execution events for ${flowID} v.${flowVersion}`,
|
||||
),
|
||||
)
|
||||
.catch((error) =>
|
||||
console.error(
|
||||
`Failed to subscribe to execution events for ${flowID} v.${flowVersion}:`,
|
||||
error,
|
||||
),
|
||||
);
|
||||
}
|
||||
}, [api, flowID, flowVersion, flowExecutionID]);
|
||||
|
||||
|
||||
@@ -827,7 +827,7 @@ export default class BackendAPI {
|
||||
this.webSocket.onclose = (event) => {
|
||||
console.warn("WebSocket connection closed", event);
|
||||
this.stopHeartbeat(); // Stop heartbeat when connection closes
|
||||
this.webSocket = null;
|
||||
this.wsConnecting = null;
|
||||
// Attempt to reconnect after a delay
|
||||
setTimeout(() => this.connectWebSocket(), 1000);
|
||||
};
|
||||
@@ -835,6 +835,7 @@ export default class BackendAPI {
|
||||
this.webSocket.onerror = (error) => {
|
||||
console.error("WebSocket error:", error);
|
||||
this.stopHeartbeat(); // Stop heartbeat on error
|
||||
this.wsConnecting = null;
|
||||
reject(error);
|
||||
};
|
||||
|
||||
@@ -868,26 +869,28 @@ export default class BackendAPI {
|
||||
this.webSocket.close();
|
||||
}
|
||||
}
|
||||
|
||||
sendWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
|
||||
async sendWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
|
||||
method: M,
|
||||
data: WebsocketMessageTypeMap[M],
|
||||
callCount = 0,
|
||||
) {
|
||||
callCountLimit = 4,
|
||||
): Promise<void> {
|
||||
if (this.webSocket && this.webSocket.readyState === WebSocket.OPEN) {
|
||||
this.webSocket.send(JSON.stringify({ method, data }));
|
||||
} else {
|
||||
this.connectWebSocket().then(() => {
|
||||
callCount == 0
|
||||
? this.sendWebSocketMessage(method, data, callCount + 1)
|
||||
: setTimeout(
|
||||
() => {
|
||||
this.sendWebSocketMessage(method, data, callCount + 1);
|
||||
},
|
||||
2 ** (callCount - 1) * 1000,
|
||||
);
|
||||
});
|
||||
const result = this.webSocket.send(JSON.stringify({ method, data }));
|
||||
return;
|
||||
}
|
||||
if (callCount >= callCountLimit) {
|
||||
throw new Error(
|
||||
`WebSocket connection not open after ${callCountLimit} attempts`,
|
||||
);
|
||||
}
|
||||
await this.connectWebSocket();
|
||||
if (callCount === 0) {
|
||||
return this.sendWebSocketMessage(method, data, callCount + 1);
|
||||
}
|
||||
const delayMs = 2 ** (callCount - 1) * 1000;
|
||||
await new Promise((res) => setTimeout(res, delayMs));
|
||||
return this.sendWebSocketMessage(method, data, callCount + 1);
|
||||
}
|
||||
|
||||
onWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
|
||||
@@ -901,8 +904,8 @@ export default class BackendAPI {
|
||||
return () => this.wsMessageHandlers[method].delete(handler);
|
||||
}
|
||||
|
||||
subscribeToExecution(graphId: string, graphVersion: number) {
|
||||
this.sendWebSocketMessage("subscribe", {
|
||||
async subscribeToExecution(graphId: string, graphVersion: number) {
|
||||
await this.sendWebSocketMessage("subscribe", {
|
||||
graph_id: graphId,
|
||||
graph_version: graphVersion,
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user