mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-21 04:57:58 -05:00
Compare commits
1 Commits
testing-cl
...
feat/migra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9d846eebb |
@@ -1,11 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from concurrent.futures import Future, ProcessPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||||
|
|
||||||
@@ -102,6 +101,22 @@ utilization_gauge = Gauge(
|
|||||||
"Ratio of active graph runs to max graph workers",
|
"Ratio of active graph runs to max graph workers",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Thread-local storage for ExecutionProcessor instances
|
||||||
|
_tls = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def init_worker():
|
||||||
|
"""Initialize ExecutionProcessor instance in thread-local storage"""
|
||||||
|
_tls.processor = ExecutionProcessor()
|
||||||
|
_tls.processor.on_graph_executor_start()
|
||||||
|
|
||||||
|
|
||||||
|
def execute_graph(
|
||||||
|
graph_exec_entry: "GraphExecutionEntry", cancel_event: threading.Event
|
||||||
|
):
|
||||||
|
"""Execute graph using thread-local ExecutionProcessor instance"""
|
||||||
|
return _tls.processor.on_graph_execution(graph_exec_entry, cancel_event)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@@ -366,7 +381,7 @@ async def _enqueue_next_nodes(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Executor:
|
class ExecutionProcessor:
|
||||||
"""
|
"""
|
||||||
This class contains event handlers for the process pool executor events.
|
This class contains event handlers for the process pool executor events.
|
||||||
|
|
||||||
@@ -389,10 +404,9 @@ class Executor:
|
|||||||
9. Node executor enqueues the next executed nodes to the node execution queue.
|
9. Node executor enqueues the next executed nodes to the node execution queue.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@async_error_logged(swallow=True)
|
@async_error_logged(swallow=True)
|
||||||
async def on_node_execution(
|
async def on_node_execution(
|
||||||
cls,
|
self,
|
||||||
node_exec: NodeExecutionEntry,
|
node_exec: NodeExecutionEntry,
|
||||||
node_exec_progress: NodeExecutionProgress,
|
node_exec_progress: NodeExecutionProgress,
|
||||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||||
@@ -411,7 +425,7 @@ class Executor:
|
|||||||
node = await db_client.get_node(node_exec.node_id)
|
node = await db_client.get_node(node_exec.node_id)
|
||||||
execution_stats = NodeExecutionStats()
|
execution_stats = NodeExecutionStats()
|
||||||
|
|
||||||
timing_info, status = await cls._on_node_execution(
|
timing_info, status = await self._on_node_execution(
|
||||||
node=node,
|
node=node,
|
||||||
node_exec=node_exec,
|
node_exec=node_exec,
|
||||||
node_exec_progress=node_exec_progress,
|
node_exec_progress=node_exec_progress,
|
||||||
@@ -454,10 +468,9 @@ class Executor:
|
|||||||
|
|
||||||
return execution_stats
|
return execution_stats
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@async_time_measured
|
@async_time_measured
|
||||||
async def _on_node_execution(
|
async def _on_node_execution(
|
||||||
cls,
|
self,
|
||||||
node: Node,
|
node: Node,
|
||||||
node_exec: NodeExecutionEntry,
|
node_exec: NodeExecutionEntry,
|
||||||
node_exec_progress: NodeExecutionProgress,
|
node_exec_progress: NodeExecutionProgress,
|
||||||
@@ -497,7 +510,7 @@ class Executor:
|
|||||||
|
|
||||||
async for output_name, output_data in execute_node(
|
async for output_name, output_data in execute_node(
|
||||||
node=node,
|
node=node,
|
||||||
creds_manager=cls.creds_manager,
|
creds_manager=self.creds_manager,
|
||||||
data=node_exec,
|
data=node_exec,
|
||||||
execution_stats=stats,
|
execution_stats=stats,
|
||||||
nodes_input_masks=nodes_input_masks,
|
nodes_input_masks=nodes_input_masks,
|
||||||
@@ -537,29 +550,27 @@ class Executor:
|
|||||||
|
|
||||||
return status
|
return status
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@func_retry
|
@func_retry
|
||||||
def on_graph_executor_start(cls):
|
def on_graph_executor_start(self):
|
||||||
configure_logging()
|
configure_logging()
|
||||||
set_service_name("GraphExecutor")
|
set_service_name("GraphExecutor")
|
||||||
cls.pid = os.getpid()
|
self.tid = threading.get_ident()
|
||||||
cls.creds_manager = IntegrationCredentialsManager()
|
self.creds_manager = IntegrationCredentialsManager()
|
||||||
cls.node_execution_loop = asyncio.new_event_loop()
|
self.node_execution_loop = asyncio.new_event_loop()
|
||||||
cls.node_evaluation_loop = asyncio.new_event_loop()
|
self.node_evaluation_loop = asyncio.new_event_loop()
|
||||||
cls.node_execution_thread = threading.Thread(
|
self.node_execution_thread = threading.Thread(
|
||||||
target=cls.node_execution_loop.run_forever, daemon=True
|
target=self.node_execution_loop.run_forever, daemon=True
|
||||||
)
|
)
|
||||||
cls.node_evaluation_thread = threading.Thread(
|
self.node_evaluation_thread = threading.Thread(
|
||||||
target=cls.node_evaluation_loop.run_forever, daemon=True
|
target=self.node_evaluation_loop.run_forever, daemon=True
|
||||||
)
|
)
|
||||||
cls.node_execution_thread.start()
|
self.node_execution_thread.start()
|
||||||
cls.node_evaluation_thread.start()
|
self.node_evaluation_thread.start()
|
||||||
logger.info(f"[GraphExecutor] {cls.pid} started")
|
logger.info(f"[GraphExecutor] {self.tid} started")
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@error_logged(swallow=False)
|
@error_logged(swallow=False)
|
||||||
def on_graph_execution(
|
def on_graph_execution(
|
||||||
cls,
|
self,
|
||||||
graph_exec: GraphExecutionEntry,
|
graph_exec: GraphExecutionEntry,
|
||||||
cancel: threading.Event,
|
cancel: threading.Event,
|
||||||
):
|
):
|
||||||
@@ -615,7 +626,7 @@ class Executor:
|
|||||||
else:
|
else:
|
||||||
exec_stats = exec_meta.stats.to_db()
|
exec_stats = exec_meta.stats.to_db()
|
||||||
|
|
||||||
timing_info, status = cls._on_graph_execution(
|
timing_info, status = self._on_graph_execution(
|
||||||
graph_exec=graph_exec,
|
graph_exec=graph_exec,
|
||||||
cancel=cancel,
|
cancel=cancel,
|
||||||
log_metadata=log_metadata,
|
log_metadata=log_metadata,
|
||||||
@@ -641,7 +652,7 @@ class Executor:
|
|||||||
user_id=graph_exec.user_id,
|
user_id=graph_exec.user_id,
|
||||||
execution_status=status,
|
execution_status=status,
|
||||||
),
|
),
|
||||||
cls.node_execution_loop,
|
self.node_execution_loop,
|
||||||
).result(timeout=60.0)
|
).result(timeout=60.0)
|
||||||
if activity_status is not None:
|
if activity_status is not None:
|
||||||
exec_stats.activity_status = activity_status
|
exec_stats.activity_status = activity_status
|
||||||
@@ -652,7 +663,7 @@ class Executor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Communication handling
|
# Communication handling
|
||||||
cls._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
update_graph_execution_state(
|
update_graph_execution_state(
|
||||||
@@ -662,9 +673,8 @@ class Executor:
|
|||||||
stats=exec_stats,
|
stats=exec_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _charge_usage(
|
def _charge_usage(
|
||||||
cls,
|
self,
|
||||||
node_exec: NodeExecutionEntry,
|
node_exec: NodeExecutionEntry,
|
||||||
execution_count: int,
|
execution_count: int,
|
||||||
) -> int:
|
) -> int:
|
||||||
@@ -714,22 +724,15 @@ class Executor:
|
|||||||
|
|
||||||
return total_cost
|
return total_cost
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@time_measured
|
@time_measured
|
||||||
def _on_graph_execution(
|
def _on_graph_execution(
|
||||||
cls,
|
self,
|
||||||
graph_exec: GraphExecutionEntry,
|
graph_exec: GraphExecutionEntry,
|
||||||
cancel: threading.Event,
|
cancel: threading.Event,
|
||||||
log_metadata: LogMetadata,
|
log_metadata: LogMetadata,
|
||||||
execution_stats: GraphExecutionStats,
|
execution_stats: GraphExecutionStats,
|
||||||
) -> ExecutionStatus:
|
) -> ExecutionStatus:
|
||||||
|
|
||||||
# Agent execution is uninterrupted.
|
|
||||||
import signal
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
||||||
signal.signal(signal.SIGTERM, signal.SIG_IGN)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
dict: The execution statistics of the graph execution.
|
dict: The execution statistics of the graph execution.
|
||||||
@@ -786,7 +789,7 @@ class Executor:
|
|||||||
|
|
||||||
# Charge usage (may raise) ------------------------------
|
# Charge usage (may raise) ------------------------------
|
||||||
try:
|
try:
|
||||||
cost = cls._charge_usage(
|
cost = self._charge_usage(
|
||||||
node_exec=queued_node_exec,
|
node_exec=queued_node_exec,
|
||||||
execution_count=increment_execution_count(graph_exec.user_id),
|
execution_count=increment_execution_count(graph_exec.user_id),
|
||||||
)
|
)
|
||||||
@@ -806,7 +809,7 @@ class Executor:
|
|||||||
status=ExecutionStatus.FAILED,
|
status=ExecutionStatus.FAILED,
|
||||||
)
|
)
|
||||||
|
|
||||||
cls._handle_low_balance_notif(
|
self._handle_low_balance_notif(
|
||||||
db_client,
|
db_client,
|
||||||
graph_exec.user_id,
|
graph_exec.user_id,
|
||||||
graph_exec.graph_id,
|
graph_exec.graph_id,
|
||||||
@@ -825,7 +828,7 @@ class Executor:
|
|||||||
|
|
||||||
# Kick off async node execution -------------------------
|
# Kick off async node execution -------------------------
|
||||||
node_execution_task = asyncio.run_coroutine_threadsafe(
|
node_execution_task = asyncio.run_coroutine_threadsafe(
|
||||||
cls.on_node_execution(
|
self.on_node_execution(
|
||||||
node_exec=queued_node_exec,
|
node_exec=queued_node_exec,
|
||||||
node_exec_progress=running_node_execution[node_id],
|
node_exec_progress=running_node_execution[node_id],
|
||||||
nodes_input_masks=nodes_input_masks,
|
nodes_input_masks=nodes_input_masks,
|
||||||
@@ -834,7 +837,7 @@ class Executor:
|
|||||||
execution_stats_lock,
|
execution_stats_lock,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
cls.node_execution_loop,
|
self.node_execution_loop,
|
||||||
)
|
)
|
||||||
running_node_execution[node_id].add_task(
|
running_node_execution[node_id].add_task(
|
||||||
node_exec_id=queued_node_exec.node_exec_id,
|
node_exec_id=queued_node_exec.node_exec_id,
|
||||||
@@ -875,7 +878,7 @@ class Executor:
|
|||||||
node_output_found = True
|
node_output_found = True
|
||||||
running_node_evaluation[node_id] = (
|
running_node_evaluation[node_id] = (
|
||||||
asyncio.run_coroutine_threadsafe(
|
asyncio.run_coroutine_threadsafe(
|
||||||
cls._process_node_output(
|
self._process_node_output(
|
||||||
output=output,
|
output=output,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
graph_exec=graph_exec,
|
graph_exec=graph_exec,
|
||||||
@@ -883,7 +886,7 @@ class Executor:
|
|||||||
nodes_input_masks=nodes_input_masks,
|
nodes_input_masks=nodes_input_masks,
|
||||||
execution_queue=execution_queue,
|
execution_queue=execution_queue,
|
||||||
),
|
),
|
||||||
cls.node_evaluation_loop,
|
self.node_evaluation_loop,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
@@ -926,7 +929,7 @@ class Executor:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
cls._cleanup_graph_execution(
|
self._cleanup_graph_execution(
|
||||||
execution_queue=execution_queue,
|
execution_queue=execution_queue,
|
||||||
running_node_execution=running_node_execution,
|
running_node_execution=running_node_execution,
|
||||||
running_node_evaluation=running_node_evaluation,
|
running_node_evaluation=running_node_evaluation,
|
||||||
@@ -937,10 +940,9 @@ class Executor:
|
|||||||
db_client=db_client,
|
db_client=db_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@error_logged(swallow=True)
|
@error_logged(swallow=True)
|
||||||
def _cleanup_graph_execution(
|
def _cleanup_graph_execution(
|
||||||
cls,
|
self,
|
||||||
execution_queue: ExecutionQueue[NodeExecutionEntry],
|
execution_queue: ExecutionQueue[NodeExecutionEntry],
|
||||||
running_node_execution: dict[str, "NodeExecutionProgress"],
|
running_node_execution: dict[str, "NodeExecutionProgress"],
|
||||||
running_node_evaluation: dict[str, Future],
|
running_node_evaluation: dict[str, Future],
|
||||||
@@ -991,10 +993,9 @@ class Executor:
|
|||||||
|
|
||||||
clean_exec_files(graph_exec_id)
|
clean_exec_files(graph_exec_id)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@async_error_logged(swallow=True)
|
@async_error_logged(swallow=True)
|
||||||
async def _process_node_output(
|
async def _process_node_output(
|
||||||
cls,
|
self,
|
||||||
output: ExecutionOutputEntry,
|
output: ExecutionOutputEntry,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
graph_exec: GraphExecutionEntry,
|
graph_exec: GraphExecutionEntry,
|
||||||
@@ -1027,9 +1028,8 @@ class Executor:
|
|||||||
):
|
):
|
||||||
execution_queue.add(next_execution)
|
execution_queue.add(next_execution)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _handle_agent_run_notif(
|
def _handle_agent_run_notif(
|
||||||
cls,
|
self,
|
||||||
db_client: "DatabaseManagerClient",
|
db_client: "DatabaseManagerClient",
|
||||||
graph_exec: GraphExecutionEntry,
|
graph_exec: GraphExecutionEntry,
|
||||||
exec_stats: GraphExecutionStats,
|
exec_stats: GraphExecutionStats,
|
||||||
@@ -1065,9 +1065,8 @@ class Executor:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _handle_low_balance_notif(
|
def _handle_low_balance_notif(
|
||||||
cls,
|
self,
|
||||||
db_client: "DatabaseManagerClient",
|
db_client: "DatabaseManagerClient",
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
@@ -1132,11 +1131,11 @@ class ExecutionManager(AppProcess):
|
|||||||
return self._stop_consuming
|
return self._stop_consuming
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def executor(self) -> ProcessPoolExecutor:
|
def executor(self) -> ThreadPoolExecutor:
|
||||||
if self._executor is None:
|
if self._executor is None:
|
||||||
self._executor = ProcessPoolExecutor(
|
self._executor = ThreadPoolExecutor(
|
||||||
max_workers=self.pool_size,
|
max_workers=self.pool_size,
|
||||||
initializer=Executor.on_graph_executor_start,
|
initializer=init_worker,
|
||||||
)
|
)
|
||||||
return self._executor
|
return self._executor
|
||||||
|
|
||||||
@@ -1313,11 +1312,9 @@ class ExecutionManager(AppProcess):
|
|||||||
_ack_message(reject=True)
|
_ack_message(reject=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
cancel_event = multiprocessing.Manager().Event()
|
cancel_event = threading.Event()
|
||||||
|
|
||||||
future = self.executor.submit(
|
future = self.executor.submit(execute_graph, graph_exec_entry, cancel_event)
|
||||||
Executor.on_graph_execution, graph_exec_entry, cancel_event
|
|
||||||
)
|
|
||||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||||
self._update_prompt_metrics()
|
self._update_prompt_metrics()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user