mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
refactor(backend): Refactor log client and resource cleanup (#10558)
## Summary - Created centralized service client helpers with thread caching in `util/clients.py` - Refactored service client management to eliminate health checks and improve performance - Enhanced logging in process cleanup to include error details - Improved retry mechanisms and resource cleanup across the platform - Updated multiple services to use new centralized client patterns ## Key Changes ### New Centralized Client Factory (`util/clients.py`) - Added thread-cached factory functions for all major service clients: - Database managers (sync and async) - Scheduler client - Notification manager - Execution event bus (Redis-based) - RabbitMQ execution queue (sync and async) - Integration credentials store - All clients use `@thread_cached` decorator for performance optimization ### Service Client Improvements - **Removed health checks**: Eliminated unnecessary health check calls from `get_service_client()` to reduce startup overhead - **Enhanced retry support**: Database manager clients now use request retry by default - **Better error handling**: Improved error propagation and logging ### Enhanced Logging and Cleanup - **Process termination logs**: Added error details to termination messages in `util/process.py` - **Retry mechanism updates**: Improved retry logic with better error handling in `util/retry.py` - **Resource cleanup**: Better resource management across executors and monitoring services ### Updated Service Usage - Refactored 21+ files to use new centralized client patterns - Updated all executor, monitoring, and notification services - Maintained backward compatibility while improving performance ## Files Changed - **Created**: `backend/util/clients.py` - Centralized client factory with thread caching - **Modified**: 21 files across blocks, executor, monitoring, and utility modules - **Key areas**: Service client initialization, resource cleanup, retry mechanisms ## Test Plan - [x] Verify all existing tests pass - [x] Validate service startup and client initialization - [x] Test resource cleanup on process termination - [x] Confirm retry mechanisms work correctly - [x] Validate thread caching performance improvements - [x] Ensure no breaking changes to existing functionality ## Breaking Changes None - all changes maintain backward compatibility. ## Additional Notes This refactoring centralizes client management patterns that were scattered across the codebase, making them more consistent and performant through thread caching. The removal of health checks reduces startup time while maintaining reliability through improved retry mechanisms. 🤖 Generated with [Claude Code](https://claude.ai/code)
This commit is contained in:
@@ -14,7 +14,8 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json, retry
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,7 +49,7 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
pass
|
||||
@@ -180,7 +181,7 @@ class AgentExecutorBlock(Block):
|
||||
)
|
||||
yield output_name, output_data
|
||||
|
||||
@retry.func_retry
|
||||
@func_retry
|
||||
async def _stop(
|
||||
self,
|
||||
graph_exec_id: str,
|
||||
|
||||
@@ -1,26 +1,18 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.model import SchemaField, UserIntegrations
|
||||
from backend.integrations.ayrshare import AyrshareClient
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
|
||||
|
||||
async def get_profile_key(user_id: str):
|
||||
user_integrations: UserIntegrations = (
|
||||
await get_database_manager_client().get_user_integrations(user_id)
|
||||
await get_database_manager_async_client().get_user_integrations(user_id)
|
||||
)
|
||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
|
||||
@@ -1,22 +1,13 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
|
||||
|
||||
StorageScope = Literal["within_agent", "across_agents"]
|
||||
|
||||
|
||||
@@ -88,7 +79,7 @@ class PersistInformationBlock(Block):
|
||||
async def _store_data(
|
||||
self, user_id: str, node_exec_id: str, key: str, data: Any
|
||||
) -> Any | None:
|
||||
return await get_database_manager_client().set_execution_kv_data(
|
||||
return await get_database_manager_async_client().set_execution_kv_data(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
key=key,
|
||||
@@ -149,7 +140,7 @@ class RetrieveInformationBlock(Block):
|
||||
yield "value", input_data.default_value
|
||||
|
||||
async def _retrieve_data(self, user_id: str, key: str) -> Any | None:
|
||||
return await get_database_manager_client().get_execution_kv_data(
|
||||
return await get_database_manager_async_client().get_execution_kv_data(
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
)
|
||||
|
||||
@@ -3,8 +3,6 @@ import re
|
||||
from collections import Counter
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data.block import (
|
||||
@@ -17,6 +15,7 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import Link, Node
|
||||
@@ -24,14 +23,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool request.
|
||||
@@ -333,7 +324,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
if not graph_id or not graph_version:
|
||||
raise ValueError("Graph ID or Graph Version not found in sink node.")
|
||||
|
||||
db_client = get_database_manager_client()
|
||||
db_client = get_database_manager_async_client()
|
||||
sink_graph_meta = await db_client.get_graph_metadata(graph_id, graph_version)
|
||||
if not sink_graph_meta:
|
||||
raise ValueError(
|
||||
@@ -393,7 +384,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
ValueError: If no tool links are found for the specified node_id, or if a sink node
|
||||
or its metadata cannot be found.
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
db_client = get_database_manager_async_client()
|
||||
tools = [
|
||||
(link, node)
|
||||
for link, node in await db_client.get_connected_output_nodes(node_id)
|
||||
|
||||
@@ -39,6 +39,7 @@ from pydantic.fields import Field
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Config
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
@@ -883,15 +884,15 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
|
||||
|
||||
def publish(self, res: GraphExecution | NodeExecutionResult):
|
||||
if isinstance(res, GraphExecution):
|
||||
self.publish_graph_exec_update(res)
|
||||
self._publish_graph_exec_update(res)
|
||||
else:
|
||||
self.publish_node_exec_update(res)
|
||||
self._publish_node_exec_update(res)
|
||||
|
||||
def publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
def _publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
def publish_graph_exec_update(self, res: GraphExecution):
|
||||
def _publish_graph_exec_update(self, res: GraphExecution):
|
||||
event = GraphExecutionEvent.model_validate(res.model_dump())
|
||||
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
|
||||
@@ -923,17 +924,18 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
||||
def event_bus_name(self) -> str:
|
||||
return config.execution_event_bus_name
|
||||
|
||||
@func_retry
|
||||
async def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
|
||||
if isinstance(res, GraphExecutionMeta):
|
||||
await self.publish_graph_exec_update(res)
|
||||
await self._publish_graph_exec_update(res)
|
||||
else:
|
||||
await self.publish_node_exec_update(res)
|
||||
await self._publish_node_exec_update(res)
|
||||
|
||||
async def publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
async def _publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
async def publish_graph_exec_update(self, res: GraphExecutionMeta):
|
||||
async def _publish_graph_exec_update(self, res: GraphExecutionMeta):
|
||||
# GraphExecutionEvent requires inputs and outputs fields that GraphExecutionMeta doesn't have
|
||||
# Add default empty values for compatibility
|
||||
event_data = res.model_dump()
|
||||
|
||||
@@ -4,20 +4,12 @@ from enum import Enum
|
||||
from typing import Awaitable, Optional
|
||||
|
||||
import aio_pika
|
||||
import aio_pika.exceptions as aio_ex
|
||||
import pika
|
||||
import pika.adapters.blocking_connection
|
||||
from pika.exceptions import AMQPError
|
||||
from pika.spec import BasicProperties
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.retry import conn_retry, func_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -198,12 +190,7 @@ class SyncRabbitMQ(RabbitMQBase):
|
||||
routing_key=queue.routing_key or queue.name,
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((AMQPError, ConnectionError)),
|
||||
wait=wait_random_exponential(multiplier=1, max=5),
|
||||
stop=stop_after_attempt(5),
|
||||
reraise=True,
|
||||
)
|
||||
@func_retry
|
||||
def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
@@ -302,12 +289,7 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
exchange, routing_key=queue.routing_key or queue.name
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((aio_ex.AMQPError, ConnectionError)),
|
||||
wait=wait_random_exponential(multiplier=1, max=5),
|
||||
stop=stop_after_attempt(5),
|
||||
reraise=True,
|
||||
)
|
||||
@func_retry
|
||||
async def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.blocks.llm import LlmModel, llm_call
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.data.model import APIKeyCredentials, GraphExecutionStats
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
@@ -415,6 +416,7 @@ def _build_execution_summary(
|
||||
}
|
||||
|
||||
|
||||
@func_retry
|
||||
async def _call_llm_direct(
|
||||
credentials: APIKeyCredentials, prompt: list[dict[str, str]]
|
||||
) -> str:
|
||||
|
||||
@@ -26,14 +26,13 @@ from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import LogMetadata, create_execution_queue_config
|
||||
from backend.executor.utils import LogMetadata
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -63,14 +62,19 @@ from backend.executor.utils import (
|
||||
ExecutionOutputEntry,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
get_async_execution_event_bus,
|
||||
get_execution_event_bus,
|
||||
parse_execution_output,
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
)
|
||||
from backend.util.decorator import (
|
||||
async_error_logged,
|
||||
async_time_measured,
|
||||
@@ -81,7 +85,6 @@ from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import continuous_retry, func_retry
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@@ -1088,11 +1091,33 @@ class ExecutionManager(AppProcess):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self._stop_consuming = None
|
||||
|
||||
self._executor = None
|
||||
self._stop_consuming = None
|
||||
|
||||
self._cancel_thread = None
|
||||
self._cancel_client = None
|
||||
self._run_thread = None
|
||||
self._run_client = None
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
self._cancel_thread = threading.Thread(
|
||||
target=lambda: self._consume_execution_cancel(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._cancel_thread
|
||||
|
||||
@property
|
||||
def run_thread(self) -> threading.Thread:
|
||||
if self._run_thread is None:
|
||||
self._run_thread = threading.Thread(
|
||||
target=lambda: self._consume_execution_run(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._run_thread
|
||||
|
||||
@property
|
||||
def stop_consuming(self) -> threading.Event:
|
||||
if self._stop_consuming is None:
|
||||
@@ -1108,44 +1133,55 @@ class ExecutionManager(AppProcess):
|
||||
)
|
||||
return self._executor
|
||||
|
||||
@property
|
||||
def cancel_client(self) -> SyncRabbitMQ:
|
||||
if self._cancel_client is None:
|
||||
self._cancel_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
return self._cancel_client
|
||||
|
||||
@property
|
||||
def run_client(self) -> SyncRabbitMQ:
|
||||
if self._run_client is None:
|
||||
self._run_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
return self._run_client
|
||||
|
||||
def run(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
|
||||
|
||||
pool_size_gauge.set(self.pool_size)
|
||||
self._update_prompt_metrics()
|
||||
start_http_server(settings.config.execution_manager_port)
|
||||
|
||||
threading.Thread(
|
||||
target=lambda: self._consume_execution_cancel(),
|
||||
daemon=True,
|
||||
).start()
|
||||
self.cancel_thread.start()
|
||||
self.run_thread.start()
|
||||
|
||||
threading.Thread(
|
||||
target=lambda: self._consume_execution_run(),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
threading.Thread(
|
||||
target=start_http_server,
|
||||
args=(settings.config.execution_manager_port,),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
while not self.stop_consuming.is_set():
|
||||
while True:
|
||||
time.sleep(1e5)
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_execution_cancel(self):
|
||||
self._cancel_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
self._cancel_client.connect()
|
||||
cancel_channel = self._cancel_client.get_channel()
|
||||
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
|
||||
if self.stop_consuming.is_set() and not self.active_graph_runs:
|
||||
logger.info(
|
||||
f"[{self.service_name}] Stop reconnecting cancel consumer since the service is cleaned up."
|
||||
)
|
||||
return
|
||||
|
||||
self.cancel_client.connect()
|
||||
cancel_channel = self.cancel_client.get_channel()
|
||||
cancel_channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
)
|
||||
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
|
||||
cancel_channel.start_consuming()
|
||||
raise RuntimeError(f"❌ cancel message consumer is stopped: {cancel_channel}")
|
||||
if not self.stop_consuming.is_set() or self.active_graph_runs:
|
||||
raise RuntimeError(
|
||||
f"[{self.service_name}] ❌ cancel message consumer is stopped: {cancel_channel}"
|
||||
)
|
||||
logger.info(
|
||||
f"[{self.service_name}] ✅ Cancel message consumer stopped gracefully"
|
||||
)
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_execution_run(self):
|
||||
@@ -1159,9 +1195,8 @@ class ExecutionManager(AppProcess):
|
||||
)
|
||||
return
|
||||
|
||||
self._run_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
self._run_client.connect()
|
||||
run_channel = self._run_client.get_channel()
|
||||
self.run_client.connect()
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.basic_qos(prefetch_count=self.pool_size)
|
||||
|
||||
# Configure consumer for long-running graph executions
|
||||
@@ -1173,21 +1208,12 @@ class ExecutionManager(AppProcess):
|
||||
consumer_tag="graph_execution_consumer",
|
||||
)
|
||||
run_channel.confirm_delivery()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Starting to consume run messages...")
|
||||
|
||||
# Continue consuming messages until stop flag is set
|
||||
# This keeps the connection alive but rejects new messages in _handle_run_message
|
||||
while not self.stop_consuming.is_set():
|
||||
try:
|
||||
run_channel.connection.process_data_events(time_limit=1)
|
||||
except Exception as e:
|
||||
if self.stop_consuming.is_set():
|
||||
# Expected during shutdown
|
||||
break
|
||||
logger.error(f"[{self.service_name}] Error processing events: {e}")
|
||||
raise
|
||||
|
||||
run_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set():
|
||||
raise RuntimeError(
|
||||
f"[{self.service_name}] ❌ run message consumer is stopped: {run_channel}"
|
||||
)
|
||||
logger.info(f"[{self.service_name}] ✅ Run message consumer stopped gracefully")
|
||||
|
||||
@error_logged(swallow=True)
|
||||
@@ -1233,18 +1259,30 @@ class ExecutionManager(AppProcess):
|
||||
):
|
||||
delivery_tag = method.delivery_tag
|
||||
|
||||
@func_retry
|
||||
def _ack_message(reject: bool = False):
|
||||
"""Acknowledge or reject the message based on execution status."""
|
||||
if reject:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=True)
|
||||
)
|
||||
else:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_ack(delivery_tag)
|
||||
)
|
||||
|
||||
# Check if we're shutting down - reject new messages but keep connection alive
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info(
|
||||
f"[{self.service_name}] Rejecting new execution during shutdown"
|
||||
)
|
||||
channel.basic_nack(delivery_tag, requeue=True)
|
||||
_ack_message(reject=True)
|
||||
return
|
||||
|
||||
# Check if we can accept more runs
|
||||
self._cleanup_completed_runs()
|
||||
if len(self.active_graph_runs) >= self.pool_size:
|
||||
channel.basic_nack(delivery_tag, requeue=True)
|
||||
_ack_message(reject=True)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -1253,7 +1291,7 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(
|
||||
f"[{self.service_name}] Could not parse run message: {e}, body={body}"
|
||||
)
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
_ack_message(reject=True)
|
||||
return
|
||||
|
||||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||||
@@ -1265,7 +1303,7 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
)
|
||||
channel.basic_nack(delivery_tag, requeue=False)
|
||||
_ack_message(reject=True)
|
||||
return
|
||||
|
||||
cancel_event = multiprocessing.Manager().Event()
|
||||
@@ -1283,23 +1321,9 @@ class ExecutionManager(AppProcess):
|
||||
logger.error(
|
||||
f"[{self.service_name}] Execution for {graph_exec_id} failed: {type(exec_error)} {exec_error}"
|
||||
)
|
||||
try:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_nack(delivery_tag, requeue=True)
|
||||
)
|
||||
except Exception as ack_error:
|
||||
logger.error(
|
||||
f"[{self.service_name}] Failed to NACK message for {graph_exec_id}: {ack_error}"
|
||||
)
|
||||
_ack_message(reject=True)
|
||||
else:
|
||||
try:
|
||||
channel.connection.add_callback_threadsafe(
|
||||
lambda: channel.basic_ack(delivery_tag)
|
||||
)
|
||||
except Exception as ack_error:
|
||||
logger.error(
|
||||
f"[{self.service_name}] Failed to ACK message for {graph_exec_id}: {ack_error}"
|
||||
)
|
||||
_ack_message(reject=False)
|
||||
except BaseException as e:
|
||||
logger.exception(
|
||||
f"[{self.service_name}] Error in run completion callback: {e}"
|
||||
@@ -1326,7 +1350,7 @@ class ExecutionManager(AppProcess):
|
||||
def _update_prompt_metrics(self):
|
||||
active_count = len(self.active_graph_runs)
|
||||
active_runs_gauge.set(active_count)
|
||||
if self._stop_consuming and self._stop_consuming.is_set():
|
||||
if self.stop_consuming.is_set():
|
||||
utilization_gauge.set(1.0)
|
||||
else:
|
||||
utilization_gauge.set(active_count / self.pool_size)
|
||||
@@ -1337,8 +1361,15 @@ class ExecutionManager(AppProcess):
|
||||
logger.info(f"{prefix} 🧹 Starting graceful shutdown...")
|
||||
|
||||
# Signal the consumer thread to stop (thread-safe)
|
||||
self.stop_consuming.set()
|
||||
logger.info(f"{prefix} ✅ Signaled execution message consumer to stop")
|
||||
try:
|
||||
self.stop_consuming.set()
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.connection.add_callback_threadsafe(
|
||||
lambda: run_channel.stop_consuming()
|
||||
)
|
||||
logger.info(f"{prefix} ✅ Exec consumer has been signaled to stop")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error signaling consumer to stop: {type(e)} {e}")
|
||||
|
||||
# Wait for active executions to complete
|
||||
if self.active_graph_runs:
|
||||
@@ -1371,34 +1402,34 @@ class ExecutionManager(AppProcess):
|
||||
else:
|
||||
logger.info(f"{prefix} ✅ All executions completed gracefully")
|
||||
|
||||
# NOW shutdown executor pool after all executions and cleanup are complete
|
||||
if self._executor:
|
||||
logger.info(f"{prefix} ⏳ Shutting down GraphExec pool...")
|
||||
try:
|
||||
# All active executions are done, safe to shutdown workers
|
||||
self._executor.shutdown(cancel_futures=True, wait=False)
|
||||
logger.info(f"{prefix} ✅ Executor shutdown completed")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {e}")
|
||||
# Shutdown the executor
|
||||
try:
|
||||
self.executor.shutdown(cancel_futures=True, wait=False)
|
||||
logger.info(f"{prefix} ✅ Executor shutdown completed")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
|
||||
|
||||
# Clean up RabbitMQ connections
|
||||
if self._cancel_client:
|
||||
logger.info(f"{prefix} ⏳ Disconnecting cancel RabbitMQ client...")
|
||||
try:
|
||||
self._cancel_client.disconnect()
|
||||
logger.info(f"{prefix} ✅ Cancel RabbitMQ client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{prefix} ⚠️ Error disconnecting cancel RabbitMQ client: {e}"
|
||||
)
|
||||
# Disconnect the run execution consumer
|
||||
try:
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.connection.add_callback_threadsafe(
|
||||
lambda: self.run_client.disconnect()
|
||||
)
|
||||
self.run_thread.join()
|
||||
logger.info(f"{prefix} ✅ Run client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}")
|
||||
|
||||
if self._run_client:
|
||||
logger.info(f"{prefix} ⏳ Disconnecting run RabbitMQ client...")
|
||||
try:
|
||||
self._run_client.disconnect()
|
||||
logger.info(f"{prefix} ✅ Run RabbitMQ client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error disconnecting run RabbitMQ client: {e}")
|
||||
# Disconnect the cancel execution consumer
|
||||
try:
|
||||
cancel_channel = self.cancel_client.get_channel()
|
||||
cancel_channel.connection.add_callback_threadsafe(
|
||||
lambda: self.cancel_client.disconnect()
|
||||
)
|
||||
self.cancel_thread.join()
|
||||
logger.info(f"{prefix} ✅ Cancel client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error disconnecting cancel client: {type(e)} {e}")
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
@@ -1406,26 +1437,15 @@ class ExecutionManager(AppProcess):
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
# Disable health check for the service client to avoid breaking process initializer.
|
||||
return get_service_client(
|
||||
DatabaseManagerClient, health_check=False, request_retry=True
|
||||
)
|
||||
return get_database_manager_client()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_async_client() -> "DatabaseManagerAsyncClient":
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
|
||||
# Disable health check for the service client to avoid breaking process initializer.
|
||||
return get_service_client(
|
||||
DatabaseManagerAsyncClient, health_check=False, request_retry=True
|
||||
)
|
||||
return get_database_manager_async_client()
|
||||
|
||||
|
||||
@func_retry
|
||||
async def send_async_execution_update(
|
||||
entry: GraphExecution | NodeExecutionResult | None,
|
||||
) -> None:
|
||||
@@ -1434,6 +1454,7 @@ async def send_async_execution_update(
|
||||
await get_async_execution_event_bus().publish(entry)
|
||||
|
||||
|
||||
@func_retry
|
||||
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
|
||||
if entry is None:
|
||||
return
|
||||
|
||||
@@ -93,6 +93,7 @@ async def _execute_graph(**kwargs):
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO: We need to communicate this error to the user somehow.
|
||||
logger.error(f"Error executing graph {args.graph_id}: {e}")
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from backend.data import db
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@@ -17,7 +16,7 @@ async def test_agent_schedule(server: SpinTestServer):
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
scheduler = get_service_client(SchedulerClient)
|
||||
scheduler = get_scheduler_client()
|
||||
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
|
||||
@@ -4,9 +4,8 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
@@ -16,33 +15,25 @@ from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.db import prisma
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
RedisExecutionEventBus,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
SyncRabbitMQ,
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_async_execution_queue,
|
||||
get_database_manager_async_client,
|
||||
get_integration_credentials_store,
|
||||
)
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
config = Config()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
||||
|
||||
@@ -79,51 +70,6 @@ class LogMetadata(TruncatedLogger):
|
||||
)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_event_bus() -> RedisExecutionEventBus:
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_async_execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_queue() -> SyncRabbitMQ:
|
||||
client = SyncRabbitMQ(create_execution_queue_config())
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_execution_queue() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
return IntegrationCredentialsStore()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_async_client() -> "DatabaseManagerAsyncClient":
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient)
|
||||
|
||||
|
||||
# ============ Execution Cost Helpers ============ #
|
||||
|
||||
|
||||
@@ -450,7 +396,7 @@ def validate_exec(
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
logger.warning(error_message)
|
||||
return None, error_message
|
||||
|
||||
return data, node_block.name
|
||||
@@ -685,11 +631,6 @@ def _merge_nodes_input_masks(
|
||||
|
||||
# ============ Execution Queue Helpers ============ #
|
||||
|
||||
|
||||
class CancelExecutionEvent(BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
GRAPH_EXECUTION_EXCHANGE = Exchange(
|
||||
name="graph_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
@@ -750,6 +691,10 @@ def create_execution_queue_config() -> RabbitMQConfig:
|
||||
)
|
||||
|
||||
|
||||
class CancelExecutionEvent(BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
async def stop_graph_execution(
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
@@ -763,7 +708,7 @@ async def stop_graph_execution(
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
queue_client = await get_async_execution_queue()
|
||||
db = execution_db if prisma.is_connected() else get_db_async_client()
|
||||
db = execution_db if prisma.is_connected() else get_database_manager_async_client()
|
||||
await queue_client.publish_message(
|
||||
routing_key="",
|
||||
message=CancelExecutionEvent(graph_exec_id=graph_exec_id).model_dump_json(),
|
||||
@@ -849,8 +794,8 @@ async def add_graph_execution(
|
||||
gdb = graph_db
|
||||
edb = execution_db
|
||||
else:
|
||||
gdb = get_db_async_client()
|
||||
edb = get_db_async_client()
|
||||
gdb = get_database_manager_async_client()
|
||||
edb = get_database_manager_async_client()
|
||||
|
||||
graph: GraphModel | None = await gdb.get_graph(
|
||||
graph_id=graph_id,
|
||||
@@ -903,7 +848,7 @@ async def add_graph_execution(
|
||||
except BaseException as e:
|
||||
err = str(e) or type(e).__name__
|
||||
if not graph_exec:
|
||||
logger.error(f"Graph execution #{graph_id} failed: {err}")
|
||||
logger.error(f"Unable to execute graph #{graph_id} failed: {err}")
|
||||
raise
|
||||
|
||||
logger.error(
|
||||
|
||||
@@ -5,7 +5,6 @@ from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -229,17 +228,15 @@ class IntegrationCredentialsStore:
|
||||
return self._locks
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def db_manager(self):
|
||||
if prisma.is_connected():
|
||||
from backend.data import user
|
||||
|
||||
return user
|
||||
else:
|
||||
from backend.executor.database import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient)
|
||||
return get_database_manager_async_client()
|
||||
|
||||
# =============== USER-MANAGED CREDENTIALS =============== #
|
||||
async def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
|
||||
@@ -8,10 +8,11 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.clients import (
|
||||
get_database_manager_client,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,7 +41,7 @@ class BlockErrorMonitor:
|
||||
|
||||
def __init__(self, include_top_blocks: int | None = None):
|
||||
self.config = config
|
||||
self.notification_client = get_service_client(NotificationManagerClient)
|
||||
self.notification_client = get_notification_manager_client()
|
||||
self.include_top_blocks = (
|
||||
include_top_blocks
|
||||
if include_top_blocks is not None
|
||||
@@ -107,7 +108,7 @@ class BlockErrorMonitor:
|
||||
) -> dict[str, BlockStatsWithSamples]:
|
||||
"""Get block execution stats using efficient SQL aggregation."""
|
||||
|
||||
result = execution_utils.get_db_client().get_block_error_stats(
|
||||
result = get_database_manager_client().get_block_error_stats(
|
||||
start_time, end_time
|
||||
)
|
||||
|
||||
@@ -197,7 +198,7 @@ class BlockErrorMonitor:
|
||||
) -> list[str]:
|
||||
"""Get error samples for a specific block - just a few recent ones."""
|
||||
# Only fetch a small number of recent failed executions for this specific block
|
||||
executions = execution_utils.get_db_client().get_node_executions(
|
||||
executions = get_database_manager_client().get_node_executions(
|
||||
block_ids=[block_id],
|
||||
statuses=[ExecutionStatus.FAILED],
|
||||
created_time_gte=start_time,
|
||||
|
||||
@@ -4,10 +4,11 @@ import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.clients import (
|
||||
get_database_manager_client,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,13 +26,13 @@ class LateExecutionMonitor:
|
||||
|
||||
def __init__(self):
|
||||
self.config = config
|
||||
self.notification_client = get_service_client(NotificationManagerClient)
|
||||
self.notification_client = get_notification_manager_client()
|
||||
|
||||
def check_late_executions(self) -> str:
|
||||
"""Check for late executions and send alerts if found."""
|
||||
|
||||
# Check for QUEUED executions
|
||||
queued_late_executions = execution_utils.get_db_client().get_graph_executions(
|
||||
queued_late_executions = get_database_manager_client().get_graph_executions(
|
||||
statuses=[ExecutionStatus.QUEUED],
|
||||
created_time_gte=datetime.now(timezone.utc)
|
||||
- timedelta(
|
||||
@@ -43,7 +44,7 @@ class LateExecutionMonitor:
|
||||
)
|
||||
|
||||
# Check for RUNNING executions stuck for more than 24 hours
|
||||
running_late_executions = execution_utils.get_db_client().get_graph_executions(
|
||||
running_late_executions = get_database_manager_client().get_graph_executions(
|
||||
statuses=[ExecutionStatus.RUNNING],
|
||||
created_time_gte=datetime.now(timezone.utc)
|
||||
- timedelta(hours=24)
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
|
||||
import logging
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,11 +15,6 @@ class NotificationJobArgs(BaseModel):
|
||||
cron: str
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_manager_client():
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
|
||||
def process_existing_batches(**kwargs):
|
||||
"""Process existing notification batches."""
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
|
||||
@@ -5,7 +5,6 @@ from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable
|
||||
|
||||
import aio_pika
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data import rabbitmq
|
||||
@@ -26,26 +25,14 @@ from backend.data.notifications import (
|
||||
get_notif_data_type,
|
||||
get_summary_params_type,
|
||||
)
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
SyncRabbitMQ,
|
||||
)
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import generate_unsubscribe_link
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.clients import get_database_manager_client
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.metrics import discord_send_alert
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_sync,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[NotificationManager]")
|
||||
@@ -116,27 +103,6 @@ def create_notification_config() -> RabbitMQConfig:
|
||||
)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db():
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_queue() -> SyncRabbitMQ:
|
||||
client = SyncRabbitMQ(create_notification_config())
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_notification_queue() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_notification_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
def get_routing_key(event_type: NotificationType) -> str:
|
||||
strategy = NotificationTypeOverride(event_type).strategy
|
||||
"""Get the appropriate routing key for an event"""
|
||||
@@ -161,6 +127,8 @@ def queue_notification(event: NotificationEventModel) -> NotificationResult:
|
||||
exchange = "notifications"
|
||||
routing_key = get_routing_key(event.type)
|
||||
|
||||
from backend.util.clients import get_notification_queue
|
||||
|
||||
queue = get_notification_queue()
|
||||
queue.publish_message(
|
||||
routing_key=routing_key,
|
||||
@@ -186,6 +154,8 @@ async def queue_notification_async(event: NotificationEventModel) -> Notificatio
|
||||
exchange = "notifications"
|
||||
routing_key = get_routing_key(event.type)
|
||||
|
||||
from backend.util.clients import get_async_notification_queue
|
||||
|
||||
queue = await get_async_notification_queue()
|
||||
await queue.publish_message(
|
||||
routing_key=routing_key,
|
||||
@@ -241,7 +211,7 @@ class NotificationManager(AppService):
|
||||
processed_count = 0
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
start_time = current_time - timedelta(days=7)
|
||||
users = get_db().get_active_user_ids_in_timerange(
|
||||
users = get_database_manager_client().get_active_user_ids_in_timerange(
|
||||
end_time=current_time.isoformat(),
|
||||
start_time=start_time.isoformat(),
|
||||
)
|
||||
@@ -275,14 +245,14 @@ class NotificationManager(AppService):
|
||||
|
||||
for notification_type in notification_types:
|
||||
# Get all batches for this notification type
|
||||
batches = get_db().get_all_batches_by_type(notification_type)
|
||||
batches = get_database_manager_client().get_all_batches_by_type(
|
||||
notification_type
|
||||
)
|
||||
|
||||
for batch in batches:
|
||||
# Check if batch has aged out
|
||||
oldest_message = (
|
||||
get_db().get_user_notification_oldest_message_in_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
oldest_message = get_database_manager_client().get_user_notification_oldest_message_in_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
if not oldest_message:
|
||||
@@ -296,7 +266,11 @@ class NotificationManager(AppService):
|
||||
|
||||
# If batch has aged out, process it
|
||||
if oldest_message.created_at + max_delay < current_time:
|
||||
recipient_email = get_db().get_user_email_by_id(batch.user_id)
|
||||
recipient_email = (
|
||||
get_database_manager_client().get_user_email_by_id(
|
||||
batch.user_id
|
||||
)
|
||||
)
|
||||
|
||||
if not recipient_email:
|
||||
logger.error(
|
||||
@@ -313,13 +287,15 @@ class NotificationManager(AppService):
|
||||
f"User {batch.user_id} does not want to receive {notification_type} notifications"
|
||||
)
|
||||
# Clear the batch
|
||||
get_db().empty_user_notification_batch(
|
||||
get_database_manager_client().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
continue
|
||||
|
||||
batch_data = get_db().get_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
batch_data = (
|
||||
get_database_manager_client().get_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
)
|
||||
|
||||
if not batch_data or not batch_data.notifications:
|
||||
@@ -327,7 +303,7 @@ class NotificationManager(AppService):
|
||||
f"Batch data not found for user {batch.user_id}"
|
||||
)
|
||||
# Clear the batch
|
||||
get_db().empty_user_notification_batch(
|
||||
get_database_manager_client().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
continue
|
||||
@@ -363,7 +339,7 @@ class NotificationManager(AppService):
|
||||
)
|
||||
|
||||
# Clear the batch
|
||||
get_db().empty_user_notification_batch(
|
||||
get_database_manager_client().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
@@ -412,9 +388,11 @@ class NotificationManager(AppService):
|
||||
self, user_id: str, event_type: NotificationType
|
||||
) -> bool:
|
||||
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
|
||||
validated_email = get_db().get_user_email_verification(user_id)
|
||||
validated_email = get_database_manager_client().get_user_email_verification(
|
||||
user_id
|
||||
)
|
||||
preference = (
|
||||
get_db()
|
||||
get_database_manager_client()
|
||||
.get_user_notification_preference(user_id)
|
||||
.preferences.get(event_type, True)
|
||||
)
|
||||
@@ -505,10 +483,14 @@ class NotificationManager(AppService):
|
||||
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
|
||||
) -> bool:
|
||||
|
||||
get_db().create_or_add_to_user_notification_batch(user_id, event_type, event)
|
||||
get_database_manager_client().create_or_add_to_user_notification_batch(
|
||||
user_id, event_type, event
|
||||
)
|
||||
|
||||
oldest_message = get_db().get_user_notification_oldest_message_in_batch(
|
||||
user_id, event_type
|
||||
oldest_message = (
|
||||
get_database_manager_client().get_user_notification_oldest_message_in_batch(
|
||||
user_id, event_type
|
||||
)
|
||||
)
|
||||
if not oldest_message:
|
||||
logger.error(
|
||||
@@ -559,7 +541,9 @@ class NotificationManager(AppService):
|
||||
return False
|
||||
logger.debug(f"Processing immediate notification: {event}")
|
||||
|
||||
recipient_email = get_db().get_user_email_by_id(event.user_id)
|
||||
recipient_email = get_database_manager_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -594,7 +578,9 @@ class NotificationManager(AppService):
|
||||
return False
|
||||
logger.info(f"Processing batch notification: {event}")
|
||||
|
||||
recipient_email = get_db().get_user_email_by_id(event.user_id)
|
||||
recipient_email = get_database_manager_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -613,7 +599,9 @@ class NotificationManager(AppService):
|
||||
if not should_send:
|
||||
logger.info("Batch not old enough to send")
|
||||
return False
|
||||
batch = get_db().get_user_notification_batch(event.user_id, event.type)
|
||||
batch = get_database_manager_client().get_user_notification_batch(
|
||||
event.user_id, event.type
|
||||
)
|
||||
if not batch or not batch.notifications:
|
||||
logger.error(f"Batch not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -714,7 +702,9 @@ class NotificationManager(AppService):
|
||||
logger.info(
|
||||
f"Successfully sent all {successfully_sent_count} notifications, clearing batch"
|
||||
)
|
||||
get_db().empty_user_notification_batch(event.user_id, event.type)
|
||||
get_database_manager_client().empty_user_notification_batch(
|
||||
event.user_id, event.type
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Only sent {successfully_sent_count} of {len(batch_messages)} notifications. "
|
||||
@@ -736,7 +726,9 @@ class NotificationManager(AppService):
|
||||
|
||||
logger.info(f"Processing summary notification: {model}")
|
||||
|
||||
recipient_email = get_db().get_user_email_by_id(event.user_id)
|
||||
recipient_email = get_database_manager_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
|
||||
@@ -9,7 +9,6 @@ import pydantic
|
||||
import stripe
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from autogpt_libs.feature_flag.client import feature_flag
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
@@ -51,7 +50,6 @@ from backend.data.credit import (
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -84,18 +82,13 @@ from backend.server.model import (
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.server.utils import get_user_id
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_scheduler_client() -> scheduler.SchedulerClient:
|
||||
return get_service_client(scheduler.SchedulerClient, health_check=False)
|
||||
|
||||
|
||||
def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
"""Create standardized file size error response."""
|
||||
return HTTPException(
|
||||
@@ -104,11 +97,6 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -924,7 +912,7 @@ async def create_graph_execution_schedule(
|
||||
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
|
||||
)
|
||||
|
||||
return await execution_scheduler_client().add_execution_schedule(
|
||||
return await get_scheduler_client().add_execution_schedule(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
@@ -945,7 +933,7 @@ async def list_graph_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_id: str = Path(),
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await execution_scheduler_client().get_execution_schedules(
|
||||
return await get_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
@@ -960,7 +948,7 @@ async def list_graph_execution_schedules(
|
||||
async def list_all_graphs_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await execution_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
return await get_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
@@ -974,7 +962,7 @@ async def delete_graph_execution_schedule(
|
||||
schedule_id: str = Path(..., description="ID of the schedule to delete"),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
await get_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
|
||||
134
autogpt_platform/backend/backend/util/clients.py
Normal file
134
autogpt_platform/backend/backend/util/clients.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Centralized service client helpers with thread caching.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
RedisExecutionEventBus,
|
||||
)
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ, SyncRabbitMQ
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client() -> "DatabaseManagerClient":
|
||||
"""Get a thread-cached DatabaseManagerClient with request retry enabled."""
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerClient, request_retry=True)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_async_client() -> "DatabaseManagerAsyncClient":
|
||||
"""Get a thread-cached DatabaseManagerAsyncClient with request retry enabled."""
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, request_retry=True)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_scheduler_client() -> "SchedulerClient":
|
||||
"""Get a thread-cached SchedulerClient."""
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(SchedulerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_manager_client() -> "NotificationManagerClient":
|
||||
"""Get a thread-cached NotificationManagerClient."""
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
|
||||
# ============ Execution Event Bus Helpers ============ #
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_event_bus() -> "RedisExecutionEventBus":
|
||||
"""Get a thread-cached RedisExecutionEventBus."""
|
||||
from backend.data.execution import RedisExecutionEventBus
|
||||
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_async_execution_event_bus() -> "AsyncRedisExecutionEventBus":
|
||||
"""Get a thread-cached AsyncRedisExecutionEventBus."""
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
# ============ Execution Queue Helpers ============ #
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_queue() -> "SyncRabbitMQ":
|
||||
"""Get a thread-cached SyncRabbitMQ execution queue client."""
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
|
||||
client = SyncRabbitMQ(create_execution_queue_config())
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_execution_queue() -> "AsyncRabbitMQ":
|
||||
"""Get a thread-cached AsyncRabbitMQ execution queue client."""
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
# ============ Integration Credentials Store ============ #
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
"""Get a thread-cached IntegrationCredentialsStore."""
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
return IntegrationCredentialsStore()
|
||||
|
||||
|
||||
# ============ Notification Queue Helpers ============ #
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_queue() -> "SyncRabbitMQ":
|
||||
"""Get a thread-cached SyncRabbitMQ notification queue client."""
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.notifications.notifications import create_notification_config
|
||||
|
||||
client = SyncRabbitMQ(create_notification_config())
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_notification_queue() -> "AsyncRabbitMQ":
|
||||
"""Get a thread-cached AsyncRabbitMQ notification queue client."""
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||
from backend.notifications.notifications import create_notification_config
|
||||
|
||||
client = AsyncRabbitMQ(create_notification_config())
|
||||
await client.connect()
|
||||
return client
|
||||
@@ -75,7 +75,7 @@ class AppProcess(ABC):
|
||||
self.run()
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Termination request: {type(e).__name__}; executing cleanup."
|
||||
f"[{self.service_name}] Termination request: {type(e).__name__}; {e} executing cleanup."
|
||||
)
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
@@ -161,14 +161,21 @@ def continuous_retry(*, retry_delay: float = 1.0):
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
counter = 0
|
||||
while True:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"%s failed with %s — retrying in %.2f s",
|
||||
counter += 1
|
||||
if counter % 10 == 0:
|
||||
log = logger.exception
|
||||
else:
|
||||
log = logger.warning
|
||||
log(
|
||||
"%s failed for the %s times, error: [%s] — retrying in %.2fs",
|
||||
func.__name__,
|
||||
exc,
|
||||
counter,
|
||||
str(exc) or type(exc).__name__,
|
||||
retry_delay,
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
@@ -176,13 +183,20 @@ def continuous_retry(*, retry_delay: float = 1.0):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
while True:
|
||||
counter = 0
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"%s failed with %s — retrying in %.2f s",
|
||||
counter += 1
|
||||
if counter % 10 == 0:
|
||||
log = logger.exception
|
||||
else:
|
||||
log = logger.warning
|
||||
log(
|
||||
"%s failed for the %s times, error: [%s] — retrying in %.2fs",
|
||||
func.__name__,
|
||||
exc,
|
||||
counter,
|
||||
str(exc) or type(exc).__name__,
|
||||
retry_delay,
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
@@ -277,7 +277,6 @@ ASC = TypeVar("ASC", bound=AppServiceClient)
|
||||
def get_service_client(
|
||||
service_client_type: Type[ASC],
|
||||
call_timeout: int | None = api_call_timeout,
|
||||
health_check: bool = True,
|
||||
request_retry: bool = False,
|
||||
) -> ASC:
|
||||
|
||||
@@ -461,8 +460,6 @@ def get_service_client(
|
||||
return sync_method
|
||||
|
||||
client = cast(ASC, DynamicClient())
|
||||
if health_check and hasattr(client, "health_check"):
|
||||
client.health_check()
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@@ -16,6 +16,12 @@ from backend.util.service import (
|
||||
TEST_SERVICE_PORT = 8765
|
||||
|
||||
|
||||
def wait_for_service_ready(service_client_type, timeout_seconds=30):
|
||||
"""Helper method to wait for a service to be ready using health check with retry."""
|
||||
client = get_service_client(service_client_type, request_retry=True)
|
||||
client.health_check() # This will retry until service is ready
|
||||
|
||||
|
||||
class ServiceTest(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -28,6 +34,15 @@ class ServiceTest(AppService):
|
||||
def get_port(cls) -> int:
|
||||
return TEST_SERVICE_PORT
|
||||
|
||||
def __enter__(self):
|
||||
# Start the service
|
||||
result = super().__enter__()
|
||||
|
||||
# Wait for the service to be ready
|
||||
wait_for_service_ready(ServiceTestClient)
|
||||
|
||||
return result
|
||||
|
||||
@expose
|
||||
def add(self, a: int, b: int) -> int:
|
||||
return a + b
|
||||
@@ -48,13 +63,13 @@ class ServiceTest(AppService):
|
||||
"""Method that fails 2 times then succeeds - for testing retry logic"""
|
||||
self.fail_count += 1
|
||||
if self.fail_count <= 2:
|
||||
raise RuntimeError("Database connection failed")
|
||||
raise RuntimeError(f"Intended error for testing {self.fail_count}/2")
|
||||
return a + b
|
||||
|
||||
@expose
|
||||
def always_failing_add(self, a: int, b: int) -> int:
|
||||
"""Method that always fails - for testing no retry when disabled"""
|
||||
raise RuntimeError("Database connection failed")
|
||||
raise RuntimeError("Intended error for testing")
|
||||
|
||||
|
||||
class ServiceTestClient(AppServiceClient):
|
||||
@@ -349,5 +364,5 @@ def test_service_no_retry_when_disabled(server):
|
||||
client = get_service_client(ServiceTestClient, request_retry=False)
|
||||
|
||||
# This should fail immediately without retry
|
||||
with pytest.raises(RuntimeError, match="Database connection failed"):
|
||||
with pytest.raises(RuntimeError, match="Intended error for testing"):
|
||||
client.always_failing_add(5, 3)
|
||||
|
||||
Reference in New Issue
Block a user