mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 16:48:06 -05:00
Compare commits
1 Commits
dev
...
feat/migra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9d846eebb |
@@ -1,11 +1,10 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
@@ -102,6 +101,22 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
|
||||
|
||||
def init_worker():
|
||||
"""Initialize ExecutionProcessor instance in thread-local storage"""
|
||||
_tls.processor = ExecutionProcessor()
|
||||
_tls.processor.on_graph_executor_start()
|
||||
|
||||
|
||||
def execute_graph(
|
||||
graph_exec_entry: "GraphExecutionEntry", cancel_event: threading.Event
|
||||
):
|
||||
"""Execute graph using thread-local ExecutionProcessor instance"""
|
||||
return _tls.processor.on_graph_execution(graph_exec_entry, cancel_event)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -366,7 +381,7 @@ async def _enqueue_next_nodes(
|
||||
]
|
||||
|
||||
|
||||
class Executor:
|
||||
class ExecutionProcessor:
|
||||
"""
|
||||
This class contains event handlers for the process pool executor events.
|
||||
|
||||
@@ -389,10 +404,9 @@ class Executor:
|
||||
9. Node executor enqueues the next executed nodes to the node execution queue.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@async_error_logged(swallow=True)
|
||||
async def on_node_execution(
|
||||
cls,
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
@@ -411,7 +425,7 @@ class Executor:
|
||||
node = await db_client.get_node(node_exec.node_id)
|
||||
execution_stats = NodeExecutionStats()
|
||||
|
||||
timing_info, status = await cls._on_node_execution(
|
||||
timing_info, status = await self._on_node_execution(
|
||||
node=node,
|
||||
node_exec=node_exec,
|
||||
node_exec_progress=node_exec_progress,
|
||||
@@ -454,10 +468,9 @@ class Executor:
|
||||
|
||||
return execution_stats
|
||||
|
||||
@classmethod
|
||||
@async_time_measured
|
||||
async def _on_node_execution(
|
||||
cls,
|
||||
self,
|
||||
node: Node,
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
@@ -497,7 +510,7 @@ class Executor:
|
||||
|
||||
async for output_name, output_data in execute_node(
|
||||
node=node,
|
||||
creds_manager=cls.creds_manager,
|
||||
creds_manager=self.creds_manager,
|
||||
data=node_exec,
|
||||
execution_stats=stats,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
@@ -537,29 +550,27 @@ class Executor:
|
||||
|
||||
return status
|
||||
|
||||
@classmethod
|
||||
@func_retry
|
||||
def on_graph_executor_start(cls):
|
||||
def on_graph_executor_start(self):
|
||||
configure_logging()
|
||||
set_service_name("GraphExecutor")
|
||||
cls.pid = os.getpid()
|
||||
cls.creds_manager = IntegrationCredentialsManager()
|
||||
cls.node_execution_loop = asyncio.new_event_loop()
|
||||
cls.node_evaluation_loop = asyncio.new_event_loop()
|
||||
cls.node_execution_thread = threading.Thread(
|
||||
target=cls.node_execution_loop.run_forever, daemon=True
|
||||
self.tid = threading.get_ident()
|
||||
self.creds_manager = IntegrationCredentialsManager()
|
||||
self.node_execution_loop = asyncio.new_event_loop()
|
||||
self.node_evaluation_loop = asyncio.new_event_loop()
|
||||
self.node_execution_thread = threading.Thread(
|
||||
target=self.node_execution_loop.run_forever, daemon=True
|
||||
)
|
||||
cls.node_evaluation_thread = threading.Thread(
|
||||
target=cls.node_evaluation_loop.run_forever, daemon=True
|
||||
self.node_evaluation_thread = threading.Thread(
|
||||
target=self.node_evaluation_loop.run_forever, daemon=True
|
||||
)
|
||||
cls.node_execution_thread.start()
|
||||
cls.node_evaluation_thread.start()
|
||||
logger.info(f"[GraphExecutor] {cls.pid} started")
|
||||
self.node_execution_thread.start()
|
||||
self.node_evaluation_thread.start()
|
||||
logger.info(f"[GraphExecutor] {self.tid} started")
|
||||
|
||||
@classmethod
|
||||
@error_logged(swallow=False)
|
||||
def on_graph_execution(
|
||||
cls,
|
||||
self,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
):
|
||||
@@ -615,7 +626,7 @@ class Executor:
|
||||
else:
|
||||
exec_stats = exec_meta.stats.to_db()
|
||||
|
||||
timing_info, status = cls._on_graph_execution(
|
||||
timing_info, status = self._on_graph_execution(
|
||||
graph_exec=graph_exec,
|
||||
cancel=cancel,
|
||||
log_metadata=log_metadata,
|
||||
@@ -641,7 +652,7 @@ class Executor:
|
||||
user_id=graph_exec.user_id,
|
||||
execution_status=status,
|
||||
),
|
||||
cls.node_execution_loop,
|
||||
self.node_execution_loop,
|
||||
).result(timeout=60.0)
|
||||
if activity_status is not None:
|
||||
exec_stats.activity_status = activity_status
|
||||
@@ -652,7 +663,7 @@ class Executor:
|
||||
)
|
||||
|
||||
# Communication handling
|
||||
cls._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
|
||||
finally:
|
||||
update_graph_execution_state(
|
||||
@@ -662,9 +673,8 @@ class Executor:
|
||||
stats=exec_stats,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _charge_usage(
|
||||
cls,
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
) -> int:
|
||||
@@ -714,22 +724,15 @@ class Executor:
|
||||
|
||||
return total_cost
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
cls,
|
||||
self,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
execution_stats: GraphExecutionStats,
|
||||
) -> ExecutionStatus:
|
||||
|
||||
# Agent execution is uninterrupted.
|
||||
import signal
|
||||
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
signal.signal(signal.SIGTERM, signal.SIG_IGN)
|
||||
|
||||
"""
|
||||
Returns:
|
||||
dict: The execution statistics of the graph execution.
|
||||
@@ -786,7 +789,7 @@ class Executor:
|
||||
|
||||
# Charge usage (may raise) ------------------------------
|
||||
try:
|
||||
cost = cls._charge_usage(
|
||||
cost = self._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
)
|
||||
@@ -806,7 +809,7 @@ class Executor:
|
||||
status=ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
cls._handle_low_balance_notif(
|
||||
self._handle_low_balance_notif(
|
||||
db_client,
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
@@ -825,7 +828,7 @@ class Executor:
|
||||
|
||||
# Kick off async node execution -------------------------
|
||||
node_execution_task = asyncio.run_coroutine_threadsafe(
|
||||
cls.on_node_execution(
|
||||
self.on_node_execution(
|
||||
node_exec=queued_node_exec,
|
||||
node_exec_progress=running_node_execution[node_id],
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
@@ -834,7 +837,7 @@ class Executor:
|
||||
execution_stats_lock,
|
||||
),
|
||||
),
|
||||
cls.node_execution_loop,
|
||||
self.node_execution_loop,
|
||||
)
|
||||
running_node_execution[node_id].add_task(
|
||||
node_exec_id=queued_node_exec.node_exec_id,
|
||||
@@ -875,7 +878,7 @@ class Executor:
|
||||
node_output_found = True
|
||||
running_node_evaluation[node_id] = (
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
cls._process_node_output(
|
||||
self._process_node_output(
|
||||
output=output,
|
||||
node_id=node_id,
|
||||
graph_exec=graph_exec,
|
||||
@@ -883,7 +886,7 @@ class Executor:
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
execution_queue=execution_queue,
|
||||
),
|
||||
cls.node_evaluation_loop,
|
||||
self.node_evaluation_loop,
|
||||
)
|
||||
)
|
||||
if (
|
||||
@@ -926,7 +929,7 @@ class Executor:
|
||||
raise
|
||||
|
||||
finally:
|
||||
cls._cleanup_graph_execution(
|
||||
self._cleanup_graph_execution(
|
||||
execution_queue=execution_queue,
|
||||
running_node_execution=running_node_execution,
|
||||
running_node_evaluation=running_node_evaluation,
|
||||
@@ -937,10 +940,9 @@ class Executor:
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@error_logged(swallow=True)
|
||||
def _cleanup_graph_execution(
|
||||
cls,
|
||||
self,
|
||||
execution_queue: ExecutionQueue[NodeExecutionEntry],
|
||||
running_node_execution: dict[str, "NodeExecutionProgress"],
|
||||
running_node_evaluation: dict[str, Future],
|
||||
@@ -991,10 +993,9 @@ class Executor:
|
||||
|
||||
clean_exec_files(graph_exec_id)
|
||||
|
||||
@classmethod
|
||||
@async_error_logged(swallow=True)
|
||||
async def _process_node_output(
|
||||
cls,
|
||||
self,
|
||||
output: ExecutionOutputEntry,
|
||||
node_id: str,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
@@ -1027,9 +1028,8 @@ class Executor:
|
||||
):
|
||||
execution_queue.add(next_execution)
|
||||
|
||||
@classmethod
|
||||
def _handle_agent_run_notif(
|
||||
cls,
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec: GraphExecutionEntry,
|
||||
exec_stats: GraphExecutionStats,
|
||||
@@ -1065,9 +1065,8 @@ class Executor:
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _handle_low_balance_notif(
|
||||
cls,
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
@@ -1132,11 +1131,11 @@ class ExecutionManager(AppProcess):
|
||||
return self._stop_consuming
|
||||
|
||||
@property
|
||||
def executor(self) -> ProcessPoolExecutor:
|
||||
def executor(self) -> ThreadPoolExecutor:
|
||||
if self._executor is None:
|
||||
self._executor = ProcessPoolExecutor(
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
initializer=init_worker,
|
||||
)
|
||||
return self._executor
|
||||
|
||||
@@ -1313,11 +1312,9 @@ class ExecutionManager(AppProcess):
|
||||
_ack_message(reject=True)
|
||||
return
|
||||
|
||||
cancel_event = multiprocessing.Manager().Event()
|
||||
cancel_event = threading.Event()
|
||||
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_entry, cancel_event
|
||||
)
|
||||
future = self.executor.submit(execute_graph, graph_exec_entry, cancel_event)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
self._update_prompt_metrics()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user