Compare commits

...

1 Commits

Author SHA1 Message Date
Zamil Majdy
e9d846eebb feat(backend): migrate AgentExecutor from ProcessPoolExecutor to ThreadPoolExecutor
- Migrate execution manager from ProcessPoolExecutor to ThreadPoolExecutor for improved performance and resource efficiency
- Rename `Executor` class to `ExecutionProcessor` for better clarity
- Convert classmethods to instance methods following proper OOP design patterns
- Implement thread-local storage using `threading.local()` for thread-safe execution
- Replace process ID tracking with thread ID tracking using `threading.get_ident()`
- Replace `multiprocessing.Manager().Event()` with `threading.Event()`
- Remove signal handling code that doesn't work in worker threads
- Update ExecutionManager to use ThreadPoolExecutor with new `init_worker` initializer

Benefits:
- Performance: Reduced overhead compared to process creation/destruction
- Resource Efficiency: Lower memory footprint and faster startup
- Simplicity: Cleaner implementation using thread-local storage pattern
- Thread Safety: Maintained through isolated ExecutionProcessor instances per thread

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-07 06:53:32 +07:00

View File

@@ -1,11 +1,10 @@
import asyncio
import logging
import multiprocessing
import os
import threading
import time
from collections import defaultdict
from concurrent.futures import Future, ProcessPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
@@ -102,6 +101,22 @@ utilization_gauge = Gauge(
"Ratio of active graph runs to max graph workers",
)
# Thread-local storage for ExecutionProcessor instances
_tls = threading.local()
def init_worker():
"""Initialize ExecutionProcessor instance in thread-local storage"""
_tls.processor = ExecutionProcessor()
_tls.processor.on_graph_executor_start()
def execute_graph(
graph_exec_entry: "GraphExecutionEntry", cancel_event: threading.Event
):
"""Execute graph using thread-local ExecutionProcessor instance"""
return _tls.processor.on_graph_execution(graph_exec_entry, cancel_event)
T = TypeVar("T")
@@ -366,7 +381,7 @@ async def _enqueue_next_nodes(
]
class Executor:
class ExecutionProcessor:
"""
This class contains event handlers for the process pool executor events.
@@ -389,10 +404,9 @@ class Executor:
9. Node executor enqueues the next executed nodes to the node execution queue.
"""
@classmethod
@async_error_logged(swallow=True)
async def on_node_execution(
cls,
self,
node_exec: NodeExecutionEntry,
node_exec_progress: NodeExecutionProgress,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
@@ -411,7 +425,7 @@ class Executor:
node = await db_client.get_node(node_exec.node_id)
execution_stats = NodeExecutionStats()
timing_info, status = await cls._on_node_execution(
timing_info, status = await self._on_node_execution(
node=node,
node_exec=node_exec,
node_exec_progress=node_exec_progress,
@@ -454,10 +468,9 @@ class Executor:
return execution_stats
@classmethod
@async_time_measured
async def _on_node_execution(
cls,
self,
node: Node,
node_exec: NodeExecutionEntry,
node_exec_progress: NodeExecutionProgress,
@@ -497,7 +510,7 @@ class Executor:
async for output_name, output_data in execute_node(
node=node,
creds_manager=cls.creds_manager,
creds_manager=self.creds_manager,
data=node_exec,
execution_stats=stats,
nodes_input_masks=nodes_input_masks,
@@ -537,29 +550,27 @@ class Executor:
return status
@classmethod
@func_retry
def on_graph_executor_start(cls):
def on_graph_executor_start(self):
configure_logging()
set_service_name("GraphExecutor")
cls.pid = os.getpid()
cls.creds_manager = IntegrationCredentialsManager()
cls.node_execution_loop = asyncio.new_event_loop()
cls.node_evaluation_loop = asyncio.new_event_loop()
cls.node_execution_thread = threading.Thread(
target=cls.node_execution_loop.run_forever, daemon=True
self.tid = threading.get_ident()
self.creds_manager = IntegrationCredentialsManager()
self.node_execution_loop = asyncio.new_event_loop()
self.node_evaluation_loop = asyncio.new_event_loop()
self.node_execution_thread = threading.Thread(
target=self.node_execution_loop.run_forever, daemon=True
)
cls.node_evaluation_thread = threading.Thread(
target=cls.node_evaluation_loop.run_forever, daemon=True
self.node_evaluation_thread = threading.Thread(
target=self.node_evaluation_loop.run_forever, daemon=True
)
cls.node_execution_thread.start()
cls.node_evaluation_thread.start()
logger.info(f"[GraphExecutor] {cls.pid} started")
self.node_execution_thread.start()
self.node_evaluation_thread.start()
logger.info(f"[GraphExecutor] {self.tid} started")
@classmethod
@error_logged(swallow=False)
def on_graph_execution(
cls,
self,
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
):
@@ -615,7 +626,7 @@ class Executor:
else:
exec_stats = exec_meta.stats.to_db()
timing_info, status = cls._on_graph_execution(
timing_info, status = self._on_graph_execution(
graph_exec=graph_exec,
cancel=cancel,
log_metadata=log_metadata,
@@ -641,7 +652,7 @@ class Executor:
user_id=graph_exec.user_id,
execution_status=status,
),
cls.node_execution_loop,
self.node_execution_loop,
).result(timeout=60.0)
if activity_status is not None:
exec_stats.activity_status = activity_status
@@ -652,7 +663,7 @@ class Executor:
)
# Communication handling
cls._handle_agent_run_notif(db_client, graph_exec, exec_stats)
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
finally:
update_graph_execution_state(
@@ -662,9 +673,8 @@ class Executor:
stats=exec_stats,
)
@classmethod
def _charge_usage(
cls,
self,
node_exec: NodeExecutionEntry,
execution_count: int,
) -> int:
@@ -714,22 +724,15 @@ class Executor:
return total_cost
@classmethod
@time_measured
def _on_graph_execution(
cls,
self,
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
log_metadata: LogMetadata,
execution_stats: GraphExecutionStats,
) -> ExecutionStatus:
# Agent execution is uninterrupted.
import signal
signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_IGN)
"""
Returns:
dict: The execution statistics of the graph execution.
@@ -786,7 +789,7 @@ class Executor:
# Charge usage (may raise) ------------------------------
try:
cost = cls._charge_usage(
cost = self._charge_usage(
node_exec=queued_node_exec,
execution_count=increment_execution_count(graph_exec.user_id),
)
@@ -806,7 +809,7 @@ class Executor:
status=ExecutionStatus.FAILED,
)
cls._handle_low_balance_notif(
self._handle_low_balance_notif(
db_client,
graph_exec.user_id,
graph_exec.graph_id,
@@ -825,7 +828,7 @@ class Executor:
# Kick off async node execution -------------------------
node_execution_task = asyncio.run_coroutine_threadsafe(
cls.on_node_execution(
self.on_node_execution(
node_exec=queued_node_exec,
node_exec_progress=running_node_execution[node_id],
nodes_input_masks=nodes_input_masks,
@@ -834,7 +837,7 @@ class Executor:
execution_stats_lock,
),
),
cls.node_execution_loop,
self.node_execution_loop,
)
running_node_execution[node_id].add_task(
node_exec_id=queued_node_exec.node_exec_id,
@@ -875,7 +878,7 @@ class Executor:
node_output_found = True
running_node_evaluation[node_id] = (
asyncio.run_coroutine_threadsafe(
cls._process_node_output(
self._process_node_output(
output=output,
node_id=node_id,
graph_exec=graph_exec,
@@ -883,7 +886,7 @@ class Executor:
nodes_input_masks=nodes_input_masks,
execution_queue=execution_queue,
),
cls.node_evaluation_loop,
self.node_evaluation_loop,
)
)
if (
@@ -926,7 +929,7 @@ class Executor:
raise
finally:
cls._cleanup_graph_execution(
self._cleanup_graph_execution(
execution_queue=execution_queue,
running_node_execution=running_node_execution,
running_node_evaluation=running_node_evaluation,
@@ -937,10 +940,9 @@ class Executor:
db_client=db_client,
)
@classmethod
@error_logged(swallow=True)
def _cleanup_graph_execution(
cls,
self,
execution_queue: ExecutionQueue[NodeExecutionEntry],
running_node_execution: dict[str, "NodeExecutionProgress"],
running_node_evaluation: dict[str, Future],
@@ -991,10 +993,9 @@ class Executor:
clean_exec_files(graph_exec_id)
@classmethod
@async_error_logged(swallow=True)
async def _process_node_output(
cls,
self,
output: ExecutionOutputEntry,
node_id: str,
graph_exec: GraphExecutionEntry,
@@ -1027,9 +1028,8 @@ class Executor:
):
execution_queue.add(next_execution)
@classmethod
def _handle_agent_run_notif(
cls,
self,
db_client: "DatabaseManagerClient",
graph_exec: GraphExecutionEntry,
exec_stats: GraphExecutionStats,
@@ -1065,9 +1065,8 @@ class Executor:
)
)
@classmethod
def _handle_low_balance_notif(
cls,
self,
db_client: "DatabaseManagerClient",
user_id: str,
graph_id: str,
@@ -1132,11 +1131,11 @@ class ExecutionManager(AppProcess):
return self._stop_consuming
@property
def executor(self) -> ProcessPoolExecutor:
def executor(self) -> ThreadPoolExecutor:
if self._executor is None:
self._executor = ProcessPoolExecutor(
self._executor = ThreadPoolExecutor(
max_workers=self.pool_size,
initializer=Executor.on_graph_executor_start,
initializer=init_worker,
)
return self._executor
@@ -1313,11 +1312,9 @@ class ExecutionManager(AppProcess):
_ack_message(reject=True)
return
cancel_event = multiprocessing.Manager().Event()
cancel_event = threading.Event()
future = self.executor.submit(
Executor.on_graph_execution, graph_exec_entry, cancel_event
)
future = self.executor.submit(execute_graph, graph_exec_entry, cancel_event)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
self._update_prompt_metrics()