mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 15:47:59 -05:00
feat(backend/executor): Move execution queue + cancel mechanism to RabbitMQ (#9759)
The graph execution queue is not disk-persisted; when the executor dies, the executions are lost. The scope of this issue is migrating the execution queue from an inter-process queue to a RabbitMQ message queue. A sync client should be used for this. - Resolves #9746 - Resolves #9714 ### Changes 🏗️ Move the execution manager from multiprocess.Queue into persisted Rabbit-MQ. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Execute agents. <details> <summary>Example test plan</summary> - [ ] Create from scratch and execute an agent with at least 3 blocks - [ ] Import an agent from file upload, and confirm it executes correctly - [ ] Upload agent to marketplace - [ ] Import an agent from marketplace and confirm it executes correctly - [ ] Edit an agent from monitor, and confirm it executes correctly </details> #### For configuration changes: - [ ] `.env.example` is updated or already compatible with my changes - [ ] `docker-compose.yml` is updated or already compatible with my changes - [ ] I have included a list of my configuration changes in the PR description (under **Changes**) <details> <summary>Examples of configuration changes</summary> - Changing ports - Adding new services that need to communicate with each other - Secrets or environment variable changes - New or infrastructure changes such as databases </details>
This commit is contained in:
@@ -5,11 +5,15 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
@@ -20,6 +24,13 @@ from backend.data.notifications import (
|
||||
NotificationEventDTO,
|
||||
NotificationType,
|
||||
)
|
||||
from backend.data.rabbitmq import (
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
SyncRabbitMQ,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -454,6 +465,51 @@ def validate_exec(
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
GRAPH_EXECUTION_EXCHANGE = Exchange(
|
||||
name="graph_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue"
|
||||
GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run"
|
||||
|
||||
GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
|
||||
name="graph_execution_cancel",
|
||||
type=ExchangeType.FANOUT,
|
||||
durable=True,
|
||||
auto_delete=True,
|
||||
)
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
|
||||
|
||||
|
||||
def create_execution_config() -> RabbitMQConfig:
|
||||
"""
|
||||
Define two exchanges and queues:
|
||||
- 'graph_execution' (DIRECT) for run tasks.
|
||||
- 'graph_execution_cancel' (FANOUT) for cancel requests.
|
||||
"""
|
||||
run_queue = Queue(
|
||||
name=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
)
|
||||
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
This class contains event handlers for the process pool executor events.
|
||||
@@ -927,12 +983,18 @@ class Executor:
|
||||
)
|
||||
|
||||
|
||||
class CancelExecutionEvent(BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecutionEntry]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self.rabbit_config = create_execution_config()
|
||||
self.rabbitmq_service = SyncRabbitMQ(self.rabbit_config)
|
||||
self.running = True
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event, int]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
@@ -943,6 +1005,10 @@ class ExecutionManager(AppService):
|
||||
|
||||
self.credentials_store = IntegrationCredentialsStore()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to RabbitMQ...")
|
||||
self.rabbitmq_service.connect()
|
||||
channel = self.rabbitmq_service.get_channel()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
@@ -952,25 +1018,112 @@ class ExecutionManager(AppService):
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
|
||||
sync_manager = multiprocessing.Manager()
|
||||
logger.info(f"[{self.service_name}] Ready to consume messages...")
|
||||
while True:
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
logger.debug(
|
||||
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
|
||||
# cancel graph execution requests
|
||||
method_frame, _, body = channel.basic_get(
|
||||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
auto_ack=True,
|
||||
)
|
||||
cancel_event = sync_manager.Event()
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
if method_frame:
|
||||
self._handle_cancel_message(body)
|
||||
|
||||
# start graph execution requests
|
||||
method_frame, _, body = channel.basic_get(
|
||||
queue=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
auto_ack=False,
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id, None)
|
||||
if method_frame:
|
||||
self._handle_run_message(channel, method_frame, body)
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
def _handle_cancel_message(self, body: bytes):
|
||||
try:
|
||||
request = CancelExecutionEvent.model_validate_json(body)
|
||||
graph_exec_id = request.graph_exec_id
|
||||
if not graph_exec_id:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Cancel message missing 'graph_exec_id'"
|
||||
)
|
||||
return
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
logger.debug(
|
||||
f"[{self.service_name}] Cancel received for {graph_exec_id} but not active."
|
||||
)
|
||||
return
|
||||
|
||||
_, cancel_event, _ = self.active_graph_runs[graph_exec_id]
|
||||
logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(
|
||||
f"[{self.service_name}] Cancel already set for {graph_exec_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling cancel message: {e}")
|
||||
|
||||
def _handle_run_message(
|
||||
self, channel: BlockingChannel, method_frame: Basic.GetOk, body: bytes
|
||||
):
|
||||
delivery_tag = method_frame.delivery_tag
|
||||
try:
|
||||
graph_exec_entry = GraphExecutionEntry.model_validate_json(body)
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.service_name}] Could not parse run message: {e}")
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||||
logger.info(
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
|
||||
)
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
)
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
return
|
||||
|
||||
cancel_event = multiprocessing.Manager().Event()
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_entry, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event, delivery_tag)
|
||||
|
||||
def _on_run_done(f: Future):
|
||||
logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}")
|
||||
info = self.active_graph_runs.pop(graph_exec_id, None)
|
||||
if not info:
|
||||
return
|
||||
_, _, delivery_tag = info
|
||||
if future.exception():
|
||||
logger.error(
|
||||
f"[{self.service_name}] Execution for {graph_exec_id} failed: {future.exception()}"
|
||||
)
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
else:
|
||||
channel.basic_ack(delivery_tag)
|
||||
|
||||
future.add_done_callback(_on_run_done)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down service loop...")
|
||||
self.running = False
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
|
||||
self.executor.shutdown(cancel_futures=True)
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
|
||||
self.rabbitmq_service.disconnect()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
|
||||
self.executor.shutdown(cancel_futures=True)
|
||||
|
||||
@@ -1064,8 +1217,11 @@ class ExecutionManager(AppService):
|
||||
for node_exec in graph_exec.node_executions
|
||||
],
|
||||
)
|
||||
self.queue.add(graph_exec_entry)
|
||||
|
||||
self.rabbitmq_service.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
return graph_exec_entry
|
||||
|
||||
@expose
|
||||
@@ -1077,16 +1233,11 @@ class ExecutionManager(AppService):
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
logger.warning(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
else:
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
self.rabbitmq_service.publish_message(
|
||||
routing_key="",
|
||||
message=CancelExecutionEvent(graph_exec_id=graph_exec_id).model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
)
|
||||
|
||||
# Update the status of the graph & node executions
|
||||
self.db_client.update_graph_execution_stats(
|
||||
|
||||
@@ -15,21 +15,25 @@ def to_dict(data) -> dict:
|
||||
|
||||
|
||||
def dumps(data) -> str:
|
||||
return json.dumps(jsonable_encoder(data))
|
||||
return json.dumps(to_dict(data))
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str, *args, **kwargs) -> Any: ...
|
||||
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
|
||||
|
||||
|
||||
def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
|
||||
) -> Any:
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
parsed = json.loads(data, *args, **kwargs)
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
|
||||
@@ -14,9 +14,7 @@ def sentry_init():
|
||||
traces_sample_rate=1.0,
|
||||
profiles_sample_rate=1.0,
|
||||
environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}",
|
||||
_experiments={
|
||||
"enable_logs": True,
|
||||
},
|
||||
_experiments={"enable_logs": True},
|
||||
integrations=[
|
||||
LoggingIntegration(sentry_logs_level=logging.INFO),
|
||||
AnthropicIntegration(
|
||||
|
||||
Reference in New Issue
Block a user